12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- from flask import abort
- from flask_security import current_user
- from marshmallow import fields
- from sqlalchemy import select
- from flexmeasures.data import db
- from flexmeasures.data.models.user import User, Account
- class AccountIdField(fields.Integer):
- """
- Field that represents an account ID. It deserializes from the account id to an account instance.
- """
- def _deserialize(self, account_id: str, attr, obj, **kwargs) -> Account:
- account: Account = db.session.execute(
- select(Account).filter_by(id=int(account_id))
- ).scalar_one_or_none()
- if account is None:
- raise abort(404, f"Account {account_id} not found")
- return account
- def _serialize(self, account: Account, attr, data, **kwargs) -> int:
- return account.id
- @classmethod
- def load_current(cls):
- """
- Use this with the load_default arg to __init__ if you want the current user's account
- by default.
- """
- return current_user.account if not current_user.is_anonymous else None
- class UserIdField(fields.Integer):
- """
- Field that represents a user ID. It deserializes from the user id to a user instance.
- """
- def __init__(self, *args, **kwargs):
- kwargs["load_default"] = lambda: (
- current_user if not current_user.is_anonymous else None
- )
- super().__init__(*args, **kwargs)
- def _deserialize(self, user_id: int, attr, obj, **kwargs) -> User:
- user: User = db.session.execute(
- select(User).filter_by(id=int(user_id))
- ).scalar_one_or_none()
- if user is None:
- raise abort(404, f"User {user_id} not found")
- return user
- def _serialize(self, user: User, attr, data, **kwargs) -> int:
- return user.id
|