1"""Unit tests for altair API"""
2
3import io
4import json
5import operator
6import os
7import tempfile
8
9import jsonschema
10import pytest
11import pandas as pd
12
13import altair.vegalite.v4 as alt
14
15try:
16    import altair_saver  # noqa: F401
17except ImportError:
18    altair_saver = None
19
20
21def getargs(*args, **kwargs):
22    return args, kwargs
23
24
25OP_DICT = {
26    "layer": operator.add,
27    "hconcat": operator.or_,
28    "vconcat": operator.and_,
29}
30
31
32def _make_chart_type(chart_type):
33    data = pd.DataFrame(
34        {
35            "x": [28, 55, 43, 91, 81, 53, 19, 87],
36            "y": [43, 91, 81, 53, 19, 87, 52, 28],
37            "color": list("AAAABBBB"),
38        }
39    )
40    base = alt.Chart(data).mark_point().encode(x="x", y="y", color="color",)
41
42    if chart_type in ["layer", "hconcat", "vconcat", "concat"]:
43        func = getattr(alt, chart_type)
44        return func(base.mark_square(), base.mark_circle())
45    elif chart_type == "facet":
46        return base.facet("color")
47    elif chart_type == "facet_encoding":
48        return base.encode(facet="color")
49    elif chart_type == "repeat":
50        return base.encode(alt.X(alt.repeat(), type="quantitative")).repeat(["x", "y"])
51    elif chart_type == "chart":
52        return base
53    else:
54        raise ValueError("chart_type='{}' is not recognized".format(chart_type))
55
56
57@pytest.fixture
58def basic_chart():
59    data = pd.DataFrame(
60        {
61            "a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
62            "b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
63        }
64    )
65
66    return alt.Chart(data).mark_bar().encode(x="a", y="b")
67
68
69def test_chart_data_types():
70    def Chart(data):
71        return alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q")
72
73    # Url Data
74    data = "/path/to/my/data.csv"
75    dct = Chart(data).to_dict()
76    assert dct["data"] == {"url": data}
77
78    # Dict Data
79    data = {"values": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]}
80    with alt.data_transformers.enable(consolidate_datasets=False):
81        dct = Chart(data).to_dict()
82    assert dct["data"] == data
83
84    with alt.data_transformers.enable(consolidate_datasets=True):
85        dct = Chart(data).to_dict()
86    name = dct["data"]["name"]
87    assert dct["datasets"][name] == data["values"]
88
89    # DataFrame data
90    data = pd.DataFrame({"x": range(5), "y": range(5)})
91    with alt.data_transformers.enable(consolidate_datasets=False):
92        dct = Chart(data).to_dict()
93    assert dct["data"]["values"] == data.to_dict(orient="records")
94
95    with alt.data_transformers.enable(consolidate_datasets=True):
96        dct = Chart(data).to_dict()
97    name = dct["data"]["name"]
98    assert dct["datasets"][name] == data.to_dict(orient="records")
99
100    # Named data object
101    data = alt.NamedData(name="Foo")
102    dct = Chart(data).to_dict()
103    assert dct["data"] == {"name": "Foo"}
104
105
106def test_chart_infer_types():
107    data = pd.DataFrame(
108        {
109            "x": pd.date_range("2012", periods=10, freq="Y"),
110            "y": range(10),
111            "c": list("abcabcabca"),
112        }
113    )
114
115    def _check_encodings(chart):
116        dct = chart.to_dict()
117        assert dct["encoding"]["x"]["type"] == "temporal"
118        assert dct["encoding"]["x"]["field"] == "x"
119        assert dct["encoding"]["y"]["type"] == "quantitative"
120        assert dct["encoding"]["y"]["field"] == "y"
121        assert dct["encoding"]["color"]["type"] == "nominal"
122        assert dct["encoding"]["color"]["field"] == "c"
123
124    # Pass field names by keyword
125    chart = alt.Chart(data).mark_point().encode(x="x", y="y", color="c")
126    _check_encodings(chart)
127
128    # pass Channel objects by keyword
129    chart = (
130        alt.Chart(data)
131        .mark_point()
132        .encode(x=alt.X("x"), y=alt.Y("y"), color=alt.Color("c"))
133    )
134    _check_encodings(chart)
135
136    # pass Channel objects by value
137    chart = alt.Chart(data).mark_point().encode(alt.X("x"), alt.Y("y"), alt.Color("c"))
138    _check_encodings(chart)
139
140    # override default types
141    chart = (
142        alt.Chart(data)
143        .mark_point()
144        .encode(alt.X("x", type="nominal"), alt.Y("y", type="ordinal"))
145    )
146    dct = chart.to_dict()
147    assert dct["encoding"]["x"]["type"] == "nominal"
148    assert dct["encoding"]["y"]["type"] == "ordinal"
149
150
151@pytest.mark.parametrize(
152    "args, kwargs",
153    [
154        getargs(detail=["value:Q", "name:N"], tooltip=["value:Q", "name:N"]),
155        getargs(detail=["value", "name"], tooltip=["value", "name"]),
156        getargs(alt.Detail(["value:Q", "name:N"]), alt.Tooltip(["value:Q", "name:N"])),
157        getargs(alt.Detail(["value", "name"]), alt.Tooltip(["value", "name"])),
158        getargs(
159            [alt.Detail("value:Q"), alt.Detail("name:N")],
160            [alt.Tooltip("value:Q"), alt.Tooltip("name:N")],
161        ),
162        getargs(
163            [alt.Detail("value"), alt.Detail("name")],
164            [alt.Tooltip("value"), alt.Tooltip("name")],
165        ),
166    ],
167)
168def test_multiple_encodings(args, kwargs):
169    df = pd.DataFrame({"value": [1, 2, 3], "name": ["A", "B", "C"]})
170    encoding_dct = [
171        {"field": "value", "type": "quantitative"},
172        {"field": "name", "type": "nominal"},
173    ]
174    chart = alt.Chart(df).mark_point().encode(*args, **kwargs)
175    dct = chart.to_dict()
176    assert dct["encoding"]["detail"] == encoding_dct
177    assert dct["encoding"]["tooltip"] == encoding_dct
178
179
180def test_chart_operations():
181    data = pd.DataFrame(
182        {
183            "x": pd.date_range("2012", periods=10, freq="Y"),
184            "y": range(10),
185            "c": list("abcabcabca"),
186        }
187    )
188    chart1 = alt.Chart(data).mark_line().encode(x="x", y="y", color="c")
189    chart2 = chart1.mark_point()
190    chart3 = chart1.mark_circle()
191    chart4 = chart1.mark_square()
192
193    chart = chart1 + chart2 + chart3
194    assert isinstance(chart, alt.LayerChart)
195    assert len(chart.layer) == 3
196    chart += chart4
197    assert len(chart.layer) == 4
198
199    chart = chart1 | chart2 | chart3
200    assert isinstance(chart, alt.HConcatChart)
201    assert len(chart.hconcat) == 3
202    chart |= chart4
203    assert len(chart.hconcat) == 4
204
205    chart = chart1 & chart2 & chart3
206    assert isinstance(chart, alt.VConcatChart)
207    assert len(chart.vconcat) == 3
208    chart &= chart4
209    assert len(chart.vconcat) == 4
210
211
212def test_selection_to_dict():
213    brush = alt.selection(type="interval")
214
215    # test some value selections
216    # Note: X and Y cannot have conditions
217    alt.Chart("path/to/data.json").mark_point().encode(
218        color=alt.condition(brush, alt.ColorValue("red"), alt.ColorValue("blue")),
219        opacity=alt.condition(brush, alt.value(0.5), alt.value(1.0)),
220        text=alt.condition(brush, alt.TextValue("foo"), alt.value("bar")),
221    ).to_dict()
222
223    # test some field selections
224    # Note: X and Y cannot have conditions
225    # Conditions cannot both be fields
226    alt.Chart("path/to/data.json").mark_point().encode(
227        color=alt.condition(brush, alt.Color("col1:N"), alt.value("blue")),
228        opacity=alt.condition(brush, "col1:N", alt.value(0.5)),
229        text=alt.condition(brush, alt.value("abc"), alt.Text("col2:N")),
230        size=alt.condition(brush, alt.value(20), "col2:N"),
231    ).to_dict()
232
233
234def test_selection_expression():
235    selection = alt.selection_single(fields=["value"])
236
237    assert isinstance(selection.value, alt.expr.Expression)
238    assert selection.value.to_dict() == "{0}.value".format(selection.name)
239
240    assert isinstance(selection["value"], alt.expr.Expression)
241    assert selection["value"].to_dict() == "{0}['value']".format(selection.name)
242
243
244@pytest.mark.parametrize("format", ["html", "json", "png", "svg"])
245def test_save(format, basic_chart):
246    if format == "png":
247        out = io.BytesIO()
248        mode = "rb"
249    else:
250        out = io.StringIO()
251        mode = "r"
252
253    if format in ["svg", "png"] and not altair_saver:
254        with pytest.raises(ValueError) as err:
255            basic_chart.save(out, format=format)
256        assert "github.com/altair-viz/altair_saver" in str(err.value)
257        return
258
259    basic_chart.save(out, format=format)
260    out.seek(0)
261    content = out.read()
262
263    if format == "json":
264        assert "$schema" in json.loads(content)
265    if format == "html":
266        assert content.startswith("<!DOCTYPE html>")
267
268    fid, filename = tempfile.mkstemp(suffix="." + format)
269    os.close(fid)
270
271    try:
272        basic_chart.save(filename)
273        with open(filename, mode) as f:
274            assert f.read() == content
275    finally:
276        os.remove(filename)
277
278
279def test_facet_basic():
280    # wrapped facet
281    chart1 = (
282        alt.Chart("data.csv")
283        .mark_point()
284        .encode(x="x:Q", y="y:Q",)
285        .facet("category:N", columns=2)
286    )
287
288    dct1 = chart1.to_dict()
289
290    assert dct1["facet"] == alt.Facet("category:N").to_dict()
291    assert dct1["columns"] == 2
292    assert dct1["data"] == alt.UrlData("data.csv").to_dict()
293
294    # explicit row/col facet
295    chart2 = (
296        alt.Chart("data.csv")
297        .mark_point()
298        .encode(x="x:Q", y="y:Q",)
299        .facet(row="category1:Q", column="category2:Q")
300    )
301
302    dct2 = chart2.to_dict()
303
304    assert dct2["facet"]["row"] == alt.Facet("category1:Q").to_dict()
305    assert dct2["facet"]["column"] == alt.Facet("category2:Q").to_dict()
306    assert "columns" not in dct2
307    assert dct2["data"] == alt.UrlData("data.csv").to_dict()
308
309
310def test_facet_parse():
311    chart = (
312        alt.Chart("data.csv")
313        .mark_point()
314        .encode(x="x:Q", y="y:Q")
315        .facet(row="row:N", column="column:O")
316    )
317    dct = chart.to_dict()
318    assert dct["data"] == {"url": "data.csv"}
319    assert "data" not in dct["spec"]
320    assert dct["facet"] == {
321        "column": {"field": "column", "type": "ordinal"},
322        "row": {"field": "row", "type": "nominal"},
323    }
324
325
326def test_facet_parse_data():
327    data = pd.DataFrame({"x": range(5), "y": range(5), "row": list("abcab")})
328    chart = (
329        alt.Chart(data)
330        .mark_point()
331        .encode(x="x", y="y:O")
332        .facet(row="row", column="column:O")
333    )
334    with alt.data_transformers.enable(consolidate_datasets=False):
335        dct = chart.to_dict()
336    assert "values" in dct["data"]
337    assert "data" not in dct["spec"]
338    assert dct["facet"] == {
339        "column": {"field": "column", "type": "ordinal"},
340        "row": {"field": "row", "type": "nominal"},
341    }
342
343    with alt.data_transformers.enable(consolidate_datasets=True):
344        dct = chart.to_dict()
345    assert "datasets" in dct
346    assert "name" in dct["data"]
347    assert "data" not in dct["spec"]
348    assert dct["facet"] == {
349        "column": {"field": "column", "type": "ordinal"},
350        "row": {"field": "row", "type": "nominal"},
351    }
352
353
354def test_selection():
355    # test instantiation of selections
356    interval = alt.selection_interval(name="selec_1")
357    assert interval.selection.type == "interval"
358    assert interval.name == "selec_1"
359
360    single = alt.selection_single(name="selec_2")
361    assert single.selection.type == "single"
362    assert single.name == "selec_2"
363
364    multi = alt.selection_multi(name="selec_3")
365    assert multi.selection.type == "multi"
366    assert multi.name == "selec_3"
367
368    # test adding to chart
369    chart = alt.Chart().add_selection(single)
370    chart = chart.add_selection(multi, interval)
371    assert set(chart.selection.keys()) == {"selec_1", "selec_2", "selec_3"}
372
373    # test logical operations
374    assert isinstance(single & multi, alt.Selection)
375    assert isinstance(single | multi, alt.Selection)
376    assert isinstance(~single, alt.Selection)
377    assert isinstance((single & multi)[0].group, alt.SelectionAnd)
378    assert isinstance((single | multi)[0].group, alt.SelectionOr)
379    assert isinstance((~single)[0].group, alt.SelectionNot)
380
381    # test that default names increment (regression for #1454)
382    sel1 = alt.selection_single()
383    sel2 = alt.selection_multi()
384    sel3 = alt.selection_interval()
385    names = {s.name for s in (sel1, sel2, sel3)}
386    assert len(names) == 3
387
388
389def test_transforms():
390    # aggregate transform
391    agg1 = alt.AggregatedFieldDef(**{"as": "x1", "op": "mean", "field": "y"})
392    agg2 = alt.AggregatedFieldDef(**{"as": "x2", "op": "median", "field": "z"})
393    chart = alt.Chart().transform_aggregate([agg1], ["foo"], x2="median(z)")
394    kwds = dict(aggregate=[agg1, agg2], groupby=["foo"])
395    assert chart.transform == [alt.AggregateTransform(**kwds)]
396
397    # bin transform
398    chart = alt.Chart().transform_bin("binned", field="field", bin=True)
399    kwds = {"as": "binned", "field": "field", "bin": True}
400    assert chart.transform == [alt.BinTransform(**kwds)]
401
402    # calcualte transform
403    chart = alt.Chart().transform_calculate("calc", "datum.a * 4")
404    kwds = {"as": "calc", "calculate": "datum.a * 4"}
405    assert chart.transform == [alt.CalculateTransform(**kwds)]
406
407    # density transform
408    chart = alt.Chart().transform_density("x", as_=["value", "density"])
409    kwds = {"as": ["value", "density"], "density": "x"}
410    assert chart.transform == [alt.DensityTransform(**kwds)]
411
412    # filter transform
413    chart = alt.Chart().transform_filter("datum.a < 4")
414    assert chart.transform == [alt.FilterTransform(filter="datum.a < 4")]
415
416    # flatten transform
417    chart = alt.Chart().transform_flatten(["A", "B"], ["X", "Y"])
418    kwds = {"as": ["X", "Y"], "flatten": ["A", "B"]}
419    assert chart.transform == [alt.FlattenTransform(**kwds)]
420
421    # fold transform
422    chart = alt.Chart().transform_fold(["A", "B", "C"], as_=["key", "val"])
423    kwds = {"as": ["key", "val"], "fold": ["A", "B", "C"]}
424    assert chart.transform == [alt.FoldTransform(**kwds)]
425
426    # impute transform
427    chart = alt.Chart().transform_impute("field", "key", groupby=["x"])
428    kwds = {"impute": "field", "key": "key", "groupby": ["x"]}
429    assert chart.transform == [alt.ImputeTransform(**kwds)]
430
431    # joinaggregate transform
432    chart = alt.Chart().transform_joinaggregate(min="min(x)", groupby=["key"])
433    kwds = {
434        "joinaggregate": [
435            alt.JoinAggregateFieldDef(field="x", op="min", **{"as": "min"})
436        ],
437        "groupby": ["key"],
438    }
439    assert chart.transform == [alt.JoinAggregateTransform(**kwds)]
440
441    # loess transform
442    chart = alt.Chart().transform_loess("x", "y", as_=["xx", "yy"])
443    kwds = {"on": "x", "loess": "y", "as": ["xx", "yy"]}
444    assert chart.transform == [alt.LoessTransform(**kwds)]
445
446    # lookup transform (data)
447    lookup_data = alt.LookupData(alt.UrlData("foo.csv"), "id", ["rate"])
448    chart = alt.Chart().transform_lookup("a", from_=lookup_data, as_="a", default="b")
449    kwds = {"from": lookup_data, "as": "a", "lookup": "a", "default": "b"}
450    assert chart.transform == [alt.LookupTransform(**kwds)]
451
452    # lookup transform (selection)
453    lookup_selection = alt.LookupSelection(key="key", selection="sel")
454    chart = alt.Chart().transform_lookup(
455        "a", from_=lookup_selection, as_="a", default="b"
456    )
457    kwds = {"from": lookup_selection, "as": "a", "lookup": "a", "default": "b"}
458    assert chart.transform == [alt.LookupTransform(**kwds)]
459
460    # pivot transform
461    chart = alt.Chart().transform_pivot("x", "y")
462    assert chart.transform == [alt.PivotTransform(pivot="x", value="y")]
463
464    # quantile transform
465    chart = alt.Chart().transform_quantile("x", as_=["prob", "value"])
466    kwds = {"quantile": "x", "as": ["prob", "value"]}
467    assert chart.transform == [alt.QuantileTransform(**kwds)]
468
469    # regression transform
470    chart = alt.Chart().transform_regression("x", "y", as_=["xx", "yy"])
471    kwds = {"on": "x", "regression": "y", "as": ["xx", "yy"]}
472    assert chart.transform == [alt.RegressionTransform(**kwds)]
473
474    # sample transform
475    chart = alt.Chart().transform_sample()
476    assert chart.transform == [alt.SampleTransform(1000)]
477
478    # stack transform
479    chart = alt.Chart().transform_stack("stacked", "x", groupby=["y"])
480    assert chart.transform == [
481        alt.StackTransform(stack="x", groupby=["y"], **{"as": "stacked"})
482    ]
483
484    # timeUnit transform
485    chart = alt.Chart().transform_timeunit("foo", field="x", timeUnit="date")
486    kwds = {"as": "foo", "field": "x", "timeUnit": "date"}
487    assert chart.transform == [alt.TimeUnitTransform(**kwds)]
488
489    # window transform
490    chart = alt.Chart().transform_window(xsum="sum(x)", ymin="min(y)", frame=[None, 0])
491    window = [
492        alt.WindowFieldDef(**{"as": "xsum", "field": "x", "op": "sum"}),
493        alt.WindowFieldDef(**{"as": "ymin", "field": "y", "op": "min"}),
494    ]
495
496    # kwargs don't maintain order in Python < 3.6, so window list can
497    # be reversed
498    assert chart.transform == [
499        alt.WindowTransform(frame=[None, 0], window=window)
500    ] or chart.transform == [alt.WindowTransform(frame=[None, 0], window=window[::-1])]
501
502
503def test_filter_transform_selection_predicates():
504    selector1 = alt.selection_interval(name="s1")
505    selector2 = alt.selection_interval(name="s2")
506    base = alt.Chart("data.txt").mark_point()
507
508    chart = base.transform_filter(selector1)
509    assert chart.to_dict()["transform"] == [{"filter": {"selection": "s1"}}]
510
511    chart = base.transform_filter(~selector1)
512    assert chart.to_dict()["transform"] == [{"filter": {"selection": {"not": "s1"}}}]
513
514    chart = base.transform_filter(selector1 & selector2)
515    assert chart.to_dict()["transform"] == [
516        {"filter": {"selection": {"and": ["s1", "s2"]}}}
517    ]
518
519    chart = base.transform_filter(selector1 | selector2)
520    assert chart.to_dict()["transform"] == [
521        {"filter": {"selection": {"or": ["s1", "s2"]}}}
522    ]
523
524    chart = base.transform_filter(selector1 | ~selector2)
525    assert chart.to_dict()["transform"] == [
526        {"filter": {"selection": {"or": ["s1", {"not": "s2"}]}}}
527    ]
528
529    chart = base.transform_filter(~selector1 | ~selector2)
530    assert chart.to_dict()["transform"] == [
531        {"filter": {"selection": {"or": [{"not": "s1"}, {"not": "s2"}]}}}
532    ]
533
534    chart = base.transform_filter(~(selector1 & selector2))
535    assert chart.to_dict()["transform"] == [
536        {"filter": {"selection": {"not": {"and": ["s1", "s2"]}}}}
537    ]
538
539
540def test_resolve_methods():
541    chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
542    assert chart.resolve == alt.Resolve(
543        axis=alt.AxisResolveMap(x="shared", y="independent")
544    )
545
546    chart = alt.LayerChart().resolve_legend(color="shared", fill="independent")
547    assert chart.resolve == alt.Resolve(
548        legend=alt.LegendResolveMap(color="shared", fill="independent")
549    )
550
551    chart = alt.LayerChart().resolve_scale(x="shared", y="independent")
552    assert chart.resolve == alt.Resolve(
553        scale=alt.ScaleResolveMap(x="shared", y="independent")
554    )
555
556
557def test_layer_encodings():
558    chart = alt.LayerChart().encode(x="column:Q")
559    assert chart.encoding.x == alt.X(shorthand="column:Q")
560
561
562def test_add_selection():
563    selections = [
564        alt.selection_interval(),
565        alt.selection_single(),
566        alt.selection_multi(),
567    ]
568    chart = (
569        alt.Chart()
570        .mark_point()
571        .add_selection(selections[0])
572        .add_selection(selections[1], selections[2])
573    )
574    expected = {s.name: s.selection for s in selections}
575    assert chart.selection == expected
576
577
578def test_repeat_add_selections():
579    base = alt.Chart("data.csv").mark_point()
580    selection = alt.selection_single()
581    chart1 = base.add_selection(selection).repeat(list("ABC"))
582    chart2 = base.repeat(list("ABC")).add_selection(selection)
583    assert chart1.to_dict() == chart2.to_dict()
584
585
586def test_facet_add_selections():
587    base = alt.Chart("data.csv").mark_point()
588    selection = alt.selection_single()
589    chart1 = base.add_selection(selection).facet("val:Q")
590    chart2 = base.facet("val:Q").add_selection(selection)
591    assert chart1.to_dict() == chart2.to_dict()
592
593
594def test_layer_add_selection():
595    base = alt.Chart("data.csv").mark_point()
596    selection = alt.selection_single()
597    chart1 = alt.layer(base.add_selection(selection), base)
598    chart2 = alt.layer(base, base).add_selection(selection)
599    assert chart1.to_dict() == chart2.to_dict()
600
601
602@pytest.mark.parametrize("charttype", [alt.concat, alt.hconcat, alt.vconcat])
603def test_compound_add_selections(charttype):
604    base = alt.Chart("data.csv").mark_point()
605    selection = alt.selection_single()
606    chart1 = charttype(base.add_selection(selection), base.add_selection(selection))
607    chart2 = charttype(base, base).add_selection(selection)
608    assert chart1.to_dict() == chart2.to_dict()
609
610
611def test_selection_property():
612    sel = alt.selection_interval()
613    chart = alt.Chart("data.csv").mark_point().properties(selection=sel)
614
615    assert list(chart["selection"].keys()) == [sel.name]
616
617
618def test_LookupData():
619    df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
620    lookup = alt.LookupData(data=df, key="x")
621
622    dct = lookup.to_dict()
623    assert dct["key"] == "x"
624    assert dct["data"] == {
625        "values": [{"x": 1, "y": 4}, {"x": 2, "y": 5}, {"x": 3, "y": 6}]
626    }
627
628
629def test_themes():
630    chart = alt.Chart("foo.txt").mark_point()
631
632    with alt.themes.enable("default"):
633        assert chart.to_dict()["config"] == {
634            "view": {"continuousWidth": 400, "continuousHeight": 300}
635        }
636
637    with alt.themes.enable("opaque"):
638        assert chart.to_dict()["config"] == {
639            "background": "white",
640            "view": {"continuousWidth": 400, "continuousHeight": 300},
641        }
642
643    with alt.themes.enable("none"):
644        assert "config" not in chart.to_dict()
645
646
647def test_chart_from_dict():
648    base = alt.Chart("data.csv").mark_point().encode(x="x:Q", y="y:Q")
649
650    charts = [
651        base,
652        base + base,
653        base | base,
654        base & base,
655        base.facet("c:N"),
656        (base + base).facet(row="c:N", data="data.csv"),
657        base.repeat(["c", "d"]),
658        (base + base).repeat(row=["c", "d"]),
659    ]
660
661    for chart in charts:
662        chart_out = alt.Chart.from_dict(chart.to_dict())
663        assert type(chart_out) is type(chart)
664
665    # test that an invalid spec leads to a schema validation error
666    with pytest.raises(jsonschema.ValidationError):
667        alt.Chart.from_dict({"invalid": "spec"})
668
669
670def test_consolidate_datasets(basic_chart):
671    subchart1 = basic_chart
672    subchart2 = basic_chart.copy()
673    subchart2.data = basic_chart.data.copy()
674    chart = subchart1 | subchart2
675
676    with alt.data_transformers.enable(consolidate_datasets=True):
677        dct_consolidated = chart.to_dict()
678
679    with alt.data_transformers.enable(consolidate_datasets=False):
680        dct_standard = chart.to_dict()
681
682    assert "datasets" in dct_consolidated
683    assert "datasets" not in dct_standard
684
685    datasets = dct_consolidated["datasets"]
686
687    # two dataset copies should be recognized as duplicates
688    assert len(datasets) == 1
689
690    # make sure data matches original & names are correct
691    name, data = datasets.popitem()
692
693    for spec in dct_standard["hconcat"]:
694        assert spec["data"]["values"] == data
695
696    for spec in dct_consolidated["hconcat"]:
697        assert spec["data"] == {"name": name}
698
699
700def test_consolidate_InlineData():
701    data = alt.InlineData(
702        values=[{"a": 1, "b": 1}, {"a": 2, "b": 2}], format={"type": "csv"}
703    )
704    chart = alt.Chart(data).mark_point()
705
706    with alt.data_transformers.enable(consolidate_datasets=False):
707        dct = chart.to_dict()
708    assert dct["data"]["format"] == data.format
709    assert dct["data"]["values"] == data.values
710
711    with alt.data_transformers.enable(consolidate_datasets=True):
712        dct = chart.to_dict()
713    assert dct["data"]["format"] == data.format
714    assert list(dct["datasets"].values())[0] == data.values
715
716    data = alt.InlineData(values=[], name="runtime_data")
717    chart = alt.Chart(data).mark_point()
718
719    with alt.data_transformers.enable(consolidate_datasets=False):
720        dct = chart.to_dict()
721    assert dct["data"] == data.to_dict()
722
723    with alt.data_transformers.enable(consolidate_datasets=True):
724        dct = chart.to_dict()
725    assert dct["data"] == data.to_dict()
726
727
728def test_repeat():
729    # wrapped repeat
730    chart1 = (
731        alt.Chart("data.csv")
732        .mark_point()
733        .encode(x=alt.X(alt.repeat(), type="quantitative"), y="y:Q",)
734        .repeat(["A", "B", "C", "D"], columns=2)
735    )
736
737    dct1 = chart1.to_dict()
738
739    assert dct1["repeat"] == ["A", "B", "C", "D"]
740    assert dct1["columns"] == 2
741    assert dct1["spec"]["encoding"]["x"]["field"] == {"repeat": "repeat"}
742
743    # explicit row/col repeat
744    chart2 = (
745        alt.Chart("data.csv")
746        .mark_point()
747        .encode(
748            x=alt.X(alt.repeat("row"), type="quantitative"),
749            y=alt.Y(alt.repeat("column"), type="quantitative"),
750        )
751        .repeat(row=["A", "B", "C"], column=["C", "B", "A"])
752    )
753
754    dct2 = chart2.to_dict()
755
756    assert dct2["repeat"] == {"row": ["A", "B", "C"], "column": ["C", "B", "A"]}
757    assert "columns" not in dct2
758    assert dct2["spec"]["encoding"]["x"]["field"] == {"repeat": "row"}
759    assert dct2["spec"]["encoding"]["y"]["field"] == {"repeat": "column"}
760
761
762def test_data_property():
763    data = pd.DataFrame({"x": [1, 2, 3], "y": list("ABC")})
764    chart1 = alt.Chart(data).mark_point()
765    chart2 = alt.Chart().mark_point().properties(data=data)
766
767    assert chart1.to_dict() == chart2.to_dict()
768
769
770@pytest.mark.parametrize("method", ["layer", "hconcat", "vconcat", "concat"])
771@pytest.mark.parametrize(
772    "data", ["data.json", pd.DataFrame({"x": range(3), "y": list("abc")})]
773)
774def test_subcharts_with_same_data(method, data):
775    func = getattr(alt, method)
776
777    point = alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q")
778    line = point.mark_line()
779    text = point.mark_text()
780
781    chart1 = func(point, line, text)
782    assert chart1.data is not alt.Undefined
783    assert all(c.data is alt.Undefined for c in getattr(chart1, method))
784
785    if method != "concat":
786        op = OP_DICT[method]
787        chart2 = op(op(point, line), text)
788        assert chart2.data is not alt.Undefined
789        assert all(c.data is alt.Undefined for c in getattr(chart2, method))
790
791
792@pytest.mark.parametrize("method", ["layer", "hconcat", "vconcat", "concat"])
793@pytest.mark.parametrize(
794    "data", ["data.json", pd.DataFrame({"x": range(3), "y": list("abc")})]
795)
796def test_subcharts_different_data(method, data):
797    func = getattr(alt, method)
798
799    point = alt.Chart(data).mark_point().encode(x="x:Q", y="y:Q")
800    otherdata = alt.Chart("data.csv").mark_point().encode(x="x:Q", y="y:Q")
801    nodata = alt.Chart().mark_point().encode(x="x:Q", y="y:Q")
802
803    chart1 = func(point, otherdata)
804    assert chart1.data is alt.Undefined
805    assert getattr(chart1, method)[0].data is data
806
807    chart2 = func(point, nodata)
808    assert chart2.data is alt.Undefined
809    assert getattr(chart2, method)[0].data is data
810
811
812def test_layer_facet(basic_chart):
813    chart = (basic_chart + basic_chart).facet(row="row:Q")
814    assert chart.data is not alt.Undefined
815    assert chart.spec.data is alt.Undefined
816    for layer in chart.spec.layer:
817        assert layer.data is alt.Undefined
818
819    dct = chart.to_dict()
820    assert "data" in dct
821
822
823def test_layer_errors():
824    toplevel_chart = alt.Chart("data.txt").mark_point().configure_legend(columns=2)
825
826    facet_chart1 = alt.Chart("data.txt").mark_point().encode(facet="row:Q")
827
828    facet_chart2 = alt.Chart("data.txt").mark_point().facet("row:Q")
829
830    repeat_chart = alt.Chart("data.txt").mark_point().repeat(["A", "B", "C"])
831
832    simple_chart = alt.Chart("data.txt").mark_point()
833
834    with pytest.raises(ValueError) as err:
835        toplevel_chart + simple_chart
836    assert str(err.value).startswith(
837        'Objects with "config" attribute cannot be used within LayerChart.'
838    )
839
840    with pytest.raises(ValueError) as err:
841        repeat_chart + simple_chart
842    assert str(err.value) == "Repeat charts cannot be layered."
843
844    with pytest.raises(ValueError) as err:
845        facet_chart1 + simple_chart
846    assert str(err.value) == "Faceted charts cannot be layered."
847
848    with pytest.raises(ValueError) as err:
849        alt.layer(simple_chart) + facet_chart2
850    assert str(err.value) == "Faceted charts cannot be layered."
851
852
853@pytest.mark.parametrize(
854    "chart_type",
855    ["layer", "hconcat", "vconcat", "concat", "facet", "facet_encoding", "repeat"],
856)
857def test_resolve(chart_type):
858    chart = _make_chart_type(chart_type)
859    chart = (
860        chart.resolve_scale(x="independent",)
861        .resolve_legend(color="independent")
862        .resolve_axis(y="independent")
863    )
864    dct = chart.to_dict()
865    assert dct["resolve"] == {
866        "scale": {"x": "independent"},
867        "legend": {"color": "independent"},
868        "axis": {"y": "independent"},
869    }
870
871
872# TODO: test vconcat, hconcat, concat, facet_encoding when schema allows them.
873# This is blocked by https://github.com/vega/vega-lite/issues/5261
874@pytest.mark.parametrize("chart_type", ["chart", "layer"])
875@pytest.mark.parametrize("facet_arg", [None, "facet", "row", "column"])
876def test_facet(chart_type, facet_arg):
877    chart = _make_chart_type(chart_type)
878    if facet_arg is None:
879        chart = chart.facet("color:N", columns=2)
880    else:
881        chart = chart.facet(**{facet_arg: "color:N", "columns": 2})
882    dct = chart.to_dict()
883
884    assert "spec" in dct
885    assert dct["columns"] == 2
886    expected = {"field": "color", "type": "nominal"}
887    if facet_arg is None or facet_arg == "facet":
888        assert dct["facet"] == expected
889    else:
890        assert dct["facet"][facet_arg] == expected
891
892
893def test_sequence():
894    data = alt.sequence(100)
895    assert data.to_dict() == {"sequence": {"start": 0, "stop": 100}}
896
897    data = alt.sequence(5, 10)
898    assert data.to_dict() == {"sequence": {"start": 5, "stop": 10}}
899
900    data = alt.sequence(0, 1, 0.1, as_="x")
901    assert data.to_dict() == {
902        "sequence": {"start": 0, "stop": 1, "step": 0.1, "as": "x"}
903    }
904
905
906def test_graticule():
907    data = alt.graticule()
908    assert data.to_dict() == {"graticule": True}
909
910    data = alt.graticule(step=[15, 15])
911    assert data.to_dict() == {"graticule": {"step": [15, 15]}}
912
913
914def test_sphere():
915    data = alt.sphere()
916    assert data.to_dict() == {"sphere": True}
917