data_sources.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import annotations
  2. from flask import current_app
  3. from sqlalchemy import select
  4. from flexmeasures import User
  5. from flexmeasures.data import db
  6. from flexmeasures.data.models.data_sources import DataSource
  7. from flexmeasures.data.models.user import is_user
  8. def get_or_create_source(
  9. source: User | str,
  10. source_type: str | None = None,
  11. model: str | None = None,
  12. version: str | None = None,
  13. attributes: dict | None = None,
  14. flush: bool = True,
  15. ) -> DataSource:
  16. if is_user(source):
  17. source_type = "user"
  18. query = select(DataSource).filter(DataSource.type == source_type)
  19. if model is not None:
  20. query = query.filter(DataSource.model == model)
  21. if version is not None:
  22. query = query.filter(DataSource.version == version)
  23. if attributes is not None:
  24. query = query.filter(
  25. DataSource.attributes_hash == DataSource.hash_attributes(attributes)
  26. )
  27. if is_user(source):
  28. query = query.filter(DataSource.user == source)
  29. elif isinstance(source, str):
  30. query = query.filter(DataSource.name == source)
  31. else:
  32. raise TypeError("source should be of type User or str")
  33. _source = db.session.execute(query).scalar_one_or_none()
  34. if not _source:
  35. if is_user(source):
  36. _source = DataSource(user=source, model=model, version=version)
  37. else:
  38. if source_type is None:
  39. raise TypeError("Please specify a source type")
  40. _source = DataSource(
  41. name=source,
  42. model=model,
  43. version=version,
  44. type=source_type,
  45. attributes=attributes,
  46. )
  47. current_app.logger.info(f"Setting up {_source} as new data source...")
  48. db.session.add(_source)
  49. if flush:
  50. # assigns id so that we can reference the new object in the current db session
  51. db.session.flush()
  52. return _source
  53. def get_source_or_none(
  54. source: int | str, source_type: str | None = None
  55. ) -> DataSource | None:
  56. """
  57. :param source: source id
  58. :param source_type: optionally, filter by source type
  59. """
  60. query = select(DataSource)
  61. if source_type is not None:
  62. query = query.filter(DataSource.type == source_type)
  63. query = query.filter(DataSource.id == int(source))
  64. return db.session.execute(query).scalar_one_or_none()