pandas_reporter.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from marshmallow import Schema, fields, ValidationError, validates_schema, validate
  2. from inspect import signature
  3. from flexmeasures.data.schemas import AwareDateTimeField
  4. from flexmeasures.data.schemas.reporting import (
  5. ReporterConfigSchema,
  6. ReporterParametersSchema,
  7. )
  8. from flexmeasures.data.schemas.io import RequiredInput, RequiredOutput
  9. from timely_beliefs import BeliefsDataFrame, BeliefsSeries
  10. from pandas.core.resample import Resampler
  11. from pandas.core.groupby.grouper import Grouper
  12. class PandasMethodCall(Schema):
  13. df_input = fields.Str()
  14. df_output = fields.Str()
  15. method = fields.Str(required=True)
  16. args = fields.List(fields.Raw())
  17. kwargs = fields.Dict()
  18. @validates_schema
  19. def validate_method_call(self, data, **kwargs):
  20. """Validates the method name and its arguments against a set of base classes.
  21. This validation ensures that the provided method exists in one of the
  22. specified base classes (`BeliefsSeries`, `BeliefsDataFrame`, `Resampler`, `Grouper`)
  23. and that the provided arguments (`args` and `kwargs`) are valid for the method's
  24. signature.
  25. Args:
  26. data (dict): A dictionary containing the method name (`method`) and optionally
  27. the method arguments (`args` as a list and `kwargs` as a dictionary).
  28. **kwargs: Additional keyword arguments passed by the validation framework.
  29. Raises:
  30. ValidationError: If the method is not callable in any of the base classes or
  31. if the provided arguments do not match the method signature.
  32. """
  33. method = data["method"]
  34. is_callable = []
  35. bad_arguments = True
  36. # Iterate through the base classes to validate the method
  37. for base_class in [BeliefsSeries, BeliefsDataFrame, Resampler, Grouper]:
  38. # Check if the method exists in the base class
  39. method_callable = getattr(base_class, method, None)
  40. if method_callable is None:
  41. # Method does not exist in this base class
  42. is_callable.append(False)
  43. continue
  44. # Check if the found method is callable
  45. is_callable.append(callable(method_callable))
  46. # Retrieve the method's signature for argument validation
  47. method_signature = signature(method_callable)
  48. try:
  49. # Copy `args` and `kwargs` to avoid modifying the input data
  50. args = data.get("args", []).copy()
  51. _kwargs = data.get("kwargs", {}).copy()
  52. # Insert the base class as the first argument to the method (self/cls context)
  53. args.insert(0, BeliefsDataFrame)
  54. # Bind the arguments to the method's signature for validation
  55. method_signature.bind(*args, **_kwargs)
  56. bad_arguments = False # Arguments are valid if binding succeeds
  57. except TypeError:
  58. # If binding raises a TypeError, the arguments are invalid
  59. pass
  60. # Raise an error if all arguments are invalid across all base classes
  61. if bad_arguments:
  62. raise ValidationError(
  63. f"Bad arguments or keyword arguments for method {method}"
  64. )
  65. # Raise an error if the method is not callable in any of the base classes
  66. if not any(is_callable):
  67. raise ValidationError(
  68. f"Method {method} is not a valid BeliefsSeries, BeliefsDataFrame, Resampler or Grouper method."
  69. )
  70. class PandasReporterConfigSchema(ReporterConfigSchema):
  71. """
  72. This schema lists fields that can be used to describe sensors in the optimised portfolio
  73. Example:
  74. {
  75. "required_input" : [
  76. {"name" : "df1", "unit" : "MWh"}
  77. ],
  78. "required_output" : [
  79. {"name" : "df2", "unit" : "kWh"}
  80. ],
  81. "transformations" : [
  82. {
  83. "df_input" : "df1",
  84. "df_output" : "df2",
  85. "method" : "copy"
  86. },
  87. {
  88. "df_input" : "df2",
  89. "df_output" : "df2",
  90. "method" : "sum"
  91. },
  92. {
  93. "method" : "sum",
  94. "kwargs" : {"axis" : 0}
  95. }
  96. ],
  97. }
  98. """
  99. required_input = fields.List(
  100. fields.Nested(RequiredInput()), validate=validate.Length(min=1)
  101. )
  102. required_output = fields.List(
  103. fields.Nested(RequiredOutput()), validate=validate.Length(min=1)
  104. )
  105. transformations = fields.List(fields.Nested(PandasMethodCall()), required=True)
  106. droplevels = fields.Bool(required=False, load_default=False)
  107. @validates_schema
  108. def validate_chaining(self, data, **kwargs):
  109. """
  110. This validator ensures that we are always given an input and that the
  111. final_df_output is computed.
  112. """
  113. # fake_data mocks the PandasReporter class attribute data. It contains empty BeliefsDataFrame
  114. # to simulate the process of applying the transformations.
  115. fake_data = dict(
  116. (_input["name"], BeliefsDataFrame) for _input in data.get("required_input")
  117. )
  118. output_names = [_output["name"] for _output in data.get("required_output")]
  119. previous_df = None
  120. output_method = dict()
  121. for transformation in data.get("transformations"):
  122. df_input = transformation.get("df_input", previous_df)
  123. df_output = transformation.get("df_output", df_input)
  124. if df_output in output_names:
  125. output_method[df_output] = transformation.get("method")
  126. if df_input not in fake_data:
  127. raise ValidationError("Cannot find the input DataFrame.")
  128. previous_df = df_output # keeping last BeliefsDataFrame calculation
  129. fake_data[df_output] = BeliefsDataFrame
  130. for _output in output_names:
  131. if _output not in fake_data:
  132. raise ValidationError(
  133. "Cannot find final output `{_output}` DataFrame among the resulting DataFrames."
  134. )
  135. if (_output in output_method) and (
  136. output_method[_output] in ["resample", "groupby"]
  137. ):
  138. raise ValidationError(
  139. f"Final output (`{_output}`) type cannot by of type `Resampler` or `DataFrameGroupBy`"
  140. )
  141. class PandasReporterParametersSchema(ReporterParametersSchema):
  142. # make start and end optional, conditional on providing the time parameters
  143. # for the single sensors in `input_variables`
  144. start = AwareDateTimeField(required=False)
  145. end = AwareDateTimeField(required=False)
  146. use_latest_version_only = fields.Bool(required=False)
  147. @validates_schema
  148. def validate_time_parameters(self, data, **kwargs):
  149. """This method validates that all input sensors have start
  150. and end parameters available.
  151. """
  152. # it's enough to provide a common start and end
  153. if ("start" in data) and ("end" in data):
  154. return
  155. for input_description in data.get("input", []):
  156. input_sensor = input_description["sensor"]
  157. if ("event_starts_after" not in input_description) and (
  158. "start" not in data
  159. ):
  160. raise ValidationError(
  161. f"Start parameter not provided for sensor {input_sensor}"
  162. )
  163. if ("event_ends_before" not in input_description) and ("end" not in data):
  164. raise ValidationError(
  165. f"End parameter not provided for sensor {input_sensor}"
  166. )