994170c26bc6_add_account_table.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """add account table
  2. Revision ID: 994170c26bc6
  3. Revises: b6d49ed7cceb
  4. Create Date: 2021-08-11 19:21:07.083253
  5. """
  6. from typing import List, Tuple, Optional
  7. import os
  8. import json
  9. from alembic import context, op
  10. import sqlalchemy as sa
  11. from sqlalchemy import orm
  12. import inflection
  13. from flexmeasures.data.models.user import Account, User
  14. from flexmeasures.data.models.time_series import Sensor
  15. # revision identifiers, used by Alembic.
  16. revision = "994170c26bc6"
  17. down_revision = "b6d49ed7cceb"
  18. branch_labels = None
  19. depends_on = None
  20. asset_ownership_backup_script = "generic_asset_fm_user_ownership.sql"
  21. t_assets = sa.Table(
  22. "asset",
  23. sa.MetaData(),
  24. sa.Column("id"),
  25. sa.Column("owner_id"),
  26. )
  27. t_generic_assets = sa.Table(
  28. "generic_asset",
  29. sa.MetaData(),
  30. sa.Column("id"),
  31. sa.Column("name"),
  32. sa.Column("account_id"),
  33. sa.Column("generic_asset_type_id"),
  34. )
  35. def upgrade():
  36. """
  37. Add account table.
  38. 1. Users need an account. You can pass this info in (user ID to account name) like this:
  39. flexmeasures db upgrade +1 -x '{"1": "One account", "2": "Bccount", "4": "Bccount"}'
  40. Note that user IDs are strings here, as this is a JSON array.
  41. The +1 makes sure we only upgrade by 1 revision, as these arguments are only meant to be used by this upgrade function.
  42. Users not mentioned here get an account derived from their email address' main domain, capitalized (info@company.com becomes "Company")
  43. 2. The ownership of a generic_asset now goes to account.
  44. Here we fill in the user's new account (see point 1).
  45. (we save a backup of the generic_asset.owner_id info which linked to fm_user)
  46. The old-style asset's ownership remains in place for now! Our code will keep it consistent, until we have completed the move.
  47. """
  48. backup_generic_asset_user_associations()
  49. upgrade_schema()
  50. upgrade_data()
  51. op.alter_column("fm_user", "account_id", nullable=False)
  52. op.drop_column("generic_asset", "owner_id")
  53. def downgrade():
  54. downgrade_schema()
  55. downgrade_data()
  56. def upgrade_schema():
  57. op.create_table(
  58. "account",
  59. sa.Column("id", sa.Integer(), nullable=False),
  60. sa.Column("name", sa.String(length=100), nullable=True),
  61. sa.PrimaryKeyConstraint("id", name=op.f("account_pkey")),
  62. sa.UniqueConstraint("name", name=op.f("account_name_key")),
  63. )
  64. op.add_column("fm_user", sa.Column("account_id", sa.Integer(), nullable=True))
  65. op.create_foreign_key(
  66. op.f("fm_user_account_id_account_fkey"),
  67. "fm_user",
  68. "account",
  69. ["account_id"],
  70. ["id"],
  71. )
  72. op.add_column("generic_asset", sa.Column("account_id", sa.Integer(), nullable=True))
  73. op.drop_constraint(
  74. "generic_asset_owner_id_fm_user_fkey", "generic_asset", type_="foreignkey"
  75. )
  76. op.create_foreign_key(
  77. op.f("generic_asset_account_id_account_fkey"),
  78. "generic_asset",
  79. "account",
  80. ["account_id"],
  81. ["id"],
  82. ondelete="CASCADE",
  83. )
  84. def upgrade_data():
  85. # add custom accounts
  86. user_account_mappings = context.get_x_argument()
  87. connection = op.get_bind()
  88. session = orm.Session(bind=connection)
  89. for i, user_account_map in enumerate(user_account_mappings):
  90. print(user_account_map)
  91. user_account_dict = json.loads(user_account_map)
  92. for user_id, account_name in user_account_dict.items():
  93. print(
  94. f"Linking user {user_id} to account {account_name} (as from custom param) ..."
  95. )
  96. account_results = session.execute(
  97. sa.select(Account.id).filter_by(name=account_name)
  98. ).scalar_one_or_none()
  99. if account_results is None:
  100. print(f"need to create account {account_name} ...")
  101. account = Account(name=account_name)
  102. session.add(account)
  103. session.flush()
  104. account_id = account.id
  105. else:
  106. account_id = account_results
  107. user_results = session.execute(
  108. sa.select(User.id).filter_by(id=user_id)
  109. ).scalar_one_or_none()
  110. if not user_results:
  111. raise ValueError(f"User with ID {user_id} does not exist!")
  112. connection.execute(
  113. f"UPDATE fm_user SET account_id = {account_id} WHERE id = {user_id}"
  114. )
  115. # Make sure each existing user has an account
  116. for user_results in session.execute(
  117. sa.select(User.id, User.email, User.account_id)
  118. ).all():
  119. user_id = user_results[0]
  120. user_email = user_results[1]
  121. user_account_id = user_results[2]
  122. if user_account_id is None:
  123. domain = user_email.split("@")[-1].rsplit(".", maxsplit=1)[0]
  124. main_domain = domain.rsplit(".", maxsplit=1)[-1]
  125. account_name = inflection.titleize(main_domain)
  126. print(f"Linking user {user_id} to account {account_name} ...")
  127. account_results = (
  128. session.query(Account.id).filter_by(name=account_name).one_or_none()
  129. )
  130. if account_results is None:
  131. print(f"need to create account {account_name} ...")
  132. account = Account(name=account_name)
  133. session.add(account)
  134. session.flush()
  135. account_id = account.id
  136. else:
  137. account_id = account_results[0]
  138. connection.execute(
  139. f"UPDATE fm_user SET account_id = {account_id} WHERE id = {user_id}"
  140. )
  141. # For all generic assets, set the user's account
  142. # We query the db for old ownership directly, as the generic asset code already points to account
  143. asset_ownership_db = _generic_asset_ownership()
  144. generic_asset_results = connection.execute(
  145. sa.select(
  146. *[
  147. t_generic_assets.c.id,
  148. t_generic_assets.c.name,
  149. t_generic_assets.c.generic_asset_type_id,
  150. ]
  151. )
  152. ).all()
  153. for ga_id, ga_name, ga_generic_asset_type_id in generic_asset_results:
  154. # 1. first look into GenericAsset ownership
  155. old_owner_id = _get_old_owner_id_from_db_result(asset_ownership_db, ga_id)
  156. user_results = (
  157. session.query(User.id, User.account_id)
  158. .filter_by(id=old_owner_id)
  159. .one_or_none()
  160. if old_owner_id is not None
  161. else None
  162. )
  163. # 2. Otherwise, then try the old-style Asset's ownership (via Sensor)
  164. if user_results is None:
  165. sensor_results = (
  166. session.query(Sensor.id).filter_by(generic_asset_id=ga_id).first()
  167. )
  168. if sensor_results is None:
  169. print(
  170. f"GenericAsset {ga_id} ({ga_name}) does not have an assorted sensor. You might want to investigate ..."
  171. )
  172. continue
  173. asset_results = connection.execute(
  174. sa.select(*[t_assets.c.owner_id]).where(
  175. t_assets.c.id == sensor_results[0]
  176. )
  177. ).one_or_none()
  178. if asset_results is None:
  179. print(
  180. f"Generic asset {ga_name} does not have an asset associated, probably because it's of type {ga_generic_asset_type_id}."
  181. )
  182. else:
  183. user_results = (
  184. session.query(User.id, User.account_id)
  185. .filter_by(id=asset_results[0])
  186. .one_or_none()
  187. )
  188. if user_results is not None:
  189. account_id = user_results[1]
  190. connection.execute(
  191. sa.update(t_generic_assets)
  192. .where(t_generic_assets.c.id == ga_id)
  193. .values(account_id=account_id)
  194. )
  195. session.commit()
  196. def downgrade_schema():
  197. op.add_column(
  198. "generic_asset",
  199. sa.Column("owner_id", sa.INTEGER(), autoincrement=False, nullable=True),
  200. )
  201. op.drop_constraint(
  202. op.f("generic_asset_account_id_account_fkey"),
  203. "generic_asset",
  204. type_="foreignkey",
  205. )
  206. op.create_foreign_key(
  207. "generic_asset_owner_id_fm_user_fkey",
  208. "generic_asset",
  209. "fm_user",
  210. ["owner_id"],
  211. ["id"],
  212. ondelete="CASCADE",
  213. )
  214. op.drop_column("generic_asset", "account_id")
  215. op.drop_constraint(
  216. op.f("fm_user_account_id_account_fkey"), "fm_user", type_="foreignkey"
  217. )
  218. op.drop_column("fm_user", "account_id")
  219. op.drop_table("account")
  220. def downgrade_data():
  221. if os.path.exists(asset_ownership_backup_script):
  222. print(
  223. f"Re-applying previous asset ownership from {asset_ownership_backup_script} ..."
  224. )
  225. connection = op.get_bind()
  226. session = orm.Session(bind=connection)
  227. with open(asset_ownership_backup_script, "r") as bckp_file:
  228. for statement in bckp_file.readlines():
  229. connection.execute(statement)
  230. session.commit()
  231. else:
  232. print(f"Could not find backup script {asset_ownership_backup_script} ...")
  233. print("Previous asset ownership information is probably lost.")
  234. def backup_generic_asset_user_associations():
  235. asset_ownership_results = _generic_asset_ownership()
  236. backed_up_ownerships = 0
  237. with open(asset_ownership_backup_script, "w") as bckp_file:
  238. for aid, oid in asset_ownership_results:
  239. if oid is None:
  240. oid = "null"
  241. bckp_file.write(
  242. f"UPDATE generic_asset SET owner_id = {oid} WHERE id = {aid};\n"
  243. )
  244. backed_up_ownerships += 1
  245. if backed_up_ownerships > 0:
  246. print("Your generic_asset.owner_id associations are being dropped!")
  247. print(
  248. f"We saved UPDATE statements to put them back in {asset_ownership_backup_script}."
  249. )
  250. def _generic_asset_ownership() -> List[Tuple[int, int]]:
  251. t_asset_owners = sa.Table(
  252. "generic_asset",
  253. sa.MetaData(),
  254. sa.Column("id", sa.Integer),
  255. sa.Column("owner_id", sa.Integer),
  256. )
  257. # Use SQLAlchemy's connection and transaction to go through the data
  258. connection = op.get_bind()
  259. # Select all existing ids that need migrating, while keeping names intact
  260. asset_ownership_results = connection.execute(
  261. sa.select(
  262. *[
  263. t_asset_owners.c.id,
  264. t_asset_owners.c.owner_id,
  265. ]
  266. )
  267. ).fetchall()
  268. return asset_ownership_results
  269. def _get_old_owner_id_from_db_result(
  270. generic_asset_ownership, asset_id
  271. ) -> Optional[int]:
  272. for aid, oid in generic_asset_ownership:
  273. if aid == asset_id:
  274. return oid
  275. return None