1"""Unit tests for altair API""" 2 3import io 4import json 5import operator 6import os 7import tempfile 8 9import jsonschema 10import pytest 11import pandas as pd 12 13import altair.vegalite.v4 as alt 14 15try: 16 import altair_saver # noqa: F401 17except ImportError: 18 altair_saver = None 19 20 21def getargs(*args, **kwargs): 22 return args, kwargs 23 24 25OP_DICT = { 26 "layer": operator.add, 27 "hconcat": operator.or_, 28 "vconcat": operator.and_, 29} 30 31 32def _make_chart_type(chart_type): 33 data = pd.DataFrame( 34 { 35 "x": [28, 55, 43, 91, 81, 53, 19, 87], 36 "y": [43, 91, 81, 53, 19, 87, 52, 28], 37 "color": list("AAAABBBB"), 38 } 39 ) 40 base = alt.Chart(data).mark_point().encode(x="x", y="y", color="color",) 41 42 if chart_type in ["layer", "hconcat", "vconcat", "concat"]: 43 func = getattr(alt, chart_type) 44 return func(base.mark_square(), base.mark_circle()) 45 elif chart_type == "facet": 46 return base.facet("color") 47 elif chart_type == "facet_encoding": 48 return base.encode(facet="color") 49 elif chart_type == "repeat": 50 return base.encode(alt.X(alt.repeat(), type="quantitative")).repeat(["x", "y"]) 51 elif chart_type == "chart": 52 return base 53 else: 54 raise ValueError("chart_type='{}' is not recognized".format(chart_type)) 55 56 57@pytest.fixture 58def basic_chart(): 59 data = pd.DataFrame( 60 { 61 "a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], 62 "b": [28, 55, 43, 91, 81, 53, 19, 87, 52], 63 } 64 ) 65 66 return alt.Chart(data).mark_bar().encode(x="a", y="b") 67 68 69def test_chart_data_types(): 70 def Chart(data): 71 return alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q") 72 73 # Url Data 74 data = "/path/to/my/data.csv" 75 dct = Chart(data).to_dict() 76 assert dct["data"] == {"url": data} 77 78 # Dict Data 79 data = {"values": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]} 80 with alt.data_transformers.enable(consolidate_datasets=False): 81 dct = Chart(data).to_dict() 82 assert dct["data"] == data 83 84 with alt.data_transformers.enable(consolidate_datasets=True): 85 dct = Chart(data).to_dict() 86 name = dct["data"]["name"] 87 assert dct["datasets"][name] == data["values"] 88 89 # DataFrame data 90 data = pd.DataFrame({"x": range(5), "y": range(5)}) 91 with alt.data_transformers.enable(consolidate_datasets=False): 92 dct = Chart(data).to_dict() 93 assert dct["data"]["values"] == data.to_dict(orient="records") 94 95 with alt.data_transformers.enable(consolidate_datasets=True): 96 dct = Chart(data).to_dict() 97 name = dct["data"]["name"] 98 assert dct["datasets"][name] == data.to_dict(orient="records") 99 100 # Named data object 101 data = alt.NamedData(name="Foo") 102 dct = Chart(data).to_dict() 103 assert dct["data"] == {"name": "Foo"} 104 105 106def test_chart_infer_types(): 107 data = pd.DataFrame( 108 { 109 "x": pd.date_range("2012", periods=10, freq="Y"), 110 "y": range(10), 111 "c": list("abcabcabca"), 112 } 113 ) 114 115 def _check_encodings(chart): 116 dct = chart.to_dict() 117 assert dct["encoding"]["x"]["type"] == "temporal" 118 assert dct["encoding"]["x"]["field"] == "x" 119 assert dct["encoding"]["y"]["type"] == "quantitative" 120 assert dct["encoding"]["y"]["field"] == "y" 121 assert dct["encoding"]["color"]["type"] == "nominal" 122 assert dct["encoding"]["color"]["field"] == "c" 123 124 # Pass field names by keyword 125 chart = alt.Chart(data).mark_point().encode(x="x", y="y", color="c") 126 _check_encodings(chart) 127 128 # pass Channel objects by keyword 129 chart = ( 130 alt.Chart(data) 131 .mark_point() 132 .encode(x=alt.X("x"), y=alt.Y("y"), color=alt.Color("c")) 133 ) 134 _check_encodings(chart) 135 136 # pass Channel objects by value 137 chart = alt.Chart(data).mark_point().encode(alt.X("x"), alt.Y("y"), alt.Color("c")) 138 _check_encodings(chart) 139 140 # override default types 141 chart = ( 142 alt.Chart(data) 143 .mark_point() 144 .encode(alt.X("x", type="nominal"), alt.Y("y", type="ordinal")) 145 ) 146 dct = chart.to_dict() 147 assert dct["encoding"]["x"]["type"] == "nominal" 148 assert dct["encoding"]["y"]["type"] == "ordinal" 149 150 151@pytest.mark.parametrize( 152 "args, kwargs", 153 [ 154 getargs(detail=["value:Q", "name:N"], tooltip=["value:Q", "name:N"]), 155 getargs(detail=["value", "name"], tooltip=["value", "name"]), 156 getargs(alt.Detail(["value:Q", "name:N"]), alt.Tooltip(["value:Q", "name:N"])), 157 getargs(alt.Detail(["value", "name"]), alt.Tooltip(["value", "name"])), 158 getargs( 159 [alt.Detail("value:Q"), alt.Detail("name:N")], 160 [alt.Tooltip("value:Q"), alt.Tooltip("name:N")], 161 ), 162 getargs( 163 [alt.Detail("value"), alt.Detail("name")], 164 [alt.Tooltip("value"), alt.Tooltip("name")], 165 ), 166 ], 167) 168def test_multiple_encodings(args, kwargs): 169 df = pd.DataFrame({"value": [1, 2, 3], "name": ["A", "B", "C"]}) 170 encoding_dct = [ 171 {"field": "value", "type": "quantitative"}, 172 {"field": "name", "type": "nominal"}, 173 ] 174 chart = alt.Chart(df).mark_point().encode(*args, **kwargs) 175 dct = chart.to_dict() 176 assert dct["encoding"]["detail"] == encoding_dct 177 assert dct["encoding"]["tooltip"] == encoding_dct 178 179 180def test_chart_operations(): 181 data = pd.DataFrame( 182 { 183 "x": pd.date_range("2012", periods=10, freq="Y"), 184 "y": range(10), 185 "c": list("abcabcabca"), 186 } 187 ) 188 chart1 = alt.Chart(data).mark_line().encode(x="x", y="y", color="c") 189 chart2 = chart1.mark_point() 190 chart3 = chart1.mark_circle() 191 chart4 = chart1.mark_square() 192 193 chart = chart1 + chart2 + chart3 194 assert isinstance(chart, alt.LayerChart) 195 assert len(chart.layer) == 3 196 chart += chart4 197 assert len(chart.layer) == 4 198 199 chart = chart1 | chart2 | chart3 200 assert isinstance(chart, alt.HConcatChart) 201 assert len(chart.hconcat) == 3 202 chart |= chart4 203 assert len(chart.hconcat) == 4 204 205 chart = chart1 & chart2 & chart3 206 assert isinstance(chart, alt.VConcatChart) 207 assert len(chart.vconcat) == 3 208 chart &= chart4 209 assert len(chart.vconcat) == 4 210 211 212def test_selection_to_dict(): 213 brush = alt.selection(type="interval") 214 215 # test some value selections 216 # Note: X and Y cannot have conditions 217 alt.Chart("path/to/data.json").mark_point().encode( 218 color=alt.condition(brush, alt.ColorValue("red"), alt.ColorValue("blue")), 219 opacity=alt.condition(brush, alt.value(0.5), alt.value(1.0)), 220 text=alt.condition(brush, alt.TextValue("foo"), alt.value("bar")), 221 ).to_dict() 222 223 # test some field selections 224 # Note: X and Y cannot have conditions 225 # Conditions cannot both be fields 226 alt.Chart("path/to/data.json").mark_point().encode( 227 color=alt.condition(brush, alt.Color("col1:N"), alt.value("blue")), 228 opacity=alt.condition(brush, "col1:N", alt.value(0.5)), 229 text=alt.condition(brush, alt.value("abc"), alt.Text("col2:N")), 230 size=alt.condition(brush, alt.value(20), "col2:N"), 231 ).to_dict() 232 233 234def test_selection_expression(): 235 selection = alt.selection_single(fields=["value"]) 236 237 assert isinstance(selection.value, alt.expr.Expression) 238 assert selection.value.to_dict() == "{0}.value".format(selection.name) 239 240 assert isinstance(selection["value"], alt.expr.Expression) 241 assert selection["value"].to_dict() == "{0}['value']".format(selection.name) 242 243 244@pytest.mark.parametrize("format", ["html", "json", "png", "svg"]) 245def test_save(format, basic_chart): 246 if format == "png": 247 out = io.BytesIO() 248 mode = "rb" 249 else: 250 out = io.StringIO() 251 mode = "r" 252 253 if format in ["svg", "png"] and not altair_saver: 254 with pytest.raises(ValueError) as err: 255 basic_chart.save(out, format=format) 256 assert "github.com/altair-viz/altair_saver" in str(err.value) 257 return 258 259 basic_chart.save(out, format=format) 260 out.seek(0) 261 content = out.read() 262 263 if format == "json": 264 assert "$schema" in json.loads(content) 265 if format == "html": 266 assert content.startswith("<!DOCTYPE html>") 267 268 fid, filename = tempfile.mkstemp(suffix="." + format) 269 os.close(fid) 270 271 try: 272 basic_chart.save(filename) 273 with open(filename, mode) as f: 274 assert f.read() == content 275 finally: 276 os.remove(filename) 277 278 279def test_facet_basic(): 280 # wrapped facet 281 chart1 = ( 282 alt.Chart("data.csv") 283 .mark_point() 284 .encode(x="x:Q", y="y:Q",) 285 .facet("category:N", columns=2) 286 ) 287 288 dct1 = chart1.to_dict() 289 290 assert dct1["facet"] == alt.Facet("category:N").to_dict() 291 assert dct1["columns"] == 2 292 assert dct1["data"] == alt.UrlData("data.csv").to_dict() 293 294 # explicit row/col facet 295 chart2 = ( 296 alt.Chart("data.csv") 297 .mark_point() 298 .encode(x="x:Q", y="y:Q",) 299 .facet(row="category1:Q", column="category2:Q") 300 ) 301 302 dct2 = chart2.to_dict() 303 304 assert dct2["facet"]["row"] == alt.Facet("category1:Q").to_dict() 305 assert dct2["facet"]["column"] == alt.Facet("category2:Q").to_dict() 306 assert "columns" not in dct2 307 assert dct2["data"] == alt.UrlData("data.csv").to_dict() 308 309 310def test_facet_parse(): 311 chart = ( 312 alt.Chart("data.csv") 313 .mark_point() 314 .encode(x="x:Q", y="y:Q") 315 .facet(row="row:N", column="column:O") 316 ) 317 dct = chart.to_dict() 318 assert dct["data"] == {"url": "data.csv"} 319 assert "data" not in dct["spec"] 320 assert dct["facet"] == { 321 "column": {"field": "column", "type": "ordinal"}, 322 "row": {"field": "row", "type": "nominal"}, 323 } 324 325 326def test_facet_parse_data(): 327 data = pd.DataFrame({"x": range(5), "y": range(5), "row": list("abcab")}) 328 chart = ( 329 alt.Chart(data) 330 .mark_point() 331 .encode(x="x", y="y:O") 332 .facet(row="row", column="column:O") 333 ) 334 with alt.data_transformers.enable(consolidate_datasets=False): 335 dct = chart.to_dict() 336 assert "values" in dct["data"] 337 assert "data" not in dct["spec"] 338 assert dct["facet"] == { 339 "column": {"field": "column", "type": "ordinal"}, 340 "row": {"field": "row", "type": "nominal"}, 341 } 342 343 with alt.data_transformers.enable(consolidate_datasets=True): 344 dct = chart.to_dict() 345 assert "datasets" in dct 346 assert "name" in dct["data"] 347 assert "data" not in dct["spec"] 348 assert dct["facet"] == { 349 "column": {"field": "column", "type": "ordinal"}, 350 "row": {"field": "row", "type": "nominal"}, 351 } 352 353 354def test_selection(): 355 # test instantiation of selections 356 interval = alt.selection_interval(name="selec_1") 357 assert interval.selection.type == "interval" 358 assert interval.name == "selec_1" 359 360 single = alt.selection_single(name="selec_2") 361 assert single.selection.type == "single" 362 assert single.name == "selec_2" 363 364 multi = alt.selection_multi(name="selec_3") 365 assert multi.selection.type == "multi" 366 assert multi.name == "selec_3" 367 368 # test adding to chart 369 chart = alt.Chart().add_selection(single) 370 chart = chart.add_selection(multi, interval) 371 assert set(chart.selection.keys()) == {"selec_1", "selec_2", "selec_3"} 372 373 # test logical operations 374 assert isinstance(single & multi, alt.Selection) 375 assert isinstance(single | multi, alt.Selection) 376 assert isinstance(~single, alt.Selection) 377 assert isinstance((single & multi)[0].group, alt.SelectionAnd) 378 assert isinstance((single | multi)[0].group, alt.SelectionOr) 379 assert isinstance((~single)[0].group, alt.SelectionNot) 380 381 # test that default names increment (regression for #1454) 382 sel1 = alt.selection_single() 383 sel2 = alt.selection_multi() 384 sel3 = alt.selection_interval() 385 names = {s.name for s in (sel1, sel2, sel3)} 386 assert len(names) == 3 387 388 389def test_transforms(): 390 # aggregate transform 391 agg1 = alt.AggregatedFieldDef(**{"as": "x1", "op": "mean", "field": "y"}) 392 agg2 = alt.AggregatedFieldDef(**{"as": "x2", "op": "median", "field": "z"}) 393 chart = alt.Chart().transform_aggregate([agg1], ["foo"], x2="median(z)") 394 kwds = dict(aggregate=[agg1, agg2], groupby=["foo"]) 395 assert chart.transform == [alt.AggregateTransform(**kwds)] 396 397 # bin transform 398 chart = alt.Chart().transform_bin("binned", field="field", bin=True) 399 kwds = {"as": "binned", "field": "field", "bin": True} 400 assert chart.transform == [alt.BinTransform(**kwds)] 401 402 # calcualte transform 403 chart = alt.Chart().transform_calculate("calc", "datum.a * 4") 404 kwds = {"as": "calc", "calculate": "datum.a * 4"} 405 assert chart.transform == [alt.CalculateTransform(**kwds)] 406 407 # density transform 408 chart = alt.Chart().transform_density("x", as_=["value", "density"]) 409 kwds = {"as": ["value", "density"], "density": "x"} 410 assert chart.transform == [alt.DensityTransform(**kwds)] 411 412 # filter transform 413 chart = alt.Chart().transform_filter("datum.a < 4") 414 assert chart.transform == [alt.FilterTransform(filter="datum.a < 4")] 415 416 # flatten transform 417 chart = alt.Chart().transform_flatten(["A", "B"], ["X", "Y"]) 418 kwds = {"as": ["X", "Y"], "flatten": ["A", "B"]} 419 assert chart.transform == [alt.FlattenTransform(**kwds)] 420 421 # fold transform 422 chart = alt.Chart().transform_fold(["A", "B", "C"], as_=["key", "val"]) 423 kwds = {"as": ["key", "val"], "fold": ["A", "B", "C"]} 424 assert chart.transform == [alt.FoldTransform(**kwds)] 425 426 # impute transform 427 chart = alt.Chart().transform_impute("field", "key", groupby=["x"]) 428 kwds = {"impute": "field", "key": "key", "groupby": ["x"]} 429 assert chart.transform == [alt.ImputeTransform(**kwds)] 430 431 # joinaggregate transform 432 chart = alt.Chart().transform_joinaggregate(min="min(x)", groupby=["key"]) 433 kwds = { 434 "joinaggregate": [ 435 alt.JoinAggregateFieldDef(field="x", op="min", **{"as": "min"}) 436 ], 437 "groupby": ["key"], 438 } 439 assert chart.transform == [alt.JoinAggregateTransform(**kwds)] 440 441 # loess transform 442 chart = alt.Chart().transform_loess("x", "y", as_=["xx", "yy"]) 443 kwds = {"on": "x", "loess": "y", "as": ["xx", "yy"]} 444 assert chart.transform == [alt.LoessTransform(**kwds)] 445 446 # lookup transform (data) 447 lookup_data = alt.LookupData(alt.UrlData("foo.csv"), "id", ["rate"]) 448 chart = alt.Chart().transform_lookup("a", from_=lookup_data, as_="a", default="b") 449 kwds = {"from": lookup_data, "as": "a", "lookup": "a", "default": "b"} 450 assert chart.transform == [alt.LookupTransform(**kwds)] 451 452 # lookup transform (selection) 453 lookup_selection = alt.LookupSelection(key="key", selection="sel") 454 chart = alt.Chart().transform_lookup( 455 "a", from_=lookup_selection, as_="a", default="b" 456 ) 457 kwds = {"from": lookup_selection, "as": "a", "lookup": "a", "default": "b"} 458 assert chart.transform == [alt.LookupTransform(**kwds)] 459 460 # pivot transform 461 chart = alt.Chart().transform_pivot("x", "y") 462 assert chart.transform == [alt.PivotTransform(pivot="x", value="y")] 463 464 # quantile transform 465 chart = alt.Chart().transform_quantile("x", as_=["prob", "value"]) 466 kwds = {"quantile": "x", "as": ["prob", "value"]} 467 assert chart.transform == [alt.QuantileTransform(**kwds)] 468 469 # regression transform 470 chart = alt.Chart().transform_regression("x", "y", as_=["xx", "yy"]) 471 kwds = {"on": "x", "regression": "y", "as": ["xx", "yy"]} 472 assert chart.transform == [alt.RegressionTransform(**kwds)] 473 474 # sample transform 475 chart = alt.Chart().transform_sample() 476 assert chart.transform == [alt.SampleTransform(1000)] 477 478 # stack transform 479 chart = alt.Chart().transform_stack("stacked", "x", groupby=["y"]) 480 assert chart.transform == [ 481 alt.StackTransform(stack="x", groupby=["y"], **{"as": "stacked"}) 482 ] 483 484 # timeUnit transform 485 chart = alt.Chart().transform_timeunit("foo", field="x", timeUnit="date") 486 kwds = {"as": "foo", "field": "x", "timeUnit": "date"} 487 assert chart.transform == [alt.TimeUnitTransform(**kwds)] 488 489 # window transform 490 chart = alt.Chart().transform_window(xsum="sum(x)", ymin="min(y)", frame=[None, 0]) 491 window = [ 492 alt.WindowFieldDef(**{"as": "xsum", "field": "x", "op": "sum"}), 493 alt.WindowFieldDef(**{"as": "ymin", "field": "y", "op": "min"}), 494 ] 495 496 # kwargs don't maintain order in Python < 3.6, so window list can 497 # be reversed 498 assert chart.transform == [ 499 alt.WindowTransform(frame=[None, 0], window=window) 500 ] or chart.transform == [alt.WindowTransform(frame=[None, 0], window=window[::-1])] 501 502 503def test_filter_transform_selection_predicates(): 504 selector1 = alt.selection_interval(name="s1") 505 selector2 = alt.selection_interval(name="s2") 506 base = alt.Chart("data.txt").mark_point() 507 508 chart = base.transform_filter(selector1) 509 assert chart.to_dict()["transform"] == [{"filter": {"selection": "s1"}}] 510 511 chart = base.transform_filter(~selector1) 512 assert chart.to_dict()["transform"] == [{"filter": {"selection": {"not": "s1"}}}] 513 514 chart = base.transform_filter(selector1 & selector2) 515 assert chart.to_dict()["transform"] == [ 516 {"filter": {"selection": {"and": ["s1", "s2"]}}} 517 ] 518 519 chart = base.transform_filter(selector1 | selector2) 520 assert chart.to_dict()["transform"] == [ 521 {"filter": {"selection": {"or": ["s1", "s2"]}}} 522 ] 523 524 chart = base.transform_filter(selector1 | ~selector2) 525 assert chart.to_dict()["transform"] == [ 526 {"filter": {"selection": {"or": ["s1", {"not": "s2"}]}}} 527 ] 528 529 chart = base.transform_filter(~selector1 | ~selector2) 530 assert chart.to_dict()["transform"] == [ 531 {"filter": {"selection": {"or": [{"not": "s1"}, {"not": "s2"}]}}} 532 ] 533 534 chart = base.transform_filter(~(selector1 & selector2)) 535 assert chart.to_dict()["transform"] == [ 536 {"filter": {"selection": {"not": {"and": ["s1", "s2"]}}}} 537 ] 538 539 540def test_resolve_methods(): 541 chart = alt.LayerChart().resolve_axis(x="shared", y="independent") 542 assert chart.resolve == alt.Resolve( 543 axis=alt.AxisResolveMap(x="shared", y="independent") 544 ) 545 546 chart = alt.LayerChart().resolve_legend(color="shared", fill="independent") 547 assert chart.resolve == alt.Resolve( 548 legend=alt.LegendResolveMap(color="shared", fill="independent") 549 ) 550 551 chart = alt.LayerChart().resolve_scale(x="shared", y="independent") 552 assert chart.resolve == alt.Resolve( 553 scale=alt.ScaleResolveMap(x="shared", y="independent") 554 ) 555 556 557def test_layer_encodings(): 558 chart = alt.LayerChart().encode(x="column:Q") 559 assert chart.encoding.x == alt.X(shorthand="column:Q") 560 561 562def test_add_selection(): 563 selections = [ 564 alt.selection_interval(), 565 alt.selection_single(), 566 alt.selection_multi(), 567 ] 568 chart = ( 569 alt.Chart() 570 .mark_point() 571 .add_selection(selections[0]) 572 .add_selection(selections[1], selections[2]) 573 ) 574 expected = {s.name: s.selection for s in selections} 575 assert chart.selection == expected 576 577 578def test_repeat_add_selections(): 579 base = alt.Chart("data.csv").mark_point() 580 selection = alt.selection_single() 581 chart1 = base.add_selection(selection).repeat(list("ABC")) 582 chart2 = base.repeat(list("ABC")).add_selection(selection) 583 assert chart1.to_dict() == chart2.to_dict() 584 585 586def test_facet_add_selections(): 587 base = alt.Chart("data.csv").mark_point() 588 selection = alt.selection_single() 589 chart1 = base.add_selection(selection).facet("val:Q") 590 chart2 = base.facet("val:Q").add_selection(selection) 591 assert chart1.to_dict() == chart2.to_dict() 592 593 594def test_layer_add_selection(): 595 base = alt.Chart("data.csv").mark_point() 596 selection = alt.selection_single() 597 chart1 = alt.layer(base.add_selection(selection), base) 598 chart2 = alt.layer(base, base).add_selection(selection) 599 assert chart1.to_dict() == chart2.to_dict() 600 601 602@pytest.mark.parametrize("charttype", [alt.concat, alt.hconcat, alt.vconcat]) 603def test_compound_add_selections(charttype): 604 base = alt.Chart("data.csv").mark_point() 605 selection = alt.selection_single() 606 chart1 = charttype(base.add_selection(selection), base.add_selection(selection)) 607 chart2 = charttype(base, base).add_selection(selection) 608 assert chart1.to_dict() == chart2.to_dict() 609 610 611def test_selection_property(): 612 sel = alt.selection_interval() 613 chart = alt.Chart("data.csv").mark_point().properties(selection=sel) 614 615 assert list(chart["selection"].keys()) == [sel.name] 616 617 618def test_LookupData(): 619 df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) 620 lookup = alt.LookupData(data=df, key="x") 621 622 dct = lookup.to_dict() 623 assert dct["key"] == "x" 624 assert dct["data"] == { 625 "values": [{"x": 1, "y": 4}, {"x": 2, "y": 5}, {"x": 3, "y": 6}] 626 } 627 628 629def test_themes(): 630 chart = alt.Chart("foo.txt").mark_point() 631 632 with alt.themes.enable("default"): 633 assert chart.to_dict()["config"] == { 634 "view": {"continuousWidth": 400, "continuousHeight": 300} 635 } 636 637 with alt.themes.enable("opaque"): 638 assert chart.to_dict()["config"] == { 639 "background": "white", 640 "view": {"continuousWidth": 400, "continuousHeight": 300}, 641 } 642 643 with alt.themes.enable("none"): 644 assert "config" not in chart.to_dict() 645 646 647def test_chart_from_dict(): 648 base = alt.Chart("data.csv").mark_point().encode(x="x:Q", y="y:Q") 649 650 charts = [ 651 base, 652 base + base, 653 base | base, 654 base & base, 655 base.facet("c:N"), 656 (base + base).facet(row="c:N", data="data.csv"), 657 base.repeat(["c", "d"]), 658 (base + base).repeat(row=["c", "d"]), 659 ] 660 661 for chart in charts: 662 chart_out = alt.Chart.from_dict(chart.to_dict()) 663 assert type(chart_out) is type(chart) 664 665 # test that an invalid spec leads to a schema validation error 666 with pytest.raises(jsonschema.ValidationError): 667 alt.Chart.from_dict({"invalid": "spec"}) 668 669 670def test_consolidate_datasets(basic_chart): 671 subchart1 = basic_chart 672 subchart2 = basic_chart.copy() 673 subchart2.data = basic_chart.data.copy() 674 chart = subchart1 | subchart2 675 676 with alt.data_transformers.enable(consolidate_datasets=True): 677 dct_consolidated = chart.to_dict() 678 679 with alt.data_transformers.enable(consolidate_datasets=False): 680 dct_standard = chart.to_dict() 681 682 assert "datasets" in dct_consolidated 683 assert "datasets" not in dct_standard 684 685 datasets = dct_consolidated["datasets"] 686 687 # two dataset copies should be recognized as duplicates 688 assert len(datasets) == 1 689 690 # make sure data matches original & names are correct 691 name, data = datasets.popitem() 692 693 for spec in dct_standard["hconcat"]: 694 assert spec["data"]["values"] == data 695 696 for spec in dct_consolidated["hconcat"]: 697 assert spec["data"] == {"name": name} 698 699 700def test_consolidate_InlineData(): 701 data = alt.InlineData( 702 values=[{"a": 1, "b": 1}, {"a": 2, "b": 2}], format={"type": "csv"} 703 ) 704 chart = alt.Chart(data).mark_point() 705 706 with alt.data_transformers.enable(consolidate_datasets=False): 707 dct = chart.to_dict() 708 assert dct["data"]["format"] == data.format 709 assert dct["data"]["values"] == data.values 710 711 with alt.data_transformers.enable(consolidate_datasets=True): 712 dct = chart.to_dict() 713 assert dct["data"]["format"] == data.format 714 assert list(dct["datasets"].values())[0] == data.values 715 716 data = alt.InlineData(values=[], name="runtime_data") 717 chart = alt.Chart(data).mark_point() 718 719 with alt.data_transformers.enable(consolidate_datasets=False): 720 dct = chart.to_dict() 721 assert dct["data"] == data.to_dict() 722 723 with alt.data_transformers.enable(consolidate_datasets=True): 724 dct = chart.to_dict() 725 assert dct["data"] == data.to_dict() 726 727 728def test_repeat(): 729 # wrapped repeat 730 chart1 = ( 731 alt.Chart("data.csv") 732 .mark_point() 733 .encode(x=alt.X(alt.repeat(), type="quantitative"), y="y:Q",) 734 .repeat(["A", "B", "C", "D"], columns=2) 735 ) 736 737 dct1 = chart1.to_dict() 738 739 assert dct1["repeat"] == ["A", "B", "C", "D"] 740 assert dct1["columns"] == 2 741 assert dct1["spec"]["encoding"]["x"]["field"] == {"repeat": "repeat"} 742 743 # explicit row/col repeat 744 chart2 = ( 745 alt.Chart("data.csv") 746 .mark_point() 747 .encode( 748 x=alt.X(alt.repeat("row"), type="quantitative"), 749 y=alt.Y(alt.repeat("column"), type="quantitative"), 750 ) 751 .repeat(row=["A", "B", "C"], column=["C", "B", "A"]) 752 ) 753 754 dct2 = chart2.to_dict() 755 756 assert dct2["repeat"] == {"row": ["A", "B", "C"], "column": ["C", "B", "A"]} 757 assert "columns" not in dct2 758 assert dct2["spec"]["encoding"]["x"]["field"] == {"repeat": "row"} 759 assert dct2["spec"]["encoding"]["y"]["field"] == {"repeat": "column"} 760 761 762def test_data_property(): 763 data = pd.DataFrame({"x": [1, 2, 3], "y": list("ABC")}) 764 chart1 = alt.Chart(data).mark_point() 765 chart2 = alt.Chart().mark_point().properties(data=data) 766 767 assert chart1.to_dict() == chart2.to_dict() 768 769 770@pytest.mark.parametrize("method", ["layer", "hconcat", "vconcat", "concat"]) 771@pytest.mark.parametrize( 772 "data", ["data.json", pd.DataFrame({"x": range(3), "y": list("abc")})] 773) 774def test_subcharts_with_same_data(method, data): 775 func = getattr(alt, method) 776 777 point = alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q") 778 line = point.mark_line() 779 text = point.mark_text() 780 781 chart1 = func(point, line, text) 782 assert chart1.data is not alt.Undefined 783 assert all(c.data is alt.Undefined for c in getattr(chart1, method)) 784 785 if method != "concat": 786 op = OP_DICT[method] 787 chart2 = op(op(point, line), text) 788 assert chart2.data is not alt.Undefined 789 assert all(c.data is alt.Undefined for c in getattr(chart2, method)) 790 791 792@pytest.mark.parametrize("method", ["layer", "hconcat", "vconcat", "concat"]) 793@pytest.mark.parametrize( 794 "data", ["data.json", pd.DataFrame({"x": range(3), "y": list("abc")})] 795) 796def test_subcharts_different_data(method, data): 797 func = getattr(alt, method) 798 799 point = alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q") 800 otherdata = alt.Chart("data.csv").mark_point().encode(x="x:Q", y="y:Q") 801 nodata = alt.Chart().mark_point().encode(x="x:Q", y="y:Q") 802 803 chart1 = func(point, otherdata) 804 assert chart1.data is alt.Undefined 805 assert getattr(chart1, method)[0].data is data 806 807 chart2 = func(point, nodata) 808 assert chart2.data is alt.Undefined 809 assert getattr(chart2, method)[0].data is data 810 811 812def test_layer_facet(basic_chart): 813 chart = (basic_chart + basic_chart).facet(row="row:Q") 814 assert chart.data is not alt.Undefined 815 assert chart.spec.data is alt.Undefined 816 for layer in chart.spec.layer: 817 assert layer.data is alt.Undefined 818 819 dct = chart.to_dict() 820 assert "data" in dct 821 822 823def test_layer_errors(): 824 toplevel_chart = alt.Chart("data.txt").mark_point().configure_legend(columns=2) 825 826 facet_chart1 = alt.Chart("data.txt").mark_point().encode(facet="row:Q") 827 828 facet_chart2 = alt.Chart("data.txt").mark_point().facet("row:Q") 829 830 repeat_chart = alt.Chart("data.txt").mark_point().repeat(["A", "B", "C"]) 831 832 simple_chart = alt.Chart("data.txt").mark_point() 833 834 with pytest.raises(ValueError) as err: 835 toplevel_chart + simple_chart 836 assert str(err.value).startswith( 837 'Objects with "config" attribute cannot be used within LayerChart.' 838 ) 839 840 with pytest.raises(ValueError) as err: 841 repeat_chart + simple_chart 842 assert str(err.value) == "Repeat charts cannot be layered." 843 844 with pytest.raises(ValueError) as err: 845 facet_chart1 + simple_chart 846 assert str(err.value) == "Faceted charts cannot be layered." 847 848 with pytest.raises(ValueError) as err: 849 alt.layer(simple_chart) + facet_chart2 850 assert str(err.value) == "Faceted charts cannot be layered." 851 852 853@pytest.mark.parametrize( 854 "chart_type", 855 ["layer", "hconcat", "vconcat", "concat", "facet", "facet_encoding", "repeat"], 856) 857def test_resolve(chart_type): 858 chart = _make_chart_type(chart_type) 859 chart = ( 860 chart.resolve_scale(x="independent",) 861 .resolve_legend(color="independent") 862 .resolve_axis(y="independent") 863 ) 864 dct = chart.to_dict() 865 assert dct["resolve"] == { 866 "scale": {"x": "independent"}, 867 "legend": {"color": "independent"}, 868 "axis": {"y": "independent"}, 869 } 870 871 872# TODO: test vconcat, hconcat, concat, facet_encoding when schema allows them. 873# This is blocked by https://github.com/vega/vega-lite/issues/5261 874@pytest.mark.parametrize("chart_type", ["chart", "layer"]) 875@pytest.mark.parametrize("facet_arg", [None, "facet", "row", "column"]) 876def test_facet(chart_type, facet_arg): 877 chart = _make_chart_type(chart_type) 878 if facet_arg is None: 879 chart = chart.facet("color:N", columns=2) 880 else: 881 chart = chart.facet(**{facet_arg: "color:N", "columns": 2}) 882 dct = chart.to_dict() 883 884 assert "spec" in dct 885 assert dct["columns"] == 2 886 expected = {"field": "color", "type": "nominal"} 887 if facet_arg is None or facet_arg == "facet": 888 assert dct["facet"] == expected 889 else: 890 assert dct["facet"][facet_arg] == expected 891 892 893def test_sequence(): 894 data = alt.sequence(100) 895 assert data.to_dict() == {"sequence": {"start": 0, "stop": 100}} 896 897 data = alt.sequence(5, 10) 898 assert data.to_dict() == {"sequence": {"start": 5, "stop": 10}} 899 900 data = alt.sequence(0, 1, 0.1, as_="x") 901 assert data.to_dict() == { 902 "sequence": {"start": 0, "stop": 1, "step": 0.1, "as": "x"} 903 } 904 905 906def test_graticule(): 907 data = alt.graticule() 908 assert data.to_dict() == {"graticule": True} 909 910 data = alt.graticule(step=[15, 15]) 911 assert data.to_dict() == {"graticule": {"step": [15, 15]}} 912 913 914def test_sphere(): 915 data = alt.sphere() 916 assert data.to_dict() == {"sphere": True} 917