utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from __future__ import annotations
  2. import hashlib
  3. import base64
  4. from typing import Type
  5. import functools
  6. from copy import deepcopy
  7. import inspect
  8. import click
  9. from sqlalchemy import JSON, String, cast, literal
  10. from flask import current_app
  11. from rq.job import Job
  12. from sqlalchemy import select
  13. from flexmeasures import Sensor, Asset
  14. from flexmeasures.data import db
  15. from flexmeasures.data.models.generic_assets import GenericAsset, GenericAssetType
  16. from flexmeasures.data.models.planning import Scheduler
  17. def get_scheduler_instance(
  18. scheduler_class: Type[Scheduler], asset_or_sensor: Asset | Sensor, scheduler_params
  19. ) -> Scheduler:
  20. """
  21. Get an instance of a Scheduler adapting for the previous Scheduler signature,
  22. where a sensor is passed, to the new one where the asset_or_sensor is introduced.
  23. """
  24. _scheduler_params = deepcopy(scheduler_params)
  25. if "asset_or_sensor" not in inspect.signature(scheduler_class).parameters:
  26. _scheduler_params["sensor"] = asset_or_sensor
  27. else:
  28. _scheduler_params["asset_or_sensor"] = asset_or_sensor
  29. return scheduler_class(**_scheduler_params)
  30. def get_asset_or_sensor_ref(asset_or_sensor: Asset | Sensor) -> dict:
  31. return {"id": asset_or_sensor.id, "class": asset_or_sensor.__class__.__name__}
  32. def get_asset_or_sensor_from_ref(asset_or_sensor: dict):
  33. """
  34. Fetch Asset or Sensor object described by the asset_or_sensor dictionary.
  35. This dictionary needs to contain the class name and row id.
  36. We currently cannot simplify this by just passing around the object
  37. instead of the class name: i.e. the function arguments need to
  38. be serializable as job parameters.
  39. Examples:
  40. >> get_asset_or_sensor({"class" : "Asset", "id" : 1})
  41. Asset(id=1)
  42. >> get_asset_or_sensor({"class" : "Sensor", "id" : 2})
  43. Sensor(id=2)
  44. """
  45. if asset_or_sensor["class"] == Asset.__name__:
  46. klass = Asset
  47. elif asset_or_sensor["class"] == Sensor.__name__:
  48. klass = Sensor
  49. else:
  50. raise ValueError(
  51. f"Unrecognized class `{asset_or_sensor['class']}`. Please, consider using GenericAsset or Sensor."
  52. )
  53. return db.session.get(klass, asset_or_sensor["id"])
  54. def get_or_create_model(
  55. model_class: Type[GenericAsset | GenericAssetType | Sensor], **kwargs
  56. ) -> GenericAsset | GenericAssetType | Sensor:
  57. """Get a model from the database or add it if it's missing.
  58. For example:
  59. >>> weather_station_type = get_or_create_model(
  60. ... GenericAssetType,
  61. ... name="weather station",
  62. ... description="A weather station with various sensors.",
  63. ... )
  64. """
  65. # unpack custom initialization parameters that map to multiple database columns
  66. init_kwargs = kwargs.copy()
  67. lookup_kwargs = kwargs.copy()
  68. if "knowledge_horizon" in kwargs:
  69. (
  70. lookup_kwargs["knowledge_horizon_fnc"],
  71. lookup_kwargs["knowledge_horizon_par"],
  72. ) = lookup_kwargs.pop("knowledge_horizon")
  73. # Find out which attributes are dictionaries mapped to JSON database columns,
  74. # or callables mapped to string database columns (by their name)
  75. filter_json_kwargs = {}
  76. filter_by_kwargs = lookup_kwargs.copy()
  77. for kw, arg in lookup_kwargs.items():
  78. model_attribute = getattr(model_class, kw)
  79. if hasattr(model_attribute, "type") and isinstance(model_attribute.type, JSON):
  80. filter_json_kwargs[kw] = filter_by_kwargs.pop(kw)
  81. elif callable(arg) and isinstance(model_attribute.type, String):
  82. # Callables are stored in the database by their name
  83. # e.g. knowledge_horizon_fnc = x_days_ago_at_y_oclock
  84. # is stored as "x_days_ago_at_y_oclock"
  85. filter_by_kwargs[kw] = filter_by_kwargs[kw].__name__
  86. else:
  87. # The kw is already present in filter_by_kwargs and doesn't need to be adapted
  88. # i.e. it can be used as an argument to .filter_by()
  89. pass
  90. # See if the model already exists as a db row
  91. model_query = select(model_class).filter_by(**filter_by_kwargs)
  92. for kw, arg in filter_json_kwargs.items():
  93. model_query = model_query.filter(
  94. cast(getattr(model_class, kw), String) == cast(literal(arg, JSON()), String)
  95. )
  96. model = db.session.execute(model_query).scalar_one_or_none()
  97. # Create the model and add it to the database if it didn't already exist
  98. if model is None:
  99. model = model_class(**init_kwargs)
  100. db.session.add(model)
  101. db.session.flush() # assign ID
  102. click.echo(f"Created {repr(model)}")
  103. return model
  104. def make_hash_sha256(o):
  105. """
  106. SHA256 instead of Python's hash function because apparently, python native hashing function
  107. yields different results on restarts.
  108. Source: https://stackoverflow.com/a/42151923
  109. """
  110. hasher = hashlib.sha256()
  111. hasher.update(repr(make_hashable(o)).encode())
  112. return base64.b64encode(hasher.digest()).decode()
  113. def make_hashable(o):
  114. """
  115. Function to create hashes for dictionaries with nested objects
  116. Source: https://stackoverflow.com/a/42151923
  117. """
  118. if isinstance(o, (tuple, list)):
  119. return tuple((make_hashable(e) for e in o))
  120. if isinstance(o, dict):
  121. return tuple(sorted((k, make_hashable(v)) for k, v in o.items()))
  122. if isinstance(o, (set, frozenset)):
  123. return tuple(sorted(make_hashable(e) for e in o))
  124. if callable(
  125. getattr(o, "make_hashable", None)
  126. ): # checks if the object o has the method make_hashable
  127. return o.make_hashable()
  128. return o
  129. def hash_function_arguments(args, kwags):
  130. """Combines the hashes of the args and kargs
  131. The way to go to do h(x,y) = hash(hash(x) || hash(y)) because it avoid the following:
  132. 1) h(x,y) = hash(x || y), might create a collision if we delete the last n characters of x and we append them in front of y. e.g h("abc", "d") = h("ab", "cd")
  133. 2) we don't want to sort x and y, because we need the function h(x,y) != h(y,x)
  134. 3) extra hashing just avoid that we can't decompose the input arguments and track if the same args or kwarg are called several times. More of a security measure I think.
  135. source: https://crypto.stackexchange.com/questions/55162/best-way-to-hash-two-values-into-one
  136. """
  137. return make_hash_sha256(
  138. make_hash_sha256(args) + make_hash_sha256(kwags)
  139. ) # concat two hashes
  140. def job_cache(queue: str):
  141. """
  142. To avoid recomputing the same task multiple times, this decorator checks if the function has already been called with the
  143. same arguments. Input arguments are hashed and stored as Redis keys with the values being the job IDs `input_arguments_hash:job_id`).
  144. The benefits of using redis to store the input arguments over a local cache, such as LRU Cache, are:
  145. 1) It will work in distributed environments (in computing clusters), where multiple workers will avoid repeating
  146. work as the cache will be shared across them.
  147. 2) Cached calls are logged, which means that we can easily debug.
  148. 3) Cache will still be there on restarts.
  149. Arguments
  150. :param queue: name of the queue
  151. """
  152. def decorator(func):
  153. @functools.wraps(func)
  154. def wrapper(*args, **kwargs):
  155. # Get the redis connection
  156. connection = current_app.redis_connection
  157. kwargs_for_hash = kwargs.copy()
  158. requeue = kwargs_for_hash.pop("requeue", False)
  159. # checking if force is an input argument of `func`
  160. force_new_job_creation = kwargs_for_hash.pop(
  161. "force_new_job_creation", False
  162. )
  163. # creating a hash from args and kwargs_for_hash
  164. args_hash = f"{queue}:{func.__name__}:{hash_function_arguments(args, kwargs_for_hash)}"
  165. # check the redis connection for whether the key hash exists
  166. if connection.exists(args_hash) and not force_new_job_creation:
  167. current_app.logger.info(
  168. f"The function {func.__name__} has been called already with the same arguments. Skipping..."
  169. )
  170. # get job id
  171. job_id = connection.get(args_hash).decode()
  172. # check if the job exists and, if it doesn't, skip fetching and generate new job
  173. if Job.exists(job_id, connection=connection):
  174. job = Job.fetch(
  175. job_id, connection=connection
  176. ) # get job object from job id
  177. # requeue if failed and requeue flag is true
  178. if job.is_failed and requeue:
  179. job.requeue()
  180. return job # returning the same job regardless of the status (SUCCESS, FAILED, ...)
  181. # if the job description is new -> create job
  182. job = func(*args, **kwargs) # create a new job
  183. # store function call in redis by mapping the hash of the function arguments to its job id
  184. connection.set(
  185. args_hash, job.id, ex=current_app.config["FLEXMEASURES_JOB_CACHE_TTL"]
  186. )
  187. return job
  188. return wrapper
  189. return decorator