1""" test parquet compat """
2import datetime
3from distutils.version import LooseVersion
4from io import BytesIO
5import os
6import pathlib
7from warnings import catch_warnings
8
9import numpy as np
10import pytest
11
12from pandas.compat import PY38, is_platform_windows
13import pandas.util._test_decorators as td
14
15import pandas as pd
16import pandas._testing as tm
17
18from pandas.io.parquet import (
19    FastParquetImpl,
20    PyArrowImpl,
21    get_engine,
22    read_parquet,
23    to_parquet,
24)
25
26try:
27    import pyarrow
28
29    _HAVE_PYARROW = True
30except ImportError:
31    _HAVE_PYARROW = False
32
33try:
34    import fastparquet
35
36    _HAVE_FASTPARQUET = True
37except ImportError:
38    _HAVE_FASTPARQUET = False
39
40
41pytestmark = pytest.mark.filterwarnings(
42    "ignore:RangeIndex.* is deprecated:DeprecationWarning"
43)
44
45
46# setup engines & skips
47@pytest.fixture(
48    params=[
49        pytest.param(
50            "fastparquet",
51            marks=pytest.mark.skipif(
52                not _HAVE_FASTPARQUET, reason="fastparquet is not installed"
53            ),
54        ),
55        pytest.param(
56            "pyarrow",
57            marks=pytest.mark.skipif(
58                not _HAVE_PYARROW, reason="pyarrow is not installed"
59            ),
60        ),
61    ]
62)
63def engine(request):
64    return request.param
65
66
67@pytest.fixture
68def pa():
69    if not _HAVE_PYARROW:
70        pytest.skip("pyarrow is not installed")
71    return "pyarrow"
72
73
74@pytest.fixture
75def fp():
76    if not _HAVE_FASTPARQUET:
77        pytest.skip("fastparquet is not installed")
78    return "fastparquet"
79
80
81@pytest.fixture
82def df_compat():
83    return pd.DataFrame({"A": [1, 2, 3], "B": "foo"})
84
85
86@pytest.fixture
87def df_cross_compat():
88    df = pd.DataFrame(
89        {
90            "a": list("abc"),
91            "b": list(range(1, 4)),
92            # 'c': np.arange(3, 6).astype('u1'),
93            "d": np.arange(4.0, 7.0, dtype="float64"),
94            "e": [True, False, True],
95            "f": pd.date_range("20130101", periods=3),
96            # 'g': pd.date_range('20130101', periods=3,
97            #                    tz='US/Eastern'),
98            # 'h': pd.date_range('20130101', periods=3, freq='ns')
99        }
100    )
101    return df
102
103
104@pytest.fixture
105def df_full():
106    return pd.DataFrame(
107        {
108            "string": list("abc"),
109            "string_with_nan": ["a", np.nan, "c"],
110            "string_with_none": ["a", None, "c"],
111            "bytes": [b"foo", b"bar", b"baz"],
112            "unicode": ["foo", "bar", "baz"],
113            "int": list(range(1, 4)),
114            "uint": np.arange(3, 6).astype("u1"),
115            "float": np.arange(4.0, 7.0, dtype="float64"),
116            "float_with_nan": [2.0, np.nan, 3.0],
117            "bool": [True, False, True],
118            "datetime": pd.date_range("20130101", periods=3),
119            "datetime_with_nat": [
120                pd.Timestamp("20130101"),
121                pd.NaT,
122                pd.Timestamp("20130103"),
123            ],
124        }
125    )
126
127
128@pytest.fixture(
129    params=[
130        datetime.datetime.now(datetime.timezone.utc),
131        datetime.datetime.now(datetime.timezone.min),
132        datetime.datetime.now(datetime.timezone.max),
133        datetime.datetime.strptime("2019-01-04T16:41:24+0200", "%Y-%m-%dT%H:%M:%S%z"),
134        datetime.datetime.strptime("2019-01-04T16:41:24+0215", "%Y-%m-%dT%H:%M:%S%z"),
135        datetime.datetime.strptime("2019-01-04T16:41:24-0200", "%Y-%m-%dT%H:%M:%S%z"),
136        datetime.datetime.strptime("2019-01-04T16:41:24-0215", "%Y-%m-%dT%H:%M:%S%z"),
137    ]
138)
139def timezone_aware_date_list(request):
140    return request.param
141
142
143def check_round_trip(
144    df,
145    engine=None,
146    path=None,
147    write_kwargs=None,
148    read_kwargs=None,
149    expected=None,
150    check_names=True,
151    check_like=False,
152    check_dtype=True,
153    repeat=2,
154):
155    """Verify parquet serializer and deserializer produce the same results.
156
157    Performs a pandas to disk and disk to pandas round trip,
158    then compares the 2 resulting DataFrames to verify equality.
159
160    Parameters
161    ----------
162    df: Dataframe
163    engine: str, optional
164        'pyarrow' or 'fastparquet'
165    path: str, optional
166    write_kwargs: dict of str:str, optional
167    read_kwargs: dict of str:str, optional
168    expected: DataFrame, optional
169        Expected deserialization result, otherwise will be equal to `df`
170    check_names: list of str, optional
171        Closed set of column names to be compared
172    check_like: bool, optional
173        If True, ignore the order of index & columns.
174    repeat: int, optional
175        How many times to repeat the test
176    """
177    write_kwargs = write_kwargs or {"compression": None}
178    read_kwargs = read_kwargs or {}
179
180    if expected is None:
181        expected = df
182
183    if engine:
184        write_kwargs["engine"] = engine
185        read_kwargs["engine"] = engine
186
187    def compare(repeat):
188        for _ in range(repeat):
189            df.to_parquet(path, **write_kwargs)
190            with catch_warnings(record=True):
191                actual = read_parquet(path, **read_kwargs)
192
193            tm.assert_frame_equal(
194                expected,
195                actual,
196                check_names=check_names,
197                check_like=check_like,
198                check_dtype=check_dtype,
199            )
200
201    if path is None:
202        with tm.ensure_clean() as path:
203            compare(repeat)
204    else:
205        compare(repeat)
206
207
208def test_invalid_engine(df_compat):
209    with pytest.raises(ValueError):
210        check_round_trip(df_compat, "foo", "bar")
211
212
213def test_options_py(df_compat, pa):
214    # use the set option
215
216    with pd.option_context("io.parquet.engine", "pyarrow"):
217        check_round_trip(df_compat)
218
219
220def test_options_fp(df_compat, fp):
221    # use the set option
222
223    with pd.option_context("io.parquet.engine", "fastparquet"):
224        check_round_trip(df_compat)
225
226
227def test_options_auto(df_compat, fp, pa):
228    # use the set option
229
230    with pd.option_context("io.parquet.engine", "auto"):
231        check_round_trip(df_compat)
232
233
234def test_options_get_engine(fp, pa):
235    assert isinstance(get_engine("pyarrow"), PyArrowImpl)
236    assert isinstance(get_engine("fastparquet"), FastParquetImpl)
237
238    with pd.option_context("io.parquet.engine", "pyarrow"):
239        assert isinstance(get_engine("auto"), PyArrowImpl)
240        assert isinstance(get_engine("pyarrow"), PyArrowImpl)
241        assert isinstance(get_engine("fastparquet"), FastParquetImpl)
242
243    with pd.option_context("io.parquet.engine", "fastparquet"):
244        assert isinstance(get_engine("auto"), FastParquetImpl)
245        assert isinstance(get_engine("pyarrow"), PyArrowImpl)
246        assert isinstance(get_engine("fastparquet"), FastParquetImpl)
247
248    with pd.option_context("io.parquet.engine", "auto"):
249        assert isinstance(get_engine("auto"), PyArrowImpl)
250        assert isinstance(get_engine("pyarrow"), PyArrowImpl)
251        assert isinstance(get_engine("fastparquet"), FastParquetImpl)
252
253
254def test_get_engine_auto_error_message():
255    # Expect different error messages from get_engine(engine="auto")
256    # if engines aren't installed vs. are installed but bad version
257    from pandas.compat._optional import VERSIONS
258
259    # Do we have engines installed, but a bad version of them?
260    pa_min_ver = VERSIONS.get("pyarrow")
261    fp_min_ver = VERSIONS.get("fastparquet")
262    have_pa_bad_version = (
263        False
264        if not _HAVE_PYARROW
265        else LooseVersion(pyarrow.__version__) < LooseVersion(pa_min_ver)
266    )
267    have_fp_bad_version = (
268        False
269        if not _HAVE_FASTPARQUET
270        else LooseVersion(fastparquet.__version__) < LooseVersion(fp_min_ver)
271    )
272    # Do we have usable engines installed?
273    have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version
274    have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version
275
276    if not have_usable_pa and not have_usable_fp:
277        # No usable engines found.
278        if have_pa_bad_version:
279            match = f"Pandas requires version .{pa_min_ver}. or newer of .pyarrow."
280            with pytest.raises(ImportError, match=match):
281                get_engine("auto")
282        else:
283            match = "Missing optional dependency .pyarrow."
284            with pytest.raises(ImportError, match=match):
285                get_engine("auto")
286
287        if have_fp_bad_version:
288            match = f"Pandas requires version .{fp_min_ver}. or newer of .fastparquet."
289            with pytest.raises(ImportError, match=match):
290                get_engine("auto")
291        else:
292            match = "Missing optional dependency .fastparquet."
293            with pytest.raises(ImportError, match=match):
294                get_engine("auto")
295
296
297def test_cross_engine_pa_fp(df_cross_compat, pa, fp):
298    # cross-compat with differing reading/writing engines
299
300    df = df_cross_compat
301    with tm.ensure_clean() as path:
302        df.to_parquet(path, engine=pa, compression=None)
303
304        result = read_parquet(path, engine=fp)
305        tm.assert_frame_equal(result, df)
306
307        result = read_parquet(path, engine=fp, columns=["a", "d"])
308        tm.assert_frame_equal(result, df[["a", "d"]])
309
310
311def test_cross_engine_fp_pa(df_cross_compat, pa, fp):
312    # cross-compat with differing reading/writing engines
313
314    if (
315        LooseVersion(pyarrow.__version__) < "0.15"
316        and LooseVersion(pyarrow.__version__) >= "0.13"
317    ):
318        pytest.xfail(
319            "Reading fastparquet with pyarrow in 0.14 fails: "
320            "https://issues.apache.org/jira/browse/ARROW-6492"
321        )
322
323    df = df_cross_compat
324    with tm.ensure_clean() as path:
325        df.to_parquet(path, engine=fp, compression=None)
326
327        with catch_warnings(record=True):
328            result = read_parquet(path, engine=pa)
329            tm.assert_frame_equal(result, df)
330
331            result = read_parquet(path, engine=pa, columns=["a", "d"])
332            tm.assert_frame_equal(result, df[["a", "d"]])
333
334
335class Base:
336    def check_error_on_write(self, df, engine, exc):
337        # check that we are raising the exception on writing
338        with tm.ensure_clean() as path:
339            with pytest.raises(exc):
340                to_parquet(df, path, engine, compression=None)
341
342    @tm.network
343    def test_parquet_read_from_url(self, df_compat, engine):
344        if engine != "auto":
345            pytest.importorskip(engine)
346        url = (
347            "https://raw.githubusercontent.com/pandas-dev/pandas/"
348            "master/pandas/tests/io/data/parquet/simple.parquet"
349        )
350        df = pd.read_parquet(url)
351        tm.assert_frame_equal(df, df_compat)
352
353
354class TestBasic(Base):
355    def test_error(self, engine):
356        for obj in [
357            pd.Series([1, 2, 3]),
358            1,
359            "foo",
360            pd.Timestamp("20130101"),
361            np.array([1, 2, 3]),
362        ]:
363            self.check_error_on_write(obj, engine, ValueError)
364
365    def test_columns_dtypes(self, engine):
366        df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})
367
368        # unicode
369        df.columns = ["foo", "bar"]
370        check_round_trip(df, engine)
371
372    def test_columns_dtypes_invalid(self, engine):
373        df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})
374
375        # numeric
376        df.columns = [0, 1]
377        self.check_error_on_write(df, engine, ValueError)
378
379        # bytes
380        df.columns = [b"foo", b"bar"]
381        self.check_error_on_write(df, engine, ValueError)
382
383        # python object
384        df.columns = [
385            datetime.datetime(2011, 1, 1, 0, 0),
386            datetime.datetime(2011, 1, 1, 1, 1),
387        ]
388        self.check_error_on_write(df, engine, ValueError)
389
390    @pytest.mark.parametrize("compression", [None, "gzip", "snappy", "brotli"])
391    def test_compression(self, engine, compression):
392
393        if compression == "snappy":
394            pytest.importorskip("snappy")
395
396        elif compression == "brotli":
397            pytest.importorskip("brotli")
398
399        df = pd.DataFrame({"A": [1, 2, 3]})
400        check_round_trip(df, engine, write_kwargs={"compression": compression})
401
402    def test_read_columns(self, engine):
403        # GH18154
404        df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})
405
406        expected = pd.DataFrame({"string": list("abc")})
407        check_round_trip(
408            df, engine, expected=expected, read_kwargs={"columns": ["string"]}
409        )
410
411    def test_write_index(self, engine):
412        check_names = engine != "fastparquet"
413
414        df = pd.DataFrame({"A": [1, 2, 3]})
415        check_round_trip(df, engine)
416
417        indexes = [
418            [2, 3, 4],
419            pd.date_range("20130101", periods=3),
420            list("abc"),
421            [1, 3, 4],
422        ]
423        # non-default index
424        for index in indexes:
425            df.index = index
426            if isinstance(index, pd.DatetimeIndex):
427                df.index = df.index._with_freq(None)  # freq doesnt round-trip
428            check_round_trip(df, engine, check_names=check_names)
429
430        # index with meta-data
431        df.index = [0, 1, 2]
432        df.index.name = "foo"
433        check_round_trip(df, engine)
434
435    def test_write_multiindex(self, pa):
436        # Not supported in fastparquet as of 0.1.3 or older pyarrow version
437        engine = pa
438
439        df = pd.DataFrame({"A": [1, 2, 3]})
440        index = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)])
441        df.index = index
442        check_round_trip(df, engine)
443
444    def test_multiindex_with_columns(self, pa):
445        engine = pa
446        dates = pd.date_range("01-Jan-2018", "01-Dec-2018", freq="MS")
447        df = pd.DataFrame(np.random.randn(2 * len(dates), 3), columns=list("ABC"))
448        index1 = pd.MultiIndex.from_product(
449            [["Level1", "Level2"], dates], names=["level", "date"]
450        )
451        index2 = index1.copy(names=None)
452        for index in [index1, index2]:
453            df.index = index
454
455            check_round_trip(df, engine)
456            check_round_trip(
457                df, engine, read_kwargs={"columns": ["A", "B"]}, expected=df[["A", "B"]]
458            )
459
460    def test_write_ignoring_index(self, engine):
461        # ENH 20768
462        # Ensure index=False omits the index from the written Parquet file.
463        df = pd.DataFrame({"a": [1, 2, 3], "b": ["q", "r", "s"]})
464
465        write_kwargs = {"compression": None, "index": False}
466
467        # Because we're dropping the index, we expect the loaded dataframe to
468        # have the default integer index.
469        expected = df.reset_index(drop=True)
470
471        check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected)
472
473        # Ignore custom index
474        df = pd.DataFrame(
475            {"a": [1, 2, 3], "b": ["q", "r", "s"]}, index=["zyx", "wvu", "tsr"]
476        )
477
478        check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected)
479
480        # Ignore multi-indexes as well.
481        arrays = [
482            ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"],
483            ["one", "two", "one", "two", "one", "two", "one", "two"],
484        ]
485        df = pd.DataFrame(
486            {"one": list(range(8)), "two": [-i for i in range(8)]}, index=arrays
487        )
488
489        expected = df.reset_index(drop=True)
490        check_round_trip(df, engine, write_kwargs=write_kwargs, expected=expected)
491
492    def test_write_column_multiindex(self, engine):
493        # Not able to write column multi-indexes with non-string column names.
494        mi_columns = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)])
495        df = pd.DataFrame(np.random.randn(4, 3), columns=mi_columns)
496        self.check_error_on_write(df, engine, ValueError)
497
498    def test_write_column_multiindex_nonstring(self, pa):
499        # GH #34777
500        # Not supported in fastparquet as of 0.1.3
501        engine = pa
502
503        # Not able to write column multi-indexes with non-string column names
504        arrays = [
505            ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"],
506            [1, 2, 1, 2, 1, 2, 1, 2],
507        ]
508        df = pd.DataFrame(np.random.randn(8, 8), columns=arrays)
509        df.columns.names = ["Level1", "Level2"]
510
511        self.check_error_on_write(df, engine, ValueError)
512
513    def test_write_column_multiindex_string(self, pa):
514        # GH #34777
515        # Not supported in fastparquet as of 0.1.3
516        engine = pa
517
518        # Write column multi-indexes with string column names
519        arrays = [
520            ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"],
521            ["one", "two", "one", "two", "one", "two", "one", "two"],
522        ]
523        df = pd.DataFrame(np.random.randn(8, 8), columns=arrays)
524        df.columns.names = ["ColLevel1", "ColLevel2"]
525
526        check_round_trip(df, engine)
527
528    def test_write_column_index_string(self, pa):
529        # GH #34777
530        # Not supported in fastparquet as of 0.1.3
531        engine = pa
532
533        # Write column indexes with string column names
534        arrays = ["bar", "baz", "foo", "qux"]
535        df = pd.DataFrame(np.random.randn(8, 4), columns=arrays)
536        df.columns.name = "StringCol"
537
538        check_round_trip(df, engine)
539
540    def test_write_column_index_nonstring(self, pa):
541        # GH #34777
542        # Not supported in fastparquet as of 0.1.3
543        engine = pa
544
545        # Write column indexes with string column names
546        arrays = [1, 2, 3, 4]
547        df = pd.DataFrame(np.random.randn(8, 4), columns=arrays)
548        df.columns.name = "NonStringCol"
549
550        self.check_error_on_write(df, engine, ValueError)
551
552
553class TestParquetPyArrow(Base):
554    def test_basic(self, pa, df_full):
555
556        df = df_full
557
558        # additional supported types for pyarrow
559        dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels")
560        dti = dti._with_freq(None)  # freq doesnt round-trip
561        df["datetime_tz"] = dti
562        df["bool_with_none"] = [True, None, True]
563
564        check_round_trip(df, pa)
565
566    def test_basic_subset_columns(self, pa, df_full):
567        # GH18628
568
569        df = df_full
570        # additional supported types for pyarrow
571        df["datetime_tz"] = pd.date_range("20130101", periods=3, tz="Europe/Brussels")
572
573        check_round_trip(
574            df,
575            pa,
576            expected=df[["string", "int"]],
577            read_kwargs={"columns": ["string", "int"]},
578        )
579
580    def test_to_bytes_without_path_or_buf_provided(self, pa, df_full):
581        # GH 37105
582
583        buf_bytes = df_full.to_parquet(engine=pa)
584        assert isinstance(buf_bytes, bytes)
585
586        buf_stream = BytesIO(buf_bytes)
587        res = pd.read_parquet(buf_stream)
588
589        tm.assert_frame_equal(df_full, res)
590
591    def test_duplicate_columns(self, pa):
592        # not currently able to handle duplicate columns
593        df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy()
594        self.check_error_on_write(df, pa, ValueError)
595
596    def test_unsupported(self, pa):
597        if LooseVersion(pyarrow.__version__) < LooseVersion("0.15.1.dev"):
598            # period - will be supported using an extension type with pyarrow 1.0
599            df = pd.DataFrame({"a": pd.period_range("2013", freq="M", periods=3)})
600            # pyarrow 0.11 raises ArrowTypeError
601            # older pyarrows raise ArrowInvalid
602            self.check_error_on_write(df, pa, Exception)
603
604        # timedelta
605        df = pd.DataFrame({"a": pd.timedelta_range("1 day", periods=3)})
606        self.check_error_on_write(df, pa, NotImplementedError)
607
608        # mixed python objects
609        df = pd.DataFrame({"a": ["a", 1, 2.0]})
610        # pyarrow 0.11 raises ArrowTypeError
611        # older pyarrows raise ArrowInvalid
612        self.check_error_on_write(df, pa, Exception)
613
614    def test_categorical(self, pa):
615
616        # supported in >= 0.7.0
617        df = pd.DataFrame()
618        df["a"] = pd.Categorical(list("abcdef"))
619
620        # test for null, out-of-order values, and unobserved category
621        df["b"] = pd.Categorical(
622            ["bar", "foo", "foo", "bar", None, "bar"],
623            dtype=pd.CategoricalDtype(["foo", "bar", "baz"]),
624        )
625
626        # test for ordered flag
627        df["c"] = pd.Categorical(
628            ["a", "b", "c", "a", "c", "b"], categories=["b", "c", "d"], ordered=True
629        )
630
631        if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.0"):
632            check_round_trip(df, pa)
633        else:
634            # de-serialized as object for pyarrow < 0.15
635            expected = df.astype(object)
636            check_round_trip(df, pa, expected=expected)
637
638    @pytest.mark.xfail(
639        is_platform_windows() and PY38,
640        reason="localhost connection rejected",
641        strict=False,
642    )
643    def test_s3_roundtrip_explicit_fs(self, df_compat, s3_resource, pa, s3so):
644        s3fs = pytest.importorskip("s3fs")
645        if LooseVersion(pyarrow.__version__) <= LooseVersion("0.17.0"):
646            pytest.skip()
647        s3 = s3fs.S3FileSystem(**s3so)
648        kw = {"filesystem": s3}
649        check_round_trip(
650            df_compat,
651            pa,
652            path="pandas-test/pyarrow.parquet",
653            read_kwargs=kw,
654            write_kwargs=kw,
655        )
656
657    def test_s3_roundtrip(self, df_compat, s3_resource, pa, s3so):
658        if LooseVersion(pyarrow.__version__) <= LooseVersion("0.17.0"):
659            pytest.skip()
660        # GH #19134
661        s3so = {"storage_options": s3so}
662        check_round_trip(
663            df_compat,
664            pa,
665            path="s3://pandas-test/pyarrow.parquet",
666            read_kwargs=s3so,
667            write_kwargs=s3so,
668        )
669
670    @td.skip_if_no("s3fs")  # also requires flask
671    @pytest.mark.parametrize(
672        "partition_col",
673        [
674            ["A"],
675            [],
676        ],
677    )
678    def test_s3_roundtrip_for_dir(
679        self, df_compat, s3_resource, pa, partition_col, s3so
680    ):
681        # GH #26388
682        expected_df = df_compat.copy()
683
684        # GH #35791
685        # read_table uses the new Arrow Datasets API since pyarrow 1.0.0
686        # Previous behaviour was pyarrow partitioned columns become 'category' dtypes
687        # These are added to back of dataframe on read. In new API category dtype is
688        # only used if partition field is string, but this changed again to use
689        # category dtype for all types (not only strings) in pyarrow 2.0.0
690        pa10 = (LooseVersion(pyarrow.__version__) >= LooseVersion("1.0.0")) and (
691            LooseVersion(pyarrow.__version__) < LooseVersion("2.0.0")
692        )
693        if partition_col:
694            if pa10:
695                partition_col_type = "int32"
696            else:
697                partition_col_type = "category"
698
699            expected_df[partition_col] = expected_df[partition_col].astype(
700                partition_col_type
701            )
702
703        check_round_trip(
704            df_compat,
705            pa,
706            expected=expected_df,
707            path="s3://pandas-test/parquet_dir",
708            read_kwargs={"storage_options": s3so},
709            write_kwargs={
710                "partition_cols": partition_col,
711                "compression": None,
712                "storage_options": s3so,
713            },
714            check_like=True,
715            repeat=1,
716        )
717
718    @td.skip_if_no("pyarrow")
719    def test_read_file_like_obj_support(self, df_compat):
720        buffer = BytesIO()
721        df_compat.to_parquet(buffer)
722        df_from_buf = pd.read_parquet(buffer)
723        tm.assert_frame_equal(df_compat, df_from_buf)
724
725    @td.skip_if_no("pyarrow")
726    def test_expand_user(self, df_compat, monkeypatch):
727        monkeypatch.setenv("HOME", "TestingUser")
728        monkeypatch.setenv("USERPROFILE", "TestingUser")
729        with pytest.raises(OSError, match=r".*TestingUser.*"):
730            pd.read_parquet("~/file.parquet")
731        with pytest.raises(OSError, match=r".*TestingUser.*"):
732            df_compat.to_parquet("~/file.parquet")
733
734    def test_partition_cols_supported(self, pa, df_full):
735        # GH #23283
736        partition_cols = ["bool", "int"]
737        df = df_full
738        with tm.ensure_clean_dir() as path:
739            df.to_parquet(path, partition_cols=partition_cols, compression=None)
740            import pyarrow.parquet as pq
741
742            dataset = pq.ParquetDataset(path, validate_schema=False)
743            assert len(dataset.partitions.partition_names) == 2
744            assert dataset.partitions.partition_names == set(partition_cols)
745            assert read_parquet(path).shape == df.shape
746
747    def test_partition_cols_string(self, pa, df_full):
748        # GH #27117
749        partition_cols = "bool"
750        partition_cols_list = [partition_cols]
751        df = df_full
752        with tm.ensure_clean_dir() as path:
753            df.to_parquet(path, partition_cols=partition_cols, compression=None)
754            import pyarrow.parquet as pq
755
756            dataset = pq.ParquetDataset(path, validate_schema=False)
757            assert len(dataset.partitions.partition_names) == 1
758            assert dataset.partitions.partition_names == set(partition_cols_list)
759            assert read_parquet(path).shape == df.shape
760
761    @pytest.mark.parametrize("path_type", [str, pathlib.Path])
762    def test_partition_cols_pathlib(self, pa, df_compat, path_type):
763        # GH 35902
764
765        partition_cols = "B"
766        partition_cols_list = [partition_cols]
767        df = df_compat
768
769        with tm.ensure_clean_dir() as path_str:
770            path = path_type(path_str)
771            df.to_parquet(path, partition_cols=partition_cols_list)
772            assert read_parquet(path).shape == df.shape
773
774    def test_empty_dataframe(self, pa):
775        # GH #27339
776        df = pd.DataFrame()
777        check_round_trip(df, pa)
778
779    def test_write_with_schema(self, pa):
780        import pyarrow
781
782        df = pd.DataFrame({"x": [0, 1]})
783        schema = pyarrow.schema([pyarrow.field("x", type=pyarrow.bool_())])
784        out_df = df.astype(bool)
785        check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df)
786
787    @td.skip_if_no("pyarrow", min_version="0.15.0")
788    def test_additional_extension_arrays(self, pa):
789        # test additional ExtensionArrays that are supported through the
790        # __arrow_array__ protocol
791        df = pd.DataFrame(
792            {
793                "a": pd.Series([1, 2, 3], dtype="Int64"),
794                "b": pd.Series([1, 2, 3], dtype="UInt32"),
795                "c": pd.Series(["a", None, "c"], dtype="string"),
796            }
797        )
798        if LooseVersion(pyarrow.__version__) >= LooseVersion("0.16.0"):
799            expected = df
800        else:
801            # de-serialized as plain int / object
802            expected = df.assign(
803                a=df.a.astype("int64"), b=df.b.astype("int64"), c=df.c.astype("object")
804            )
805        check_round_trip(df, pa, expected=expected)
806
807        df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")})
808        if LooseVersion(pyarrow.__version__) >= LooseVersion("0.16.0"):
809            expected = df
810        else:
811            # if missing values in integer, currently de-serialized as float
812            expected = df.assign(a=df.a.astype("float64"))
813        check_round_trip(df, pa, expected=expected)
814
815    @td.skip_if_no("pyarrow", min_version="0.16.0")
816    def test_additional_extension_types(self, pa):
817        # test additional ExtensionArrays that are supported through the
818        # __arrow_array__ protocol + by defining a custom ExtensionType
819        df = pd.DataFrame(
820            {
821                # Arrow does not yet support struct in writing to Parquet (ARROW-1644)
822                # "c": pd.arrays.IntervalArray.from_tuples([(0, 1), (1, 2), (3, 4)]),
823                "d": pd.period_range("2012-01-01", periods=3, freq="D"),
824            }
825        )
826        check_round_trip(df, pa)
827
828    @td.skip_if_no("pyarrow", min_version="0.16")
829    def test_use_nullable_dtypes(self, pa):
830        import pyarrow.parquet as pq
831
832        table = pyarrow.table(
833            {
834                "a": pyarrow.array([1, 2, 3, None], "int64"),
835                "b": pyarrow.array([1, 2, 3, None], "uint8"),
836                "c": pyarrow.array(["a", "b", "c", None]),
837                "d": pyarrow.array([True, False, True, None]),
838            }
839        )
840        with tm.ensure_clean() as path:
841            # write manually with pyarrow to write integers
842            pq.write_table(table, path)
843            result1 = read_parquet(path)
844            result2 = read_parquet(path, use_nullable_dtypes=True)
845
846        assert result1["a"].dtype == np.dtype("float64")
847        expected = pd.DataFrame(
848            {
849                "a": pd.array([1, 2, 3, None], dtype="Int64"),
850                "b": pd.array([1, 2, 3, None], dtype="UInt8"),
851                "c": pd.array(["a", "b", "c", None], dtype="string"),
852                "d": pd.array([True, False, True, None], dtype="boolean"),
853            }
854        )
855        tm.assert_frame_equal(result2, expected)
856
857    @td.skip_if_no("pyarrow", min_version="0.14")
858    def test_timestamp_nanoseconds(self, pa):
859        # with version 2.0, pyarrow defaults to writing the nanoseconds, so
860        # this should work without error
861        df = pd.DataFrame({"a": pd.date_range("2017-01-01", freq="1n", periods=10)})
862        check_round_trip(df, pa, write_kwargs={"version": "2.0"})
863
864    def test_timezone_aware_index(self, pa, timezone_aware_date_list):
865        if LooseVersion(pyarrow.__version__) >= LooseVersion("2.0.0"):
866            # temporary skip this test until it is properly resolved
867            # https://github.com/pandas-dev/pandas/issues/37286
868            pytest.skip()
869        idx = 5 * [timezone_aware_date_list]
870        df = pd.DataFrame(index=idx, data={"index_as_col": idx})
871
872        # see gh-36004
873        # compare time(zone) values only, skip their class:
874        # pyarrow always creates fixed offset timezones using pytz.FixedOffset()
875        # even if it was datetime.timezone() originally
876        #
877        # technically they are the same:
878        # they both implement datetime.tzinfo
879        # they both wrap datetime.timedelta()
880        # this use-case sets the resolution to 1 minute
881        check_round_trip(df, pa, check_dtype=False)
882
883    @td.skip_if_no("pyarrow", min_version="1.0.0")
884    def test_filter_row_groups(self, pa):
885        # https://github.com/pandas-dev/pandas/issues/26551
886        df = pd.DataFrame({"a": list(range(0, 3))})
887        with tm.ensure_clean() as path:
888            df.to_parquet(path, pa)
889            result = read_parquet(
890                path, pa, filters=[("a", "==", 0)], use_legacy_dataset=False
891            )
892        assert len(result) == 1
893
894
895class TestParquetFastParquet(Base):
896    @td.skip_if_no("fastparquet", min_version="0.3.2")
897    def test_basic(self, fp, df_full):
898        df = df_full
899
900        dti = pd.date_range("20130101", periods=3, tz="US/Eastern")
901        dti = dti._with_freq(None)  # freq doesnt round-trip
902        df["datetime_tz"] = dti
903        df["timedelta"] = pd.timedelta_range("1 day", periods=3)
904        check_round_trip(df, fp)
905
906    @pytest.mark.skip(reason="not supported")
907    def test_duplicate_columns(self, fp):
908
909        # not currently able to handle duplicate columns
910        df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy()
911        self.check_error_on_write(df, fp, ValueError)
912
913    def test_bool_with_none(self, fp):
914        df = pd.DataFrame({"a": [True, None, False]})
915        expected = pd.DataFrame({"a": [1.0, np.nan, 0.0]}, dtype="float16")
916        check_round_trip(df, fp, expected=expected)
917
918    def test_unsupported(self, fp):
919
920        # period
921        df = pd.DataFrame({"a": pd.period_range("2013", freq="M", periods=3)})
922        self.check_error_on_write(df, fp, ValueError)
923
924        # mixed
925        df = pd.DataFrame({"a": ["a", 1, 2.0]})
926        self.check_error_on_write(df, fp, ValueError)
927
928    def test_categorical(self, fp):
929        df = pd.DataFrame({"a": pd.Categorical(list("abc"))})
930        check_round_trip(df, fp)
931
932    def test_filter_row_groups(self, fp):
933        d = {"a": list(range(0, 3))}
934        df = pd.DataFrame(d)
935        with tm.ensure_clean() as path:
936            df.to_parquet(path, fp, compression=None, row_group_offsets=1)
937            result = read_parquet(path, fp, filters=[("a", "==", 0)])
938        assert len(result) == 1
939
940    def test_s3_roundtrip(self, df_compat, s3_resource, fp, s3so):
941        # GH #19134
942        check_round_trip(
943            df_compat,
944            fp,
945            path="s3://pandas-test/fastparquet.parquet",
946            read_kwargs={"storage_options": s3so},
947            write_kwargs={"compression": None, "storage_options": s3so},
948        )
949
950    def test_partition_cols_supported(self, fp, df_full):
951        # GH #23283
952        partition_cols = ["bool", "int"]
953        df = df_full
954        with tm.ensure_clean_dir() as path:
955            df.to_parquet(
956                path,
957                engine="fastparquet",
958                partition_cols=partition_cols,
959                compression=None,
960            )
961            assert os.path.exists(path)
962            import fastparquet
963
964            actual_partition_cols = fastparquet.ParquetFile(path, False).cats
965            assert len(actual_partition_cols) == 2
966
967    def test_partition_cols_string(self, fp, df_full):
968        # GH #27117
969        partition_cols = "bool"
970        df = df_full
971        with tm.ensure_clean_dir() as path:
972            df.to_parquet(
973                path,
974                engine="fastparquet",
975                partition_cols=partition_cols,
976                compression=None,
977            )
978            assert os.path.exists(path)
979            import fastparquet
980
981            actual_partition_cols = fastparquet.ParquetFile(path, False).cats
982            assert len(actual_partition_cols) == 1
983
984    def test_partition_on_supported(self, fp, df_full):
985        # GH #23283
986        partition_cols = ["bool", "int"]
987        df = df_full
988        with tm.ensure_clean_dir() as path:
989            df.to_parquet(
990                path,
991                engine="fastparquet",
992                compression=None,
993                partition_on=partition_cols,
994            )
995            assert os.path.exists(path)
996            import fastparquet
997
998            actual_partition_cols = fastparquet.ParquetFile(path, False).cats
999            assert len(actual_partition_cols) == 2
1000
1001    def test_error_on_using_partition_cols_and_partition_on(self, fp, df_full):
1002        # GH #23283
1003        partition_cols = ["bool", "int"]
1004        df = df_full
1005        with pytest.raises(ValueError):
1006            with tm.ensure_clean_dir() as path:
1007                df.to_parquet(
1008                    path,
1009                    engine="fastparquet",
1010                    compression=None,
1011                    partition_on=partition_cols,
1012                    partition_cols=partition_cols,
1013                )
1014
1015    def test_empty_dataframe(self, fp):
1016        # GH #27339
1017        df = pd.DataFrame()
1018        expected = df.copy()
1019        expected.index.name = "index"
1020        check_round_trip(df, fp, expected=expected)
1021
1022    def test_timezone_aware_index(self, fp, timezone_aware_date_list):
1023        idx = 5 * [timezone_aware_date_list]
1024
1025        df = pd.DataFrame(index=idx, data={"index_as_col": idx})
1026
1027        expected = df.copy()
1028        expected.index.name = "index"
1029        check_round_trip(df, fp, expected=expected)
1030
1031    def test_use_nullable_dtypes_not_supported(self, fp):
1032        df = pd.DataFrame({"a": [1, 2]})
1033
1034        with tm.ensure_clean() as path:
1035            df.to_parquet(path)
1036            with pytest.raises(ValueError, match="not supported for the fastparquet"):
1037                read_parquet(path, engine="fastparquet", use_nullable_dtypes=True)
1038