accounts.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from __future__ import annotations
  2. from flask_classful import FlaskView, route
  3. from flexmeasures.data import db
  4. from webargs.flaskparser import use_kwargs, use_args
  5. from flask_security import current_user, auth_required
  6. from flask_json import as_json
  7. from sqlalchemy import or_, select, func
  8. from marshmallow import fields
  9. import marshmallow.validate as validate
  10. from flask_sqlalchemy.pagination import SelectPagination
  11. from flexmeasures.auth.policy import user_has_admin_access
  12. from flexmeasures.auth.decorators import permission_required_for_context
  13. from flexmeasures.data.models.audit_log import AuditLog
  14. from flexmeasures.data.models.user import Account, User
  15. from flexmeasures.data.models.generic_assets import GenericAsset
  16. from flexmeasures.data.services.accounts import get_accounts, get_audit_log_records
  17. from flexmeasures.api.common.schemas.users import AccountIdField
  18. from flexmeasures.data.schemas.account import AccountSchema
  19. from flexmeasures.api.common.schemas.search import SearchFilterField
  20. from flexmeasures.utils.time_utils import server_now
  21. """
  22. API endpoints to manage accounts.
  23. Both POST (to create) and DELETE are not accessible via the API, but as CLI functions.
  24. Editing (PATCH) is also not yet implemented, but might be next, e.g. for the name or roles.
  25. """
  26. # Instantiate schemas outside of endpoint logic to minimize response time
  27. account_schema = AccountSchema()
  28. accounts_schema = AccountSchema(many=True)
  29. partial_account_schema = AccountSchema(partial=True)
  30. class AccountAPI(FlaskView):
  31. route_base = "/accounts"
  32. trailing_slash = False
  33. decorators = [auth_required()]
  34. @route("", methods=["GET"])
  35. @use_kwargs(
  36. {
  37. "page": fields.Int(
  38. required=False, validate=validate.Range(min=1), load_default=None
  39. ),
  40. "per_page": fields.Int(
  41. required=False, validate=validate.Range(min=1), load_default=10
  42. ),
  43. "filter": SearchFilterField(required=False, load_default=None),
  44. "sort_by": fields.Str(
  45. required=False,
  46. load_default=None,
  47. validate=validate.OneOf(["id", "name", "assets", "users"]),
  48. ),
  49. "sort_dir": fields.Str(
  50. required=False,
  51. load_default=None,
  52. validate=validate.OneOf(["asc", "desc"]),
  53. ),
  54. },
  55. location="query",
  56. )
  57. @as_json
  58. def index(
  59. self,
  60. page: int | None = None,
  61. per_page: int | None = None,
  62. filter: list[str] | None = None,
  63. sort_by: str | None = None,
  64. sort_dir: str | None = None,
  65. ):
  66. """API endpoint to list all accounts accessible to the current user.
  67. .. :quickref: Account; Download account list
  68. This endpoint returns all accessible accounts.
  69. Accessible accounts are your own account and accounts you are a consultant for, or all accounts for admins.
  70. The endpoint supports pagination of the asset list using the `page` and `per_page` query parameters.
  71. - If the `page` parameter is not provided, all assets are returned, without pagination information. The result will be a list of assets.
  72. - If a `page` parameter is provided, the response will be paginated, showing a specific number of assets per page as defined by `per_page` (default is 10).
  73. - If a search 'filter' such as 'solar "ACME corp"' is provided, the response will filter out assets where each search term is either present in their name or account name.
  74. The response schema for pagination is inspired by https://datatables.net/manual/server-side#Returned-data
  75. **Example response**
  76. An example of one account being returned:
  77. .. sourcecode:: json
  78. {
  79. "data" : [
  80. {
  81. 'id': 1,
  82. 'name': 'Test Account'
  83. 'account_roles': [1, 3],
  84. 'consultancy_account_id': 2,
  85. 'primary_color': '#1a3443'
  86. 'secondary_color': '#f1a122'
  87. 'logo_url': 'https://example.com/logo.png'
  88. }
  89. ],
  90. "num-records" : 1,
  91. "filtered-records" : 1
  92. }
  93. If no pagination is requested, the response only consists of the list under the "data" key.
  94. :reqheader Authorization: The authentication token
  95. :reqheader Content-Type: application/json
  96. :resheader Content-Type: application/json
  97. :status 200: PROCESSED
  98. :status 400: INVALID_REQUEST
  99. :status 401: UNAUTHORIZED
  100. :status 403: INVALID_SENDER
  101. :status 422: UNPROCESSABLE_ENTITY
  102. """
  103. if user_has_admin_access(current_user, "read"):
  104. accounts = get_accounts()
  105. else:
  106. accounts = [current_user.account] + (
  107. current_user.account.consultancy_client_accounts
  108. if "consultant" in current_user.roles
  109. else []
  110. )
  111. query = db.session.query(Account).filter(
  112. Account.id.in_([a.id for a in accounts])
  113. )
  114. if filter:
  115. search_terms = filter[0].split(" ")
  116. query = query.filter(
  117. or_(*[Account.name.ilike(f"%{term}%") for term in search_terms])
  118. )
  119. if sort_by is not None and sort_dir is not None:
  120. valid_sort_columns = {
  121. "id": Account.id,
  122. "name": Account.name,
  123. "assets": func.count(GenericAsset.id),
  124. "users": func.count(User.id),
  125. }
  126. query = query.join(GenericAsset, isouter=True).join(User, isouter=True)
  127. query = query.group_by(Account.id).order_by(
  128. valid_sort_columns[sort_by].asc()
  129. if sort_dir == "asc"
  130. else valid_sort_columns[sort_by].desc()
  131. )
  132. if page:
  133. select_pagination: SelectPagination = db.paginate(
  134. query, per_page=per_page, page=page
  135. )
  136. accounts_reponse: list = []
  137. for account in select_pagination.items:
  138. user_count_query = select(func.count(User.id)).where(
  139. User.account_id == account.id
  140. )
  141. asset_count_query = select(func.count(GenericAsset.id)).where(
  142. GenericAsset.account_id == account.id
  143. )
  144. user_count = db.session.execute(user_count_query).scalar()
  145. asset_count = db.session.execute(asset_count_query).scalar()
  146. accounts_reponse.append(
  147. {
  148. **account_schema.dump(account),
  149. "user_count": user_count,
  150. "asset_count": asset_count,
  151. }
  152. )
  153. response = {
  154. "data": accounts_reponse,
  155. "num-records": select_pagination.total,
  156. "filtered-records": select_pagination.total,
  157. }
  158. else:
  159. response = accounts_schema.dump(query.all(), many=True)
  160. return response, 200
  161. @route("/<id>", methods=["GET"])
  162. @use_kwargs({"account": AccountIdField(data_key="id")}, location="path")
  163. @permission_required_for_context("read", ctx_arg_name="account")
  164. @as_json
  165. def get(self, id: int, account: Account):
  166. """API endpoint to get an account.
  167. .. :quickref: Account; Get an account
  168. This endpoint retrieves an account, given its id.
  169. Only admins, consultants and users belonging to the account itself can use this endpoint.
  170. **Example response**
  171. .. sourcecode:: json
  172. {
  173. 'id': 1,
  174. 'name': 'Test Account'
  175. 'account_roles': [1, 3],
  176. 'consultancy_account_id': 2,
  177. }
  178. :reqheader Authorization: The authentication token
  179. :reqheader Content-Type: application/json
  180. :resheader Content-Type: application/json
  181. :status 200: PROCESSED
  182. :status 400: INVALID_REQUEST, REQUIRED_INFO_MISSING, UNEXPECTED_PARAMS
  183. :status 401: UNAUTHORIZED
  184. :status 403: INVALID_SENDER
  185. :status 422: UNPROCESSABLE_ENTITY
  186. """
  187. return account_schema.dump(account), 200
  188. @route("/<id>", methods=["PATCH"])
  189. @use_args(partial_account_schema)
  190. @use_kwargs({"account": AccountIdField(data_key="id")}, location="path")
  191. @permission_required_for_context("update", ctx_arg_name="account")
  192. @as_json
  193. def patch(self, account_data: dict, id: int, account: Account):
  194. """Update an account given its identifier.
  195. .. :quickref: Account; Update an account
  196. This endpoint sets data for an existing account.
  197. The following fields are not allowed to be updated:
  198. - id
  199. The following fields are only editable if user role is admin:
  200. - consultancy_account_id
  201. **Example request**
  202. .. sourcecode:: json
  203. {
  204. 'name': 'Test Account'
  205. 'primary_color': '#1a3443'
  206. 'secondary_color': '#f1a122'
  207. 'logo_url': 'https://example.com/logo.png'
  208. 'consultancy_account_id': 2,
  209. }
  210. **Example response**
  211. The whole account is returned in the response:
  212. .. sourcecode:: json
  213. {
  214. 'id': 1,
  215. 'name': 'Test Account'
  216. 'account_roles': [1, 3],
  217. 'primary_color': '#1a3443'
  218. 'secondary_color': '#f1a122'
  219. 'logo_url': 'https://example.com/logo.png'
  220. 'consultancy_account_id': 2,
  221. }
  222. :reqheader Authorization: The authentication token
  223. :reqheader Content-Type: application/json
  224. :resheader Content-Type: application/json
  225. :status 200: UPDATED
  226. :status 400: INVALID_REQUEST, REQUIRED_INFO_MISSING, UNEXPECTED_PARAMS
  227. :status 401: UNAUTHORIZED
  228. :status 403: INVALID_SENDER
  229. :status 422: UNPROCESSABLE_ENTITY
  230. """
  231. # Get existing consultancy_account_id
  232. existing_consultancy_account_id = (
  233. account.consultancy_account.id if account.consultancy_account else None
  234. )
  235. if not user_has_admin_access(current_user, "update"):
  236. # Remove consultancy_account_id from account_data if no admin access
  237. account_data.pop("consultancy_account_id", None)
  238. else:
  239. # Check if consultancy_account_id has changed
  240. new_consultancy_account_id = account_data.get("consultancy_account_id")
  241. if existing_consultancy_account_id != new_consultancy_account_id:
  242. new_consultant_account = db.session.query(Account).get(
  243. new_consultancy_account_id
  244. )
  245. # Validate new consultant account
  246. if (
  247. not new_consultant_account
  248. or new_consultant_account.id == account.id
  249. ):
  250. return {"errors": ["Invalid consultancy_account_id"]}, 422
  251. # Track modified fields
  252. fields_to_check = [
  253. "name",
  254. "primary_color",
  255. "secondary_color",
  256. "logo_url",
  257. "consultancy_account_id",
  258. ]
  259. modified_fields = {
  260. field: getattr(account, field)
  261. for field in fields_to_check
  262. if account_data.get(field) != getattr(account, field)
  263. }
  264. # Compile modified fields string
  265. modified_fields_str = ", ".join(modified_fields.keys())
  266. for k, v in account_data.items():
  267. setattr(account, k, v)
  268. event_message = f"Account Updated, Field: {modified_fields_str}"
  269. # Add Audit log
  270. account_audit_log = AuditLog(
  271. event_datetime=server_now(),
  272. event=event_message,
  273. active_user_id=current_user.id,
  274. active_user_name=current_user.username,
  275. affected_user_id=current_user.id,
  276. affected_account_id=account.id,
  277. )
  278. db.session.add(account_audit_log)
  279. db.session.commit()
  280. return account_schema.dump(account), 200
  281. @route("/<id>/auditlog", methods=["GET"])
  282. @use_kwargs({"account": AccountIdField(data_key="id")}, location="path")
  283. @permission_required_for_context(
  284. "read",
  285. ctx_arg_name="account",
  286. pass_ctx_to_loader=True,
  287. ctx_loader=AuditLog.account_table_acl,
  288. )
  289. @as_json
  290. def auditlog(self, id: int, account: Account):
  291. """API endpoint to get history of account actions.
  292. **Example response**
  293. .. sourcecode:: json
  294. [
  295. {
  296. 'event': 'User test user deleted',
  297. 'event_datetime': '2021-01-01T00:00:00',
  298. 'active_user_id': 1,
  299. }
  300. ]
  301. :reqheader Authorization: The authentication token
  302. :reqheader Content-Type: application/json
  303. :resheader Content-Type: application/json
  304. :status 200: PROCESSED
  305. :status 400: INVALID_REQUEST, REQUIRED_INFO_MISSING, UNEXPECTED_PARAMS
  306. :status 401: UNAUTHORIZED
  307. :status 403: INVALID_SENDER
  308. :status 422: UNPROCESSABLE_ENTITY
  309. """
  310. audit_logs = get_audit_log_records(account)
  311. audit_logs = [
  312. {k: getattr(log, k) for k in ("event", "event_datetime", "active_user_id")}
  313. for log in audit_logs
  314. ]
  315. return audit_logs, 200