test_annotations.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import pandas as pd
  2. from sqlalchemy import select, func
  3. from flexmeasures.data.models.annotations import Annotation, get_or_create_annotation
  4. from flexmeasures.data.models.data_sources import DataSource
  5. def test_get_or_create_annotation(db, setup_sources):
  6. """Save an annotation, then get_or_create a new annotation with the same contents."""
  7. num_annotations_before = db.session.scalar(
  8. select(func.count()).select_from(Annotation)
  9. )
  10. source = db.session.scalars(select(DataSource).limit(1)).first()
  11. first_annotation = Annotation(
  12. content="Dutch new year",
  13. start=pd.Timestamp("2020-01-01 00:00+01"),
  14. end=pd.Timestamp("2020-01-02 00:00+01"),
  15. source=source,
  16. type="holiday",
  17. )
  18. assert first_annotation == get_or_create_annotation(first_annotation)
  19. num_annotations_intermediate = db.session.scalar(
  20. select(func.count()).select_from(Annotation)
  21. )
  22. assert num_annotations_intermediate == num_annotations_before + 1
  23. assert (
  24. db.session.execute(
  25. select(Annotation).filter_by(
  26. content=first_annotation.content,
  27. start=first_annotation.start,
  28. end=first_annotation.end,
  29. source=first_annotation.source,
  30. type=first_annotation.type,
  31. )
  32. ).scalar_one_or_none()
  33. ) == first_annotation
  34. assert first_annotation.id is not None
  35. second_annotation = Annotation(
  36. content="Dutch new year",
  37. start=pd.Timestamp("2020-01-01 00:00+01"),
  38. end=pd.Timestamp("2020-01-02 00:00+01"),
  39. source=source,
  40. type="holiday",
  41. )
  42. assert first_annotation == get_or_create_annotation(second_annotation)
  43. num_annotations_after = db.session.scalar(select(func.count(Annotation.id)))
  44. assert num_annotations_after == num_annotations_intermediate
  45. assert second_annotation.id is None
  46. def test_search_annotations(db, setup_annotations):
  47. account = setup_annotations["account"]
  48. asset = setup_annotations["asset"]
  49. sensor = setup_annotations["sensor"]
  50. for obj in (account, asset, sensor):
  51. annotations = getattr(obj, "search_annotations")()
  52. assert len(annotations) == 1
  53. assert annotations[0].content == "Dutch new year"