1import datetime
2import json
3import re
4import sys
5from dataclasses import dataclass as vanilla_dataclass
6from decimal import Decimal
7from enum import Enum
8from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
9from pathlib import Path
10from typing import List
11from uuid import UUID
12
13import pytest
14
15from pydantic import BaseModel, create_model
16from pydantic.color import Color
17from pydantic.dataclasses import dataclass as pydantic_dataclass
18from pydantic.json import pydantic_encoder, timedelta_isoformat
19from pydantic.types import ConstrainedDecimal, DirectoryPath, FilePath, SecretBytes, SecretStr
20
21
22class MyEnum(Enum):
23    foo = 'bar'
24    snap = 'crackle'
25
26
27@pytest.mark.parametrize(
28    'input,output',
29    [
30        (UUID('ebcdab58-6eb8-46fb-a190-d07a33e9eac8'), '"ebcdab58-6eb8-46fb-a190-d07a33e9eac8"'),
31        (IPv4Address('192.168.0.1'), '"192.168.0.1"'),
32        (Color('#000'), '"black"'),
33        (Color((1, 12, 123)), '"#010c7b"'),
34        (SecretStr('abcd'), '"**********"'),
35        (SecretStr(''), '""'),
36        (SecretBytes(b'xyz'), '"**********"'),
37        (SecretBytes(b''), '""'),
38        (IPv6Address('::1:0:1'), '"::1:0:1"'),
39        (IPv4Interface('192.168.0.0/24'), '"192.168.0.0/24"'),
40        (IPv6Interface('2001:db00::/120'), '"2001:db00::/120"'),
41        (IPv4Network('192.168.0.0/24'), '"192.168.0.0/24"'),
42        (IPv6Network('2001:db00::/120'), '"2001:db00::/120"'),
43        (datetime.datetime(2032, 1, 1, 1, 1), '"2032-01-01T01:01:00"'),
44        (datetime.datetime(2032, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), '"2032-01-01T01:01:00+00:00"'),
45        (datetime.datetime(2032, 1, 1), '"2032-01-01T00:00:00"'),
46        (datetime.time(12, 34, 56), '"12:34:56"'),
47        (datetime.timedelta(days=12, seconds=34, microseconds=56), '1036834.000056'),
48        ({1, 2, 3}, '[1, 2, 3]'),
49        (frozenset([1, 2, 3]), '[1, 2, 3]'),
50        ((v for v in range(4)), '[0, 1, 2, 3]'),
51        (b'this is bytes', '"this is bytes"'),
52        (Decimal('12.34'), '12.34'),
53        (create_model('BarModel', a='b', c='d')(), '{"a": "b", "c": "d"}'),
54        (MyEnum.foo, '"bar"'),
55        (re.compile('^regex$'), '"^regex$"'),
56    ],
57)
58def test_encoding(input, output):
59    assert output == json.dumps(input, default=pydantic_encoder)
60
61
62@pytest.mark.skipif(sys.platform.startswith('win'), reason='paths look different on windows')
63def test_path_encoding(tmpdir):
64    class PathModel(BaseModel):
65        path: Path
66        file_path: FilePath
67        dir_path: DirectoryPath
68
69    tmpdir = Path(tmpdir)
70    file_path = tmpdir / 'bar'
71    file_path.touch()
72    dir_path = tmpdir / 'baz'
73    dir_path.mkdir()
74    model = PathModel(path=Path('/path/test/example/'), file_path=file_path, dir_path=dir_path)
75    expected = '{{"path": "/path/test/example", "file_path": "{}", "dir_path": "{}"}}'.format(file_path, dir_path)
76    assert json.dumps(model, default=pydantic_encoder) == expected
77
78
79def test_model_encoding():
80    class ModelA(BaseModel):
81        x: int
82        y: str
83
84    class Model(BaseModel):
85        a: float
86        b: bytes
87        c: Decimal
88        d: ModelA
89
90    m = Model(a=10.2, b='foobar', c=10.2, d={'x': 123, 'y': '123'})
91    assert m.dict() == {'a': 10.2, 'b': b'foobar', 'c': Decimal('10.2'), 'd': {'x': 123, 'y': '123'}}
92    assert m.json() == '{"a": 10.2, "b": "foobar", "c": 10.2, "d": {"x": 123, "y": "123"}}'
93    assert m.json(exclude={'b'}) == '{"a": 10.2, "c": 10.2, "d": {"x": 123, "y": "123"}}'
94
95
96def test_subclass_encoding():
97    class SubDate(datetime.datetime):
98        pass
99
100    class Model(BaseModel):
101        a: datetime.datetime
102        b: SubDate
103
104    m = Model(a=datetime.datetime(2032, 1, 1, 1, 1), b=SubDate(2020, 2, 29, 12, 30))
105    assert m.dict() == {'a': datetime.datetime(2032, 1, 1, 1, 1), 'b': SubDate(2020, 2, 29, 12, 30)}
106    assert m.json() == '{"a": "2032-01-01T01:01:00", "b": "2020-02-29T12:30:00"}'
107
108
109def test_subclass_custom_encoding():
110    class SubDate(datetime.datetime):
111        pass
112
113    class SubDelta(datetime.timedelta):
114        pass
115
116    class Model(BaseModel):
117        a: SubDate
118        b: SubDelta
119
120        class Config:
121            json_encoders = {
122                datetime.datetime: lambda v: v.strftime('%a, %d %b %C %H:%M:%S'),
123                datetime.timedelta: timedelta_isoformat,
124            }
125
126    m = Model(a=SubDate(2032, 1, 1, 1, 1), b=SubDelta(hours=100))
127    assert m.dict() == {'a': SubDate(2032, 1, 1, 1, 1), 'b': SubDelta(days=4, seconds=14400)}
128    assert m.json() == '{"a": "Thu, 01 Jan 20 01:01:00", "b": "P4DT4H0M0.000000S"}'
129
130
131def test_invalid_model():
132    class Foo:
133        pass
134
135    with pytest.raises(TypeError):
136        json.dumps(Foo, default=pydantic_encoder)
137
138
139@pytest.mark.parametrize(
140    'input,output',
141    [
142        (datetime.timedelta(days=12, seconds=34, microseconds=56), 'P12DT0H0M34.000056S'),
143        (datetime.timedelta(days=1001, hours=1, minutes=2, seconds=3, microseconds=654_321), 'P1001DT1H2M3.654321S'),
144    ],
145)
146def test_iso_timedelta(input, output):
147    assert output == timedelta_isoformat(input)
148
149
150def test_custom_encoder():
151    class Model(BaseModel):
152        x: datetime.timedelta
153        y: Decimal
154        z: datetime.date
155
156        class Config:
157            json_encoders = {datetime.timedelta: lambda v: f'{v.total_seconds():0.3f}s', Decimal: lambda v: 'a decimal'}
158
159    assert Model(x=123, y=5, z='2032-06-01').json() == '{"x": "123.000s", "y": "a decimal", "z": "2032-06-01"}'
160
161
162def test_custom_iso_timedelta():
163    class Model(BaseModel):
164        x: datetime.timedelta
165
166        class Config:
167            json_encoders = {datetime.timedelta: timedelta_isoformat}
168
169    m = Model(x=123)
170    assert m.json() == '{"x": "P0DT0H2M3.000000S"}'
171
172
173def test_con_decimal_encode() -> None:
174    """
175    Makes sure a decimal with decimal_places = 0, as well as one with places
176    can handle a encode/decode roundtrip.
177    """
178
179    class Id(ConstrainedDecimal):
180        max_digits = 22
181        decimal_places = 0
182        ge = 0
183
184    class Obj(BaseModel):
185        id: Id
186        price: Decimal = Decimal('0.01')
187
188    assert Obj(id=1).json() == '{"id": 1, "price": 0.01}'
189    assert Obj.parse_raw('{"id": 1, "price": 0.01}') == Obj(id=1)
190
191
192def test_json_encoder_simple_inheritance():
193    class Parent(BaseModel):
194        dt: datetime.datetime = datetime.datetime.now()
195        timedt: datetime.timedelta = datetime.timedelta(hours=100)
196
197        class Config:
198            json_encoders = {datetime.datetime: lambda _: 'parent_encoder'}
199
200    class Child(Parent):
201        class Config:
202            json_encoders = {datetime.timedelta: lambda _: 'child_encoder'}
203
204    assert Child().json() == '{"dt": "parent_encoder", "timedt": "child_encoder"}'
205
206
207def test_json_encoder_inheritance_override():
208    class Parent(BaseModel):
209        dt: datetime.datetime = datetime.datetime.now()
210
211        class Config:
212            json_encoders = {datetime.datetime: lambda _: 'parent_encoder'}
213
214    class Child(Parent):
215        class Config:
216            json_encoders = {datetime.datetime: lambda _: 'child_encoder'}
217
218    assert Child().json() == '{"dt": "child_encoder"}'
219
220
221def test_custom_encoder_arg():
222    class Model(BaseModel):
223        x: datetime.timedelta
224
225    m = Model(x=123)
226    assert m.json() == '{"x": 123.0}'
227    assert m.json(encoder=lambda v: '__default__') == '{"x": "__default__"}'
228
229
230def test_encode_dataclass():
231    @vanilla_dataclass
232    class Foo:
233        bar: int
234        spam: str
235
236    f = Foo(bar=123, spam='apple pie')
237    assert '{"bar": 123, "spam": "apple pie"}' == json.dumps(f, default=pydantic_encoder)
238
239
240def test_encode_pydantic_dataclass():
241    @pydantic_dataclass
242    class Foo:
243        bar: int
244        spam: str
245
246    f = Foo(bar=123, spam='apple pie')
247    assert '{"bar": 123, "spam": "apple pie"}' == json.dumps(f, default=pydantic_encoder)
248
249
250def test_encode_custom_root():
251    class Model(BaseModel):
252        __root__: List[str]
253
254    assert Model(__root__=['a', 'b']).json() == '["a", "b"]'
255
256
257def test_custom_decode_encode():
258    load_calls, dump_calls = 0, 0
259
260    def custom_loads(s):
261        nonlocal load_calls
262        load_calls += 1
263        return json.loads(s.strip('$'))
264
265    def custom_dumps(s, default=None, **kwargs):
266        nonlocal dump_calls
267        dump_calls += 1
268        return json.dumps(s, default=default, indent=2)
269
270    class Model(BaseModel):
271        a: int
272        b: str
273
274        class Config:
275            json_loads = custom_loads
276            json_dumps = custom_dumps
277
278    m = Model.parse_raw('${"a": 1, "b": "foo"}$$')
279    assert m.dict() == {'a': 1, 'b': 'foo'}
280    assert m.json() == '{\n  "a": 1,\n  "b": "foo"\n}'
281