users.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from flask import abort
  2. from flask_security import current_user
  3. from marshmallow import fields
  4. from sqlalchemy import select
  5. from flexmeasures.data import db
  6. from flexmeasures.data.models.user import User, Account
  7. class AccountIdField(fields.Integer):
  8. """
  9. Field that represents an account ID. It deserializes from the account id to an account instance.
  10. """
  11. def _deserialize(self, account_id: str, attr, obj, **kwargs) -> Account:
  12. account: Account = db.session.execute(
  13. select(Account).filter_by(id=int(account_id))
  14. ).scalar_one_or_none()
  15. if account is None:
  16. raise abort(404, f"Account {account_id} not found")
  17. return account
  18. def _serialize(self, account: Account, attr, data, **kwargs) -> int:
  19. return account.id
  20. @classmethod
  21. def load_current(cls):
  22. """
  23. Use this with the load_default arg to __init__ if you want the current user's account
  24. by default.
  25. """
  26. return current_user.account if not current_user.is_anonymous else None
  27. class UserIdField(fields.Integer):
  28. """
  29. Field that represents a user ID. It deserializes from the user id to a user instance.
  30. """
  31. def __init__(self, *args, **kwargs):
  32. kwargs["load_default"] = lambda: (
  33. current_user if not current_user.is_anonymous else None
  34. )
  35. super().__init__(*args, **kwargs)
  36. def _deserialize(self, user_id: int, attr, obj, **kwargs) -> User:
  37. user: User = db.session.execute(
  38. select(User).filter_by(id=int(user_id))
  39. ).scalar_one_or_none()
  40. if user is None:
  41. raise abort(404, f"User {user_id} not found")
  42. return user
  43. def _serialize(self, user: User, attr, data, **kwargs) -> int:
  44. return user.id