defaults.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from __future__ import annotations
  2. from functools import wraps
  3. from typing import Callable
  4. import altair as alt
  5. FONT_SIZE = 16
  6. ANNOTATION_MARGIN = 16
  7. HEIGHT = 300
  8. WIDTH = "container"
  9. REDUCED_HEIGHT = REDUCED_WIDTH = 60
  10. SELECTOR_COLOR = "darkred"
  11. TIME_FORMAT = "%H:%M on %A %b %e, %Y"
  12. # Use default timeFormat for date or second labels, and use 24-hour clock notation for other (hour and minute) labels
  13. FORMAT_24H = "(hours(datum.value) == 0 & minutes(datum.value) == 0) | seconds(datum.value) != 0 ? timeFormat(datum.value) : timeFormat(datum.value, '%H:%M')"
  14. TIME_SELECTION_TOOLTIP = "Click and drag to select a time window"
  15. FIELD_DEFINITIONS = {
  16. "event_start": dict(
  17. field="event_start",
  18. type="temporal",
  19. title=None,
  20. axis={"labelExpr": FORMAT_24H, "labelOverlap": True, "labelSeparation": 1},
  21. ),
  22. "event_value": dict(
  23. field="event_value",
  24. type="quantitative",
  25. ),
  26. "sensor": dict(
  27. field="sensor.id",
  28. type="nominal",
  29. title=None,
  30. ),
  31. "sensor_name": dict(
  32. field="sensor.name",
  33. type="nominal",
  34. title="Sensor",
  35. ),
  36. "sensor_description": dict(
  37. field="sensor.description",
  38. type="nominal",
  39. title="Sensor",
  40. ),
  41. "source": dict(
  42. field="source.id",
  43. type="nominal",
  44. title=None,
  45. ),
  46. "source_type": dict(
  47. field="source.type",
  48. type="nominal",
  49. title="Type",
  50. ),
  51. "source_name": dict(
  52. field="source.name",
  53. type="nominal",
  54. title="Source",
  55. ),
  56. "source_model": dict(
  57. field="source.model",
  58. type="nominal",
  59. title="Model",
  60. ),
  61. "full_date": dict(
  62. field="full_date",
  63. type="nominal",
  64. title="Time and date",
  65. ),
  66. "source_name_and_id": dict(
  67. field="source_name_and_id",
  68. type="nominal",
  69. title="Source",
  70. ),
  71. }
  72. REPLAY_RULER = {
  73. "data": {"name": "replay"},
  74. "mark": {
  75. "type": "rule",
  76. },
  77. "encoding": {
  78. "x": {
  79. "field": "belief_time",
  80. "type": "temporal",
  81. },
  82. },
  83. }
  84. SHADE_LAYER = {
  85. "mark": {
  86. "type": "bar",
  87. "color": "#bbbbbb",
  88. "opacity": 0.3,
  89. "size": HEIGHT,
  90. },
  91. "encoding": {
  92. "x": dict(
  93. field="start",
  94. type="temporal",
  95. title=None,
  96. ),
  97. "x2": dict(
  98. field="end",
  99. type="temporal",
  100. title=None,
  101. ),
  102. },
  103. "params": [
  104. {
  105. "name": "highlight",
  106. "select": {"type": "point", "on": "mouseover"},
  107. },
  108. {"name": "select", "select": "point"},
  109. ],
  110. }
  111. TEXT_LAYER = {
  112. "mark": {
  113. "type": "text",
  114. "y": HEIGHT,
  115. "dy": FONT_SIZE + ANNOTATION_MARGIN,
  116. "baseline": "top",
  117. "align": "left",
  118. "fontSize": FONT_SIZE,
  119. "fontStyle": "italic",
  120. },
  121. "encoding": {
  122. "x": dict(
  123. field="start",
  124. type="temporal",
  125. title=None,
  126. ),
  127. "text": {"type": "nominal", "field": "content"},
  128. "opacity": {
  129. "condition": [
  130. {
  131. "param": "select",
  132. "empty": False,
  133. "value": 1,
  134. },
  135. {
  136. "param": "highlight",
  137. "empty": False,
  138. "value": 1,
  139. },
  140. ],
  141. "value": 0,
  142. },
  143. },
  144. }
  145. LEGIBILITY_DEFAULTS = dict(
  146. config=dict(
  147. axis=dict(
  148. titleFontSize=FONT_SIZE,
  149. labelFontSize=FONT_SIZE,
  150. ),
  151. axisY={"titleAngle": 0, "titleAlign": "left", "titleY": -15, "titleX": -40},
  152. title=dict(
  153. fontSize=FONT_SIZE,
  154. ),
  155. legend=dict(
  156. titleFontSize=FONT_SIZE,
  157. labelFontSize=FONT_SIZE,
  158. labelLimit=None,
  159. orient="bottom",
  160. columns=1,
  161. direction="vertical",
  162. ),
  163. ),
  164. )
  165. vega_lite_field_mapping = {
  166. "title": "text",
  167. "mark": "type",
  168. }
  169. def apply_chart_defaults(fn):
  170. @wraps(fn)
  171. def decorated_chart_specs(*args, **kwargs) -> dict:
  172. """:returns: dict with vega-lite specs, even when applied to an Altair chart."""
  173. dataset_name = kwargs.pop("dataset_name", None)
  174. include_annotations = kwargs.pop("include_annotations", None)
  175. if isinstance(fn, Callable):
  176. # function that returns a chart specification
  177. chart_specs: dict | alt.TopLevelMixin = fn(*args, **kwargs)
  178. else:
  179. # not a function, but a direct chart specification
  180. chart_specs: dict | alt.TopLevelMixin = fn
  181. if isinstance(chart_specs, alt.TopLevelMixin):
  182. chart_specs = chart_specs.to_dict()
  183. chart_specs.pop("$schema")
  184. # Add transform function to calculate full date
  185. if "transform" not in chart_specs:
  186. chart_specs["transform"] = []
  187. chart_specs["transform"].append(
  188. {
  189. "as": "full_date",
  190. "calculate": f"timeFormat(datum.event_start, '{TIME_FORMAT}')",
  191. }
  192. )
  193. if dataset_name:
  194. chart_specs["data"] = {"name": dataset_name}
  195. if include_annotations:
  196. annotation_shades_layer = SHADE_LAYER
  197. annotation_text_layer = TEXT_LAYER
  198. annotation_shades_layer["data"] = {
  199. "name": dataset_name + "_annotations"
  200. }
  201. annotation_text_layer["data"] = {"name": dataset_name + "_annotations"}
  202. chart_specs = {
  203. "layer": [
  204. annotation_shades_layer,
  205. chart_specs,
  206. annotation_text_layer,
  207. ]
  208. }
  209. # Fall back to default height and width, if needed
  210. if "height" not in chart_specs:
  211. chart_specs["height"] = HEIGHT
  212. if "width" not in chart_specs:
  213. chart_specs["width"] = WIDTH
  214. # Improve default legibility
  215. chart_specs = merge_vega_lite_specs(
  216. LEGIBILITY_DEFAULTS,
  217. chart_specs,
  218. )
  219. return chart_specs
  220. return decorated_chart_specs
  221. def merge_vega_lite_specs(child: dict, parent: dict) -> dict:
  222. """Merge nested dictionaries, with child inheriting values from parent.
  223. Child values are updated with parent values if they exist.
  224. In case a field is a string and that field is updated with some dict,
  225. the string is moved inside the dict under a field defined in vega_lite_field_mapping.
  226. For example, 'title' becomes 'text' and 'mark' becomes 'type'.
  227. """
  228. d = {}
  229. for k in set().union(child, parent):
  230. if k in parent and k in child:
  231. if isinstance(child[k], str) and isinstance(parent[k], str):
  232. child[k] = parent[k]
  233. elif isinstance(child[k], str):
  234. child[k] = {vega_lite_field_mapping.get(k, "type"): child[k]}
  235. elif isinstance(parent[k], str):
  236. parent[k] = {vega_lite_field_mapping.get(k, "type"): parent[k]}
  237. if (
  238. k in parent
  239. and isinstance(parent[k], dict)
  240. and k in child
  241. and isinstance(child[k], dict)
  242. ):
  243. v = merge_vega_lite_specs(child[k], parent[k])
  244. elif k in parent:
  245. v = parent[k]
  246. else:
  247. v = child[k]
  248. d[k] = v
  249. return d