test_data_source.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from __future__ import annotations
  2. import pytest
  3. from flexmeasures.data.models.reporting import Reporter
  4. from flexmeasures.data.models.data_sources import keep_latest_version, DataSource
  5. from datetime import datetime
  6. from pytz import UTC
  7. import numpy as np
  8. import timely_beliefs as tb
  9. def test_get_reporter_from_source(db, app, test_reporter, add_nearby_weather_sensors):
  10. reporter = test_reporter.data_generator
  11. reporter_sensor = add_nearby_weather_sensors.get("farther_temperature")
  12. assert isinstance(reporter, Reporter)
  13. assert reporter.__class__.__name__ == "TestReporter"
  14. res = reporter.compute(
  15. input=[{"sensor": reporter_sensor}],
  16. output=[{"sensor": reporter_sensor}],
  17. start=datetime(2023, 1, 1, tzinfo=UTC),
  18. end=datetime(2023, 1, 2, tzinfo=UTC),
  19. )[0]["data"]
  20. assert res.lineage.sources[0] == reporter.data_source
  21. with pytest.raises(AttributeError):
  22. reporter.compute(
  23. input=[{"sensor": reporter_sensor}],
  24. output=[{"sensor": reporter_sensor}],
  25. start=datetime(2023, 1, 1, tzinfo=UTC),
  26. end="not a date",
  27. )
  28. def test_data_source(db, app, test_reporter):
  29. # get TestReporter class from the data_generators registry
  30. TestReporter = app.data_generators["reporter"].get("TestReporter")
  31. reporter1 = TestReporter(config={"a": "1"})
  32. db.session.add(reporter1.data_source)
  33. reporter2 = TestReporter(config={"a": "1"})
  34. # reporter1 and reporter2 have the same data_source because they share the same config
  35. assert reporter1.data_source == reporter2.data_source
  36. assert reporter1.data_source.attributes.get("data_generator").get(
  37. "config"
  38. ) == reporter2.data_source.attributes.get("data_generator").get("config")
  39. reporter3 = TestReporter(config={"a": "2"})
  40. # reporter3 and reporter2 have different data sources because they have different config values
  41. assert reporter3.data_source != reporter2.data_source
  42. assert reporter3.data_source.attributes.get("data_generator").get(
  43. "config"
  44. ) != reporter2.data_source.attributes.get("data_generator").get("config")
  45. # recreate reporter3 from its data source
  46. reporter4 = reporter3.data_source.data_generator
  47. # check that reporter3 and reporter4 share the same config values
  48. assert reporter4._config == reporter3._config
  49. def test_data_generator_save_config(db, app, test_reporter, add_nearby_weather_sensors):
  50. TestReporter = app.data_generators["reporter"].get("TestReporter")
  51. reporter_sensor = add_nearby_weather_sensors.get("farther_temperature")
  52. reporter = TestReporter(config={"a": "1"})
  53. res = reporter.compute(
  54. input=[{"sensor": reporter_sensor}],
  55. output=[{"sensor": reporter_sensor}],
  56. start=datetime(2023, 1, 1, tzinfo=UTC),
  57. end=datetime(2023, 1, 2, tzinfo=UTC),
  58. )[0]["data"]
  59. assert res.lineage.sources[0].attributes.get("data_generator").get("config") == {
  60. "a": "1"
  61. }
  62. reporter = TestReporter(config={"a": "1"}, save_config=False)
  63. res = reporter.compute(
  64. input=[{"sensor": reporter_sensor}],
  65. output=[{"sensor": reporter_sensor}],
  66. start=datetime(2023, 1, 1, tzinfo=UTC),
  67. end=datetime(2023, 1, 2, tzinfo=UTC),
  68. )[0]["data"]
  69. # check that the data_generator is not saving the config in the data_source attributes
  70. assert res.lineage.sources[0].attributes.get("data_generator") == dict()
  71. def test_data_generator_save_parameters(
  72. db, app, test_reporter, add_nearby_weather_sensors
  73. ):
  74. TestReporter = app.data_generators["reporter"].get("TestReporter")
  75. reporter_sensor = add_nearby_weather_sensors.get("farther_temperature")
  76. reporter = TestReporter(config={"a": "1"}, save_parameters=True)
  77. parameters = {
  78. "input": [{"sensor": reporter_sensor.id}],
  79. "output": [{"sensor": reporter_sensor.id}],
  80. "start": "2023-01-01T00:00:00+00:00",
  81. "end": "2023-01-02T00:00:00+00:00",
  82. "b": "test",
  83. }
  84. parameters_without_start_end = {
  85. "input": [{"sensor": reporter_sensor.id}],
  86. "output": [{"sensor": reporter_sensor.id}],
  87. "b": "test",
  88. }
  89. res = reporter.compute(parameters=parameters)[0]["data"]
  90. assert res.lineage.sources[0].attributes.get("data_generator").get("config") == {
  91. "a": "1"
  92. }
  93. assert (
  94. res.lineage.sources[0].attributes.get("data_generator").get("parameters")
  95. == parameters_without_start_end
  96. )
  97. dg2 = reporter.data_source.data_generator
  98. parameters_2 = {
  99. "start": "2023-01-01T10:00:00+00:00",
  100. "end": "2023-01-02T00:00:00+00:00",
  101. "b": "test2",
  102. }
  103. res = dg2.compute(parameters=parameters_2)[0]["data"]
  104. # check that compute gets data stored in the DB (i.e. `input`/`output`) and updated data
  105. # from the method call (e.g. field `b``)
  106. assert dg2._parameters["b"] == parameters_2["b"]
  107. assert dg2._parameters["start"].isoformat() == parameters_2["start"]
  108. def test_keep_last_version():
  109. s1 = DataSource(name="s1", model="model 1", type="forecaster", version="0.1.0")
  110. s2 = DataSource(name="s1", model="model 1", type="forecaster")
  111. s3 = DataSource(name="s1", model="model 2", type="forecaster")
  112. s4 = DataSource(name="s1", model="model 2", type="scheduler")
  113. def create_dummy_frame(sources: list[DataSource]) -> tb.BeliefsDataFrame:
  114. sensor = tb.Sensor("A")
  115. beliefs = [
  116. tb.TimedBelief(
  117. sensor=sensor,
  118. event_start=datetime(2023, 1, 1, tzinfo=UTC),
  119. belief_time=datetime(2023, 1, 1, tzinfo=UTC),
  120. event_value=1,
  121. source=s,
  122. )
  123. for s in sources
  124. ]
  125. bdf = tb.BeliefsDataFrame(beliefs)
  126. bdf["source.name"] = (
  127. bdf.index.get_level_values("source").map(lambda x: x.name).values
  128. )
  129. bdf["source.model"] = (
  130. bdf.index.get_level_values("source").map(lambda x: x.model).values
  131. )
  132. bdf["source.type"] = (
  133. bdf.index.get_level_values("source").map(lambda x: x.type).values
  134. )
  135. bdf["source.version"] = (
  136. bdf.index.get_level_values("source").map(lambda x: x.version).values
  137. )
  138. return bdf
  139. # the data source with no version is assumed to have version 0.0.0
  140. bdf = create_dummy_frame([s1, s2])
  141. np.testing.assert_array_equal(keep_latest_version(bdf).sources, [s1])
  142. # sources with different models are preserved
  143. bdf = create_dummy_frame([s1, s2, s3])
  144. np.testing.assert_array_equal(keep_latest_version(bdf).sources, [s1, s3])
  145. # two sources with the same model but different types
  146. bdf = create_dummy_frame([s3, s4])
  147. np.testing.assert_array_equal(keep_latest_version(bdf).sources, [s3, s4])
  148. # repeated source
  149. bdf = create_dummy_frame([s1, s1])
  150. np.testing.assert_array_equal(keep_latest_version(bdf).sources, [s1])