utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """
  2. Utils for FlexMeasures CLI
  3. """
  4. from __future__ import annotations
  5. from typing import Any
  6. from datetime import datetime, timedelta
  7. import click
  8. import pytz
  9. from click_default_group import DefaultGroup
  10. from flexmeasures.utils.time_utils import get_most_recent_hour, get_timezone
  11. from flexmeasures.utils.validation_utils import validate_color_hex, validate_url
  12. from flexmeasures import Sensor
  13. class MsgStyle(object):
  14. """Stores the text styles for the different events
  15. Styles options are the attributes of the `click.style` which can be found
  16. [here](https://click.palletsprojects.com/en/8.1.x/api/#click.style).
  17. """
  18. SUCCESS: dict[str, Any] = {"fg": "green"}
  19. WARN: dict[str, Any] = {"fg": "yellow"}
  20. ERROR: dict[str, Any] = {"fg": "red"}
  21. class DeprecatedOption(click.Option):
  22. """A custom option that can be used to mark an option as deprecated.
  23. References
  24. ----------------
  25. Copied from https://stackoverflow.com/a/50402799/13775459
  26. """
  27. def __init__(self, *args, **kwargs):
  28. self.deprecated = kwargs.pop("deprecated", ())
  29. self.preferred = kwargs.pop("preferred", args[0][-1])
  30. super(DeprecatedOption, self).__init__(*args, **kwargs)
  31. class DeprecatedOptionsCommand(click.Command):
  32. """A custom command that can be used to mark options as deprecated.
  33. References
  34. ----------------
  35. Adapted from https://stackoverflow.com/a/50402799/13775459
  36. """
  37. def make_parser(self, ctx):
  38. """Hook 'make_parser' and during processing check the name
  39. used to invoke the option to see if it is preferred"""
  40. parser = super(DeprecatedOptionsCommand, self).make_parser(ctx)
  41. # get the parser options
  42. options = set(parser._short_opt.values())
  43. options |= set(parser._long_opt.values())
  44. for option in options:
  45. if not isinstance(option.obj, DeprecatedOption):
  46. continue
  47. def make_process(an_option):
  48. """Construct a closure to the parser option processor"""
  49. orig_process = an_option.process
  50. deprecated = getattr(an_option.obj, "deprecated", None)
  51. preferred = getattr(an_option.obj, "preferred", None)
  52. msg = "Expected `deprecated` value for `{}`"
  53. assert deprecated is not None, msg.format(an_option.obj.name)
  54. def process(value, state):
  55. """The function above us on the stack used 'opt' to
  56. pick option from a dict, see if it is deprecated"""
  57. # reach up the stack and get 'opt'
  58. import inspect
  59. frame = inspect.currentframe()
  60. try:
  61. opt = frame.f_back.f_locals.get("opt")
  62. finally:
  63. del frame
  64. if opt in deprecated:
  65. click.secho(
  66. f"Option '{opt}' will be replaced by '{preferred}'.",
  67. **MsgStyle.WARN,
  68. )
  69. return orig_process(value, state)
  70. return process
  71. option.process = make_process(option)
  72. return parser
  73. class DeprecatedDefaultGroup(DefaultGroup):
  74. """Invokes a default subcommand, *and* shows a deprecation message.
  75. Also adds the `invoked_default` boolean attribute to the context.
  76. A group callback can use this information to figure out if it's being executed directly
  77. (invoking the default subcommand) or because the execution flow passes onwards to a subcommand.
  78. By default it's None, but it can be the name of the default subcommand to execute.
  79. .. sourcecode:: python
  80. import click
  81. from flexmeasures.cli.utils import DeprecatedDefaultGroup
  82. @click.group(cls=DeprecatedDefaultGroup, default="bar", deprecation_message="renamed to `foo bar`.")
  83. def foo(ctx):
  84. if ctx.invoked_default:
  85. click.echo("foo")
  86. @foo.command()
  87. def bar():
  88. click.echo("bar")
  89. .. sourcecode:: console
  90. $ flexmeasures foo
  91. DeprecationWarning: renamed to `foo bar`.
  92. foo
  93. bar
  94. $ flexmeasures foo bar
  95. bar
  96. """
  97. def __init__(self, *args, **kwargs):
  98. self.deprecation_message = "DeprecationWarning: " + kwargs.pop(
  99. "deprecation_message", ""
  100. )
  101. super().__init__(*args, **kwargs)
  102. def get_command(self, ctx, cmd_name):
  103. ctx.invoked_default = None
  104. if cmd_name not in self.commands:
  105. click.echo(click.style(self.deprecation_message, fg="red"), err=True)
  106. ctx.invoked_default = self.default_cmd_name
  107. return super().get_command(ctx, cmd_name)
  108. def get_timerange_from_flag(
  109. last_hour: bool = False,
  110. last_day: bool = False,
  111. last_7_days: bool = False,
  112. last_month: bool = False,
  113. last_year: bool = False,
  114. timezone: pytz.BaseTzInfo | None = None,
  115. ) -> tuple[datetime, datetime]:
  116. """This function returns a time range [start,end] of the last-X period.
  117. See input parameters for more details.
  118. :param bool last_hour: flag to get the time range of the last finished hour.
  119. :param bool last_day: flag to get the time range for yesterday.
  120. :param bool last_7_days: flag to get the time range of the previous 7 days.
  121. :param bool last_month: flag to get the time range of last calendar month
  122. :param bool last_year: flag to get the last completed calendar year
  123. :param timezone: timezone object to represent
  124. :returns: start:datetime, end:datetime
  125. """
  126. if timezone is None:
  127. timezone = get_timezone()
  128. current_hour = get_most_recent_hour().astimezone(timezone)
  129. if last_hour: # last finished hour
  130. end = current_hour
  131. start = current_hour - timedelta(hours=1)
  132. if last_day: # yesterday
  133. end = current_hour.replace(hour=0)
  134. start = end - timedelta(days=1)
  135. if last_7_days: # last finished 7 day period.
  136. end = current_hour.replace(hour=0)
  137. start = end - timedelta(days=7)
  138. if last_month:
  139. end = current_hour.replace(
  140. hour=0, day=1
  141. ) # get the first day of the current month
  142. start = (end - timedelta(days=1)).replace(
  143. day=1
  144. ) # get first day of the previous month
  145. if last_year: # last calendar year
  146. end = current_hour.replace(
  147. month=1, day=1, hour=0
  148. ) # get first day of current year
  149. start = (end - timedelta(days=1)).replace(
  150. day=1, month=1
  151. ) # get first day of previous year
  152. return start, end
  153. def validate_unique(ctx, param, value):
  154. """Callback function to ensure multiple values are unique."""
  155. if value is not None:
  156. # Check if all values are unique
  157. if len(value) != len(set(value)):
  158. raise click.BadParameter("Values must be unique.")
  159. return value
  160. def abort(message: str):
  161. click.secho(message, **MsgStyle.ERROR)
  162. raise click.Abort()
  163. def done(message: str):
  164. click.secho(message, **MsgStyle.SUCCESS)
  165. def path_to_str(path: list, separator: str = ">") -> str:
  166. """
  167. Converts a list representing a path to a string format, using a specified separator.
  168. """
  169. return separator.join(path)
  170. def are_all_equal(paths: list[list[str]]) -> bool:
  171. """
  172. Checks if all given entity paths represent the same path.
  173. """
  174. return len(set(path_to_str(p) for p in paths)) == 1
  175. def reduce_entity_paths(asset_paths: list[list[str]]) -> list[list[str]]:
  176. """
  177. Simplifies a list of entity paths by trimming their common ancestor.
  178. Examples:
  179. >>> reduce_entity_paths([["Account1", "Asset1"], ["Account2", "Asset2"]])
  180. [["Account1", "Asset1"], ["Account2", "Asset2"]]
  181. >>> reduce_entity_paths([["Asset1"], ["Asset2"]])
  182. [["Asset1"], ["Asset2"]]
  183. >>> reduce_entity_paths([["Account1", "Asset1"], ["Account1", "Asset2"]])
  184. [["Asset1"], ["Asset2"]]
  185. >>> reduce_entity_paths([["Asset1", "Asset2"], ["Asset1"]])
  186. [["Asset1"], ["Asset1", "Asset2"]]
  187. >>> reduce_entity_paths([["Account1", "Asset", "Asset1"], ["Account1", "Asset", "Asset2"]])
  188. [["Asset1"], ["Asset2"]]
  189. """
  190. reduced_entities = 0
  191. # At least we need to leave one entity in each list
  192. max_reduced_entities = min([len(p) - 1 for p in asset_paths])
  193. # Find the common path
  194. while (
  195. are_all_equal([p[:reduced_entities] for p in asset_paths])
  196. and reduced_entities <= max_reduced_entities
  197. ):
  198. reduced_entities += 1
  199. return [p[reduced_entities - 1 :] for p in asset_paths]
  200. def get_sensor_aliases(
  201. sensors: list[Sensor],
  202. reduce_paths: bool = True,
  203. separator: str = "/",
  204. ) -> dict:
  205. """
  206. Generates aliases for all sensors by appending a unique path to each sensor's name.
  207. Parameters:
  208. :param sensors: A list of Sensor objects.
  209. :param reduce_paths: Flag indicating whether to reduce each sensor's entity path. Defaults to True.
  210. :param separator: Character or string used to separate entities within each sensor's path. Defaults to "/".
  211. :return: A dictionary mapping sensor IDs to their generated aliases.
  212. """
  213. entity_paths = [
  214. s.generic_asset.get_path(separator=separator).split(separator) for s in sensors
  215. ]
  216. if reduce_paths:
  217. entity_paths = reduce_entity_paths(entity_paths)
  218. entity_paths = [path_to_str(p, separator=separator) for p in entity_paths]
  219. aliases = {
  220. sensor.id: f"{sensor.name} ({path})"
  221. for path, sensor in zip(entity_paths, sensors)
  222. }
  223. return aliases
  224. def validate_color_cli(ctx, param, value):
  225. """
  226. Optional parameter validation
  227. Validates that a given value is a valid hex color code.
  228. Parameters:
  229. :param ctx: Click context.
  230. :param param: Click parameter name.
  231. :param value: The color code to validate.
  232. """
  233. try:
  234. validate_color_hex(value)
  235. except ValueError as e:
  236. click.secho(str(e), **MsgStyle.ERROR)
  237. raise click.Abort()
  238. def validate_url_cli(ctx, param, value):
  239. """
  240. Optional parameter validation
  241. Validates that a given value is a valid URL format using regex.
  242. Parameters:
  243. :param ctx: Click context.
  244. :param param: Click parameter name.
  245. :param value: The URL to validate.
  246. """
  247. try:
  248. validate_url(value)
  249. except ValueError as e:
  250. click.secho(str(e), **MsgStyle.ERROR)
  251. raise click.Abort()