1from plotly.express._core import build_dataframe 2from plotly.express._doc import make_docstring 3from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox 4import numpy as np 5import pandas as pd 6 7 8def _project_latlon_to_wgs84(lat, lon): 9 """ 10 Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map 11 """ 12 x = lon * np.pi / 180 13 y = np.arctanh(np.sin(lat * np.pi / 180)) 14 return x, y 15 16 17def _project_wgs84_to_latlon(x, y): 18 """ 19 Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map 20 """ 21 lon = x * 180 / np.pi 22 lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi 23 return lat, lon 24 25 26def _getBoundsZoomLevel(lon_min, lon_max, lat_min, lat_max, mapDim): 27 """ 28 Get the mapbox zoom level given bounds and a figure dimension 29 Source: https://stackoverflow.com/questions/6048975/google-maps-v3-how-to-calculate-the-zoom-level-for-a-given-bounds 30 """ 31 32 scale = ( 33 2 # adjustment to reflect MapBox base tiles are 512x512 vs. Google's 256x256 34 ) 35 WORLD_DIM = {"height": 256 * scale, "width": 256 * scale} 36 ZOOM_MAX = 18 37 38 def latRad(lat): 39 sin = np.sin(lat * np.pi / 180) 40 radX2 = np.log((1 + sin) / (1 - sin)) / 2 41 return max(min(radX2, np.pi), -np.pi) / 2 42 43 def zoom(mapPx, worldPx, fraction): 44 return 0.95 * np.log(mapPx / worldPx / fraction) / np.log(2) 45 46 latFraction = (latRad(lat_max) - latRad(lat_min)) / np.pi 47 48 lngDiff = lon_max - lon_min 49 lngFraction = ((lngDiff + 360) if lngDiff < 0 else lngDiff) / 360 50 51 latZoom = zoom(mapDim["height"], WORLD_DIM["height"], latFraction) 52 lngZoom = zoom(mapDim["width"], WORLD_DIM["width"], lngFraction) 53 54 return min(latZoom, lngZoom, ZOOM_MAX) 55 56 57def _compute_hexbin(x, y, x_range, y_range, color, nx, agg_func, min_count): 58 """ 59 Computes the aggregation at hexagonal bin level. 60 Also defines the coordinates of the hexagons for plotting. 61 The binning is inspired by matplotlib's implementation. 62 63 Parameters 64 ---------- 65 x : np.ndarray 66 Array of x values (shape N) 67 y : np.ndarray 68 Array of y values (shape N) 69 x_range : np.ndarray 70 Min and max x (shape 2) 71 y_range : np.ndarray 72 Min and max y (shape 2) 73 color : np.ndarray 74 Metric to aggregate at hexagon level (shape N) 75 nx : int 76 Number of hexagons horizontally 77 agg_func : function 78 Numpy compatible aggregator, this function must take a one-dimensional 79 np.ndarray as input and output a scalar 80 min_count : int 81 Minimum number of points in the hexagon for the hexagon to be displayed 82 83 Returns 84 ------- 85 np.ndarray 86 X coordinates of each hexagon (shape M x 6) 87 np.ndarray 88 Y coordinates of each hexagon (shape M x 6) 89 np.ndarray 90 Centers of the hexagons (shape M x 2) 91 np.ndarray 92 Aggregated value in each hexagon (shape M) 93 94 """ 95 xmin = x_range.min() 96 xmax = x_range.max() 97 ymin = y_range.min() 98 ymax = y_range.max() 99 100 # In the x-direction, the hexagons exactly cover the region from 101 # xmin to xmax. Need some padding to avoid roundoff errors. 102 padding = 1.0e-9 * (xmax - xmin) 103 xmin -= padding 104 xmax += padding 105 106 Dx = xmax - xmin 107 Dy = ymax - ymin 108 if Dx == 0 and Dy > 0: 109 dx = Dy / nx 110 elif Dx == 0 and Dy == 0: 111 dx, _ = _project_latlon_to_wgs84(1, 1) 112 else: 113 dx = Dx / nx 114 dy = dx * np.sqrt(3) 115 ny = np.ceil(Dy / dy).astype(int) 116 117 # Center the hexagons vertically since we only want regular hexagons 118 ymin -= (ymin + dy * ny - ymax) / 2 119 120 x = (x - xmin) / dx 121 y = (y - ymin) / dy 122 ix1 = np.round(x).astype(int) 123 iy1 = np.round(y).astype(int) 124 ix2 = np.floor(x).astype(int) 125 iy2 = np.floor(y).astype(int) 126 127 nx1 = nx + 1 128 ny1 = ny + 1 129 nx2 = nx 130 ny2 = ny 131 n = nx1 * ny1 + nx2 * ny2 132 133 d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2 134 d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2 135 bdist = d1 < d2 136 137 if color is None: 138 lattice1 = np.zeros((nx1, ny1)) 139 lattice2 = np.zeros((nx2, ny2)) 140 c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist 141 c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist 142 np.add.at(lattice1, (ix1[c1], iy1[c1]), 1) 143 np.add.at(lattice2, (ix2[c2], iy2[c2]), 1) 144 if min_count is not None: 145 lattice1[lattice1 < min_count] = np.nan 146 lattice2[lattice2 < min_count] = np.nan 147 accum = np.concatenate([lattice1.ravel(), lattice2.ravel()]) 148 good_idxs = ~np.isnan(accum) 149 else: 150 if min_count is None: 151 min_count = 1 152 153 # create accumulation arrays 154 lattice1 = np.empty((nx1, ny1), dtype=object) 155 for i in range(nx1): 156 for j in range(ny1): 157 lattice1[i, j] = [] 158 lattice2 = np.empty((nx2, ny2), dtype=object) 159 for i in range(nx2): 160 for j in range(ny2): 161 lattice2[i, j] = [] 162 163 for i in range(len(x)): 164 if bdist[i]: 165 if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1: 166 lattice1[ix1[i], iy1[i]].append(color[i]) 167 else: 168 if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2: 169 lattice2[ix2[i], iy2[i]].append(color[i]) 170 171 for i in range(nx1): 172 for j in range(ny1): 173 vals = lattice1[i, j] 174 if len(vals) >= min_count: 175 lattice1[i, j] = agg_func(vals) 176 else: 177 lattice1[i, j] = np.nan 178 for i in range(nx2): 179 for j in range(ny2): 180 vals = lattice2[i, j] 181 if len(vals) >= min_count: 182 lattice2[i, j] = agg_func(vals) 183 else: 184 lattice2[i, j] = np.nan 185 186 accum = np.hstack( 187 (lattice1.astype(float).ravel(), lattice2.astype(float).ravel()) 188 ) 189 good_idxs = ~np.isnan(accum) 190 191 agreggated_value = accum[good_idxs] 192 193 centers = np.zeros((n, 2), float) 194 centers[: nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1) 195 centers[: nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1) 196 centers[nx1 * ny1 :, 0] = np.repeat(np.arange(nx2) + 0.5, ny2) 197 centers[nx1 * ny1 :, 1] = np.tile(np.arange(ny2), nx2) + 0.5 198 centers[:, 0] *= dx 199 centers[:, 1] *= dy 200 centers[:, 0] += xmin 201 centers[:, 1] += ymin 202 centers = centers[good_idxs] 203 204 # Define normalised regular hexagon coordinates 205 hx = [0, 0.5, 0.5, 0, -0.5, -0.5] 206 hy = [ 207 -0.5 / np.cos(np.pi / 6), 208 -0.5 * np.tan(np.pi / 6), 209 0.5 * np.tan(np.pi / 6), 210 0.5 / np.cos(np.pi / 6), 211 0.5 * np.tan(np.pi / 6), 212 -0.5 * np.tan(np.pi / 6), 213 ] 214 215 # Number of hexagons needed 216 m = len(centers) 217 218 # Coordinates for all hexagonal patches 219 hxs = np.array([hx] * m) * dx + np.vstack(centers[:, 0]) 220 hys = np.array([hy] * m) * dy / np.sqrt(3) + np.vstack(centers[:, 1]) 221 222 return hxs, hys, centers, agreggated_value 223 224 225def _compute_wgs84_hexbin( 226 lat=None, 227 lon=None, 228 lat_range=None, 229 lon_range=None, 230 color=None, 231 nx=None, 232 agg_func=None, 233 min_count=None, 234): 235 """ 236 Computes the lat-lon aggregation at hexagonal bin level. 237 Latitude and longitude need to be projected to WGS84 before aggregating 238 in order to display regular hexagons on the map. 239 240 Parameters 241 ---------- 242 lat : np.ndarray 243 Array of latitudes (shape N) 244 lon : np.ndarray 245 Array of longitudes (shape N) 246 lat_range : np.ndarray 247 Min and max latitudes (shape 2) 248 lon_range : np.ndarray 249 Min and max longitudes (shape 2) 250 color : np.ndarray 251 Metric to aggregate at hexagon level (shape N) 252 nx : int 253 Number of hexagons horizontally 254 agg_func : function 255 Numpy compatible aggregator, this function must take a one-dimensional 256 np.ndarray as input and output a scalar 257 min_count : int 258 Minimum number of points in the hexagon for the hexagon to be displayed 259 260 Returns 261 ------- 262 np.ndarray 263 Lat coordinates of each hexagon (shape M x 6) 264 np.ndarray 265 Lon coordinates of each hexagon (shape M x 6) 266 pd.Series 267 Unique id for each hexagon, to be used in the geojson data (shape M) 268 np.ndarray 269 Aggregated value in each hexagon (shape M) 270 271 """ 272 # Project to WGS 84 273 x, y = _project_latlon_to_wgs84(lat, lon) 274 275 if lat_range is None: 276 lat_range = np.array([lat.min(), lat.max()]) 277 if lon_range is None: 278 lon_range = np.array([lon.min(), lon.max()]) 279 280 x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range) 281 282 hxs, hys, centers, agreggated_value = _compute_hexbin( 283 x, y, x_range, y_range, color, nx, agg_func, min_count 284 ) 285 286 # Convert back to lat-lon 287 hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys) 288 289 # Create unique feature id based on hexagon center 290 centers = centers.astype(str) 291 hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1]) 292 293 return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value 294 295 296def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None): 297 """ 298 Creates a geojson of hexagonal features based on the outputs of 299 _compute_wgs84_hexbin 300 """ 301 features = [] 302 if ids is None: 303 ids = np.arange(len(hexagons_lats)) 304 for lat, lon, idx in zip(hexagons_lats, hexagons_lons, ids): 305 points = np.array([lon, lat]).T.tolist() 306 points.append(points[0]) 307 features.append( 308 dict( 309 type="Feature", 310 id=idx, 311 geometry=dict(type="Polygon", coordinates=[points]), 312 ) 313 ) 314 return dict(type="FeatureCollection", features=features) 315 316 317def create_hexbin_mapbox( 318 data_frame=None, 319 lat=None, 320 lon=None, 321 color=None, 322 nx_hexagon=5, 323 agg_func=None, 324 animation_frame=None, 325 color_discrete_sequence=None, 326 color_discrete_map={}, 327 labels={}, 328 color_continuous_scale=None, 329 range_color=None, 330 color_continuous_midpoint=None, 331 opacity=None, 332 zoom=None, 333 center=None, 334 mapbox_style=None, 335 title=None, 336 template=None, 337 width=None, 338 height=None, 339 min_count=None, 340 show_original_data=False, 341 original_data_marker=None, 342): 343 """ 344 Returns a figure aggregating scattered points into connected hexagons 345 """ 346 args = build_dataframe(args=locals(), constructor=None) 347 348 if agg_func is None: 349 agg_func = np.mean 350 351 lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values 352 lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values 353 354 hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin( 355 lat=args["data_frame"][args["lat"]].values, 356 lon=args["data_frame"][args["lon"]].values, 357 lat_range=lat_range, 358 lon_range=lon_range, 359 color=None, 360 nx=nx_hexagon, 361 agg_func=agg_func, 362 min_count=min_count, 363 ) 364 365 geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids) 366 367 if zoom is None: 368 if height is None and width is None: 369 mapDim = dict(height=450, width=450) 370 elif height is None and width is not None: 371 mapDim = dict(height=450, width=width) 372 elif height is not None and width is None: 373 mapDim = dict(height=height, width=height) 374 else: 375 mapDim = dict(height=height, width=width) 376 zoom = _getBoundsZoomLevel( 377 lon_range[0], lon_range[1], lat_range[0], lat_range[1], mapDim 378 ) 379 380 if center is None: 381 center = dict(lat=lat_range.mean(), lon=lon_range.mean()) 382 383 if args["animation_frame"] is not None: 384 groups = args["data_frame"].groupby(args["animation_frame"]).groups 385 else: 386 groups = {0: args["data_frame"].index} 387 388 agg_data_frame_list = [] 389 for frame, index in groups.items(): 390 df = args["data_frame"].loc[index] 391 _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin( 392 lat=df[args["lat"]].values, 393 lon=df[args["lon"]].values, 394 lat_range=lat_range, 395 lon_range=lon_range, 396 color=df[args["color"]].values if args["color"] else None, 397 nx=nx_hexagon, 398 agg_func=agg_func, 399 min_count=min_count, 400 ) 401 agg_data_frame_list.append( 402 pd.DataFrame( 403 np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"] 404 ) 405 ) 406 agg_data_frame = ( 407 pd.concat(agg_data_frame_list, axis=0, keys=groups.keys()) 408 .rename_axis(index=("frame", "index")) 409 .reset_index("frame") 410 ) 411 412 agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"]) 413 414 if range_color is None: 415 range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()] 416 417 fig = choropleth_mapbox( 418 data_frame=agg_data_frame, 419 geojson=geojson, 420 locations="locations", 421 color="color", 422 hover_data={"color": True, "locations": False, "frame": False}, 423 animation_frame=("frame" if args["animation_frame"] is not None else None), 424 color_discrete_sequence=color_discrete_sequence, 425 color_discrete_map=color_discrete_map, 426 labels=labels, 427 color_continuous_scale=color_continuous_scale, 428 range_color=range_color, 429 color_continuous_midpoint=color_continuous_midpoint, 430 opacity=opacity, 431 zoom=zoom, 432 center=center, 433 mapbox_style=mapbox_style, 434 title=title, 435 template=template, 436 width=width, 437 height=height, 438 ) 439 440 if show_original_data: 441 original_fig = scatter_mapbox( 442 data_frame=( 443 args["data_frame"].sort_values(by=args["animation_frame"]) 444 if args["animation_frame"] is not None 445 else args["data_frame"] 446 ), 447 lat=args["lat"], 448 lon=args["lon"], 449 animation_frame=args["animation_frame"], 450 ) 451 original_fig.data[0].hoverinfo = "skip" 452 original_fig.data[0].hovertemplate = None 453 original_fig.data[0].marker = original_data_marker 454 455 fig.add_trace(original_fig.data[0]) 456 457 if args["animation_frame"] is not None: 458 for i in range(len(original_fig.frames)): 459 original_fig.frames[i].data[0].hoverinfo = "skip" 460 original_fig.frames[i].data[0].hovertemplate = None 461 original_fig.frames[i].data[0].marker = original_data_marker 462 463 fig.frames[i].data = [ 464 fig.frames[i].data[0], 465 original_fig.frames[i].data[0], 466 ] 467 468 return fig 469 470 471create_hexbin_mapbox.__doc__ = make_docstring( 472 create_hexbin_mapbox, 473 override_dict=dict( 474 nx_hexagon=["int", "Number of hexagons (horizontally) to be created"], 475 agg_func=[ 476 "function", 477 "Numpy array aggregator, it must take as input a 1D array", 478 "and output a scalar value.", 479 ], 480 min_count=[ 481 "int", 482 "Minimum number of points in a hexagon for it to be displayed.", 483 "If None and color is not set, display all hexagons.", 484 "If None and color is set, only display hexagons that contain points.", 485 ], 486 show_original_data=[ 487 "bool", 488 "Whether to show the original data on top of the hexbin aggregation.", 489 ], 490 original_data_marker=["dict", "Scattermapbox marker options."], 491 ), 492) 493