1from plotly.express._core import build_dataframe
2from plotly.express._doc import make_docstring
3from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
4import numpy as np
5import pandas as pd
6
7
8def _project_latlon_to_wgs84(lat, lon):
9    """
10    Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
11    """
12    x = lon * np.pi / 180
13    y = np.arctanh(np.sin(lat * np.pi / 180))
14    return x, y
15
16
17def _project_wgs84_to_latlon(x, y):
18    """
19    Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
20    """
21    lon = x * 180 / np.pi
22    lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
23    return lat, lon
24
25
26def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim):
27    """
28    Get the mapbox zoom level given bounds and a figure dimension
29    Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds
30    """
31
32    scale = (
33        2  # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256
34    )
35    WORLD_DIM = {"height": 256 * scale, "width": 256 * scale}
36    ZOOM_MAX = 18
37
38    def latRad(lat):
39        sin = np.sin(lat * np.pi / 180)
40        radX2 = np.log((1 + sin) / (1 - sin)) / 2
41        return max(min(radX2, np.pi), -np.pi) / 2
42
43    def zoom(mapPx, worldPx, fraction):
44        return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2)
45
46    latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi
47
48    lngDiff = lon_max - lon_min
49    lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360
50
51    latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction)
52    lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction)
53
54    return min(latZoom, lngZoom, ZOOM_MAX)
55
56
57def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count):
58    """
59    Computes the aggregation at hexagonal bin level.
60    Also defines the coordinates of the hexagons for plotting.
61    The binning is inspired by matplotlib's implementation.
62
63    Parameters
64    ----------
65    x : np.ndarray
66        Array of x values (shape N)
67    y : np.ndarray
68        Array of y values (shape N)
69    x_range : np.ndarray
70        Min and max x (shape 2)
71    y_range : np.ndarray
72        Min and max y (shape 2)
73    color : np.ndarray
74        Metric to aggregate at hexagon level (shape N)
75    nx : int
76        Number of hexagons horizontally
77    agg_func : function
78        Numpy compatible aggregator, this function must take a one-dimensional
79        np.ndarray as input and output a scalar
80    min_count : int
81        Minimum number of points in the hexagon for the hexagon to be displayed
82
83    Returns
84    -------
85    np.ndarray
86        X coordinates of each hexagon (shape M x 6)
87    np.ndarray
88        Y coordinates of each hexagon (shape M x 6)
89    np.ndarray
90        Centers of the hexagons (shape M x 2)
91    np.ndarray
92        Aggregated value in each hexagon (shape M)
93
94    """
95    xmin = x_range.min()
96    xmax = x_range.max()
97    ymin = y_range.min()
98    ymax = y_range.max()
99
100    # In the x-direction, the hexagons exactly cover the region from
101    # xmin to xmax. Need some padding to avoid roundoff errors.
102    padding = 1.0e-9 * (xmax - xmin)
103    xmin -= padding
104    xmax += padding
105
106    Dx = xmax - xmin
107    Dy = ymax - ymin
108    if Dx == 0 and Dy > 0:
109        dx = Dy / nx
110    elif Dx == 0 and Dy == 0:
111        dx, _ = _project_latlon_to_wgs84(1, 1)
112    else:
113        dx = Dx / nx
114    dy = dx * np.sqrt(3)
115    ny = np.ceil(Dy / dy).astype(int)
116
117    # Center the hexagons vertically since we only want regular hexagons
118    ymin -= (ymin + dy * ny - ymax) / 2
119
120    x = (x - xmin) / dx
121    y = (y - ymin) / dy
122    ix1 = np.round(x).astype(int)
123    iy1 = np.round(y).astype(int)
124    ix2 = np.floor(x).astype(int)
125    iy2 = np.floor(y).astype(int)
126
127    nx1 = nx + 1
128    ny1 = ny + 1
129    nx2 = nx
130    ny2 = ny
131    n = nx1 * ny1 + nx2 * ny2
132
133    d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
134    d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
135    bdist = d1 < d2
136
137    if color is None:
138        lattice1 = np.zeros((nx1, ny1))
139        lattice2 = np.zeros((nx2, ny2))
140        c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
141        c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
142        np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
143        np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
144        if min_count is not None:
145            lattice1[lattice1 < min_count] = np.nan
146            lattice2[lattice2 < min_count] = np.nan
147        accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
148        good_idxs = ~np.isnan(accum)
149    else:
150        if min_count is None:
151            min_count = 1
152
153        # create accumulation arrays
154        lattice1 = np.empty((nx1, ny1), dtype=object)
155        for i in range(nx1):
156            for j in range(ny1):
157                lattice1[i, j] = []
158        lattice2 = np.empty((nx2, ny2), dtype=object)
159        for i in range(nx2):
160            for j in range(ny2):
161                lattice2[i, j] = []
162
163        for i in range(len(x)):
164            if bdist[i]:
165                if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
166                    lattice1[ix1[i], iy1[i]].append(color[i])
167            else:
168                if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
169                    lattice2[ix2[i], iy2[i]].append(color[i])
170
171        for i in range(nx1):
172            for j in range(ny1):
173                vals = lattice1[i, j]
174                if len(vals) >= min_count:
175                    lattice1[i, j] = agg_func(vals)
176                else:
177                    lattice1[i, j] = np.nan
178        for i in range(nx2):
179            for j in range(ny2):
180                vals = lattice2[i, j]
181                if len(vals) >= min_count:
182                    lattice2[i, j] = agg_func(vals)
183                else:
184                    lattice2[i, j] = np.nan
185
186        accum = np.hstack(
187            (lattice1.astype(float).ravel(), lattice2.astype(float).ravel())
188        )
189        good_idxs = ~np.isnan(accum)
190
191    agreggated_value = accum[good_idxs]
192
193    centers = np.zeros((n, 2), float)
194    centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
195    centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
196    centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
197    centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5
198    centers[:, 0] *= dx
199    centers[:, 1] *= dy
200    centers[:, 0] += xmin
201    centers[:, 1] += ymin
202    centers = centers[good_idxs]
203
204    # Define normalised regular hexagon coordinates
205    hx = [0, 0.5, 0.5, 0, -0.5, -0.5]
206    hy = [
207        -0.5 / np.cos(np.pi / 6),
208        -0.5 * np.tan(np.pi / 6),
209        0.5 * np.tan(np.pi / 6),
210        0.5 / np.cos(np.pi / 6),
211        0.5 * np.tan(np.pi / 6),
212        -0.5 * np.tan(np.pi / 6),
213    ]
214
215    # Number of hexagons needed
216    m = len(centers)
217
218    # Coordinates for all hexagonal patches
219    hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0])
220    hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1])
221
222    return hxs, hys, centers, agreggated_value
223
224
225def _compute_wgs84_hexbin(
226    lat=None,
227    lon=None,
228    lat_range=None,
229    lon_range=None,
230    color=None,
231    nx=None,
232    agg_func=None,
233    min_count=None,
234):
235    """
236    Computes the lat-lon aggregation at hexagonal bin level.
237    Latitude and longitude need to be projected to WGS84 before aggregating
238    in order to display regular hexagons on the map.
239
240    Parameters
241    ----------
242    lat : np.ndarray
243        Array of latitudes (shape N)
244    lon : np.ndarray
245        Array of longitudes (shape N)
246    lat_range : np.ndarray
247        Min and max latitudes (shape 2)
248    lon_range : np.ndarray
249        Min and max longitudes (shape 2)
250    color : np.ndarray
251        Metric to aggregate at hexagon level (shape N)
252    nx : int
253        Number of hexagons horizontally
254    agg_func : function
255        Numpy compatible aggregator, this function must take a one-dimensional
256        np.ndarray as input and output a scalar
257    min_count : int
258        Minimum number of points in the hexagon for the hexagon to be displayed
259
260    Returns
261    -------
262    np.ndarray
263        Lat coordinates of each hexagon (shape M x 6)
264    np.ndarray
265        Lon coordinates of each hexagon (shape M x 6)
266    pd.Series
267        Unique id for each hexagon, to be used in the geojson data (shape M)
268    np.ndarray
269        Aggregated value in each hexagon (shape M)
270
271    """
272    # Project to WGS 84
273    x, y = _project_latlon_to_wgs84(lat, lon)
274
275    if lat_range is None:
276        lat_range = np.array([lat.min(), lat.max()])
277    if lon_range is None:
278        lon_range = np.array([lon.min(), lon.max()])
279
280    x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
281
282    hxs, hys, centers, agreggated_value = _compute_hexbin(
283        x, y, x_range, y_range, color, nx, agg_func, min_count
284    )
285
286    # Convert back to lat-lon
287    hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
288
289    # Create unique feature id based on hexagon center
290    centers = centers.astype(str)
291    hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1])
292
293    return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value
294
295
296def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
297    """
298    Creates a geojson of hexagonal features based on the outputs of
299    _compute_wgs84_hexbin
300    """
301    features = []
302    if ids is None:
303        ids = np.arange(len(hexagons_lats))
304    for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids):
305        points = np.array([lon, lat]).T.tolist()
306        points.append(points[0])
307        features.append(
308            dict(
309                type="Feature",
310                id=idx,
311                geometry=dict(type="Polygon", coordinates=[points]),
312            )
313        )
314    return dict(type="FeatureCollection", features=features)
315
316
317def create_hexbin_mapbox(
318    data_frame=None,
319    lat=None,
320    lon=None,
321    color=None,
322    nx_hexagon=5,
323    agg_func=None,
324    animation_frame=None,
325    color_discrete_sequence=None,
326    color_discrete_map={},
327    labels={},
328    color_continuous_scale=None,
329    range_color=None,
330    color_continuous_midpoint=None,
331    opacity=None,
332    zoom=None,
333    center=None,
334    mapbox_style=None,
335    title=None,
336    template=None,
337    width=None,
338    height=None,
339    min_count=None,
340    show_original_data=False,
341    original_data_marker=None,
342):
343    """
344    Returns a figure aggregating scattered points into connected hexagons
345    """
346    args = build_dataframe(args=locals(), constructor=None)
347
348    if agg_func is None:
349        agg_func = np.mean
350
351    lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values
352    lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values
353
354    hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
355        lat=args["data_frame"][args["lat"]].values,
356        lon=args["data_frame"][args["lon"]].values,
357        lat_range=lat_range,
358        lon_range=lon_range,
359        color=None,
360        nx=nx_hexagon,
361        agg_func=agg_func,
362        min_count=min_count,
363    )
364
365    geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
366
367    if zoom is None:
368        if height is None and width is None:
369            mapDim = dict(height=450, width=450)
370        elif height is None and width is not None:
371            mapDim = dict(height=450, width=width)
372        elif height is not None and width is None:
373            mapDim = dict(height=height, width=height)
374        else:
375            mapDim = dict(height=height, width=width)
376        zoom = _getBoundsZoomLevel(
377            lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim
378        )
379
380    if center is None:
381        center = dict(lat=lat_range.mean(), lon=lon_range.mean())
382
383    if args["animation_frame"] is not None:
384        groups = args["data_frame"].groupby(args["animation_frame"]).groups
385    else:
386        groups = {0: args["data_frame"].index}
387
388    agg_data_frame_list = []
389    for frame, index in groups.items():
390        df = args["data_frame"].loc[index]
391        _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
392            lat=df[args["lat"]].values,
393            lon=df[args["lon"]].values,
394            lat_range=lat_range,
395            lon_range=lon_range,
396            color=df[args["color"]].values if args["color"] else None,
397            nx=nx_hexagon,
398            agg_func=agg_func,
399            min_count=min_count,
400        )
401        agg_data_frame_list.append(
402            pd.DataFrame(
403                np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"]
404            )
405        )
406    agg_data_frame = (
407        pd.concat(agg_data_frame_list, axis=0, keys=groups.keys())
408        .rename_axis(index=("frame", "index"))
409        .reset_index("frame")
410    )
411
412    agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"])
413
414    if range_color is None:
415        range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]
416
417    fig = choropleth_mapbox(
418        data_frame=agg_data_frame,
419        geojson=geojson,
420        locations="locations",
421        color="color",
422        hover_data={"color": True, "locations": False, "frame": False},
423        animation_frame=("frame" if args["animation_frame"] is not None else None),
424        color_discrete_sequence=color_discrete_sequence,
425        color_discrete_map=color_discrete_map,
426        labels=labels,
427        color_continuous_scale=color_continuous_scale,
428        range_color=range_color,
429        color_continuous_midpoint=color_continuous_midpoint,
430        opacity=opacity,
431        zoom=zoom,
432        center=center,
433        mapbox_style=mapbox_style,
434        title=title,
435        template=template,
436        width=width,
437        height=height,
438    )
439
440    if show_original_data:
441        original_fig = scatter_mapbox(
442            data_frame=(
443                args["data_frame"].sort_values(by=args["animation_frame"])
444                if args["animation_frame"] is not None
445                else args["data_frame"]
446            ),
447            lat=args["lat"],
448            lon=args["lon"],
449            animation_frame=args["animation_frame"],
450        )
451        original_fig.data[0].hoverinfo = "skip"
452        original_fig.data[0].hovertemplate = None
453        original_fig.data[0].marker = original_data_marker
454
455        fig.add_trace(original_fig.data[0])
456
457        if args["animation_frame"] is not None:
458            for i in range(len(original_fig.frames)):
459                original_fig.frames[i].data[0].hoverinfo = "skip"
460                original_fig.frames[i].data[0].hovertemplate = None
461                original_fig.frames[i].data[0].marker = original_data_marker
462
463                fig.frames[i].data = [
464                    fig.frames[i].data[0],
465                    original_fig.frames[i].data[0],
466                ]
467
468    return fig
469
470
471create_hexbin_mapbox.__doc__ = make_docstring(
472    create_hexbin_mapbox,
473    override_dict=dict(
474        nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
475        agg_func=[
476            "function",
477            "Numpy array aggregator, it must take as input a 1D array",
478            "and output a scalar value.",
479        ],
480        min_count=[
481            "int",
482            "Minimum number of points in a hexagon for it to be displayed.",
483            "If None and color is not set, display all hexagons.",
484            "If None and color is set, only display hexagons that contain points.",
485        ],
486        show_original_data=[
487            "bool",
488            "Whether to show the original data on top of the hexbin aggregation.",
489        ],
490        original_data_marker=["dict", "Scattermapbox marker options."],
491    ),
492)
493