1import pytest
2from collections import OrderedDict
3from datetime import datetime
4
5from marshmallow import EXCLUDE, fields, INCLUDE, RAISE, Schema, validate
6
7from apispec.ext.marshmallow import MarshmallowPlugin
8from apispec import exceptions, utils, APISpec
9
10from .schemas import CustomList, CustomStringField
11from .utils import get_schemas, build_ref
12
13
14class TestMarshmallowFieldToOpenAPI:
15    def test_fields_with_load_default_load(self, openapi):
16        class MySchema(Schema):
17            field = fields.Str(dump_default="foo", load_default="bar")
18
19        res = openapi.schema2parameters(MySchema, location="query")
20        if openapi.openapi_version.major < 3:
21            assert res[0]["default"] == "bar"
22        else:
23            assert res[0]["schema"]["default"] == "bar"
24
25    # json/body is invalid for OpenAPI 3
26    @pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
27    def test_fields_default_location_mapping_if_schema_many(self, openapi):
28        class ExampleSchema(Schema):
29            id = fields.Int()
30
31        schema = ExampleSchema(many=True)
32        res = openapi.schema2parameters(schema=schema, location="json")
33        assert res[0]["in"] == "body"
34
35    def test_fields_with_dump_only(self, openapi):
36        class UserSchema(Schema):
37            name = fields.Str(dump_only=True)
38
39        res = openapi.schema2parameters(schema=UserSchema(), location="query")
40        assert len(res) == 0
41
42        class UserSchema(Schema):
43            name = fields.Str()
44
45            class Meta:
46                dump_only = ("name",)
47
48        res = openapi.schema2parameters(schema=UserSchema(), location="query")
49        assert len(res) == 0
50
51
52class TestMarshmallowSchemaToModelDefinition:
53    def test_schema2jsonschema_with_explicit_fields(self, openapi):
54        class UserSchema(Schema):
55            _id = fields.Int()
56            email = fields.Email(metadata={"description": "email address of the user"})
57            name = fields.Str()
58
59            class Meta:
60                title = "User"
61
62        res = openapi.schema2jsonschema(UserSchema)
63        assert res["title"] == "User"
64        assert res["type"] == "object"
65        props = res["properties"]
66        assert props["_id"]["type"] == "integer"
67        assert props["email"]["type"] == "string"
68        assert props["email"]["format"] == "email"
69        assert props["email"]["description"] == "email address of the user"
70
71    def test_schema2jsonschema_override_name(self, openapi):
72        class ExampleSchema(Schema):
73            _id = fields.Int(data_key="id")
74            _global = fields.Int(data_key="global")
75
76            class Meta:
77                exclude = ("_global",)
78
79        res = openapi.schema2jsonschema(ExampleSchema)
80        assert res["type"] == "object"
81        props = res["properties"]
82        # `_id` renamed to `id`
83        assert "_id" not in props and props["id"]["type"] == "integer"
84        # `_global` excluded correctly
85        assert "_global" not in props and "global" not in props
86
87    def test_required_fields(self, openapi):
88        class BandSchema(Schema):
89            drummer = fields.Str(required=True)
90            bassist = fields.Str()
91
92        res = openapi.schema2jsonschema(BandSchema)
93        assert res["required"] == ["drummer"]
94
95    def test_partial(self, openapi):
96        class BandSchema(Schema):
97            drummer = fields.Str(required=True)
98            bassist = fields.Str(required=True)
99
100        res = openapi.schema2jsonschema(BandSchema(partial=True))
101        assert "required" not in res
102
103        res = openapi.schema2jsonschema(BandSchema(partial=("drummer",)))
104        assert res["required"] == ["bassist"]
105
106    @pytest.mark.parametrize("ordered_schema", (True, False))
107    def test_ordered(self, openapi, ordered_schema):
108        class BandSchema(Schema):
109            class Meta:
110                ordered = ordered_schema
111
112            drummer = fields.Str()
113            bassist = fields.Str()
114
115        res = openapi.schema2jsonschema(BandSchema)
116        assert isinstance(res["properties"], OrderedDict) == ordered_schema
117
118        res = openapi.schema2jsonschema(BandSchema())
119        assert isinstance(res["properties"], OrderedDict) == ordered_schema
120
121    def test_no_required_fields(self, openapi):
122        class BandSchema(Schema):
123            drummer = fields.Str()
124            bassist = fields.Str()
125
126        res = openapi.schema2jsonschema(BandSchema)
127        assert "required" not in res
128
129    def test_title_and_description_may_be_added(self, openapi):
130        class UserSchema(Schema):
131            class Meta:
132                title = "User"
133                description = "A registered user"
134
135        res = openapi.schema2jsonschema(UserSchema)
136        assert res["description"] == "A registered user"
137        assert res["title"] == "User"
138
139    def test_excluded_fields(self, openapi):
140        class WhiteStripesSchema(Schema):
141            class Meta:
142                exclude = ("bassist",)
143
144            guitarist = fields.Str()
145            drummer = fields.Str()
146            bassist = fields.Str()
147
148        res = openapi.schema2jsonschema(WhiteStripesSchema)
149        assert set(res["properties"].keys()) == {"guitarist", "drummer"}
150
151    def test_unknown_values_disallow(self, openapi):
152        class UnknownRaiseSchema(Schema):
153            class Meta:
154                unknown = RAISE
155
156            first = fields.Str()
157
158        res = openapi.schema2jsonschema(UnknownRaiseSchema)
159        assert res["additionalProperties"] is False
160
161    def test_unknown_values_allow(self, openapi):
162        class UnknownIncludeSchema(Schema):
163            class Meta:
164                unknown = INCLUDE
165
166            first = fields.Str()
167
168        res = openapi.schema2jsonschema(UnknownIncludeSchema)
169        assert res["additionalProperties"] is True
170
171    def test_unknown_values_ignore(self, openapi):
172        class UnknownExcludeSchema(Schema):
173            class Meta:
174                unknown = EXCLUDE
175
176            first = fields.Str()
177
178        res = openapi.schema2jsonschema(UnknownExcludeSchema)
179        assert "additionalProperties" not in res
180
181    def test_only_explicitly_declared_fields_are_translated(self, openapi):
182        class UserSchema(Schema):
183            _id = fields.Int()
184
185            class Meta:
186                title = "User"
187                fields = ("_id", "email")
188
189        with pytest.warns(
190            UserWarning,
191            match="Only explicitly-declared fields will be included in the Schema Object.",
192        ):
193            res = openapi.schema2jsonschema(UserSchema)
194            assert res["type"] == "object"
195            props = res["properties"]
196            assert "_id" in props
197            assert "email" not in props
198
199    def test_observed_field_name_for_required_field(self, openapi):
200        fields_dict = {"user_id": fields.Int(data_key="id", required=True)}
201        res = openapi.fields2jsonschema(fields_dict)
202        assert res["required"] == ["id"]
203
204    @pytest.mark.parametrize("many", (True, False))
205    def test_schema_instance_inspection(self, openapi, many):
206        class UserSchema(Schema):
207            _id = fields.Int()
208
209        res = openapi.schema2jsonschema(UserSchema(many=many))
210        assert res["type"] == "object"
211        props = res["properties"]
212        assert "_id" in props
213
214    def test_raises_error_if_no_declared_fields(self, openapi):
215        class NotASchema:
216            pass
217
218        expected_error = (
219            f"{NotASchema!r} is neither a Schema class nor a Schema instance."
220        )
221        with pytest.raises(ValueError, match=expected_error):
222            openapi.schema2jsonschema(NotASchema)
223
224
225class TestMarshmallowSchemaToParameters:
226    @pytest.mark.parametrize("ListClass", [fields.List, CustomList])
227    def test_field_multiple(self, ListClass, openapi):
228        field = ListClass(fields.Str)
229        res = openapi._field2parameter(field, name="field", location="query")
230        assert res["in"] == "query"
231        if openapi.openapi_version.major < 3:
232            assert res["type"] == "array"
233            assert res["items"]["type"] == "string"
234            assert res["collectionFormat"] == "multi"
235        else:
236            assert res["schema"]["type"] == "array"
237            assert res["schema"]["items"]["type"] == "string"
238            assert res["style"] == "form"
239            assert res["explode"] is True
240
241    def test_field_required(self, openapi):
242        field = fields.Str(required=True)
243        res = openapi._field2parameter(field, name="field", location="query")
244        assert res["required"] is True
245
246    def test_schema_partial(self, openapi):
247        class UserSchema(Schema):
248            field = fields.Str(required=True)
249
250        res_nodump = openapi.schema2parameters(
251            UserSchema(partial=True), location="query"
252        )
253
254        param = res_nodump[0]
255        assert param["required"] is False
256
257    def test_schema_partial_list(self, openapi):
258        class UserSchema(Schema):
259            field = fields.Str(required=True)
260            partial_field = fields.Str(required=True)
261
262        res_nodump = openapi.schema2parameters(
263            UserSchema(partial=("partial_field",)), location="query"
264        )
265
266        param = next(p for p in res_nodump if p["name"] == "field")
267        assert param["required"] is True
268        param = next(p for p in res_nodump if p["name"] == "partial_field")
269        assert param["required"] is False
270
271    # json/body is invalid for OpenAPI 3
272    @pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
273    def test_schema_body(self, openapi):
274        class UserSchema(Schema):
275            name = fields.Str()
276            email = fields.Email()
277
278        res = openapi.schema2parameters(UserSchema, location="body")
279        assert len(res) == 1
280        param = res[0]
281        assert param["in"] == "body"
282        assert param["schema"] == {"$ref": "#/definitions/User"}
283
284    # json/body is invalid for OpenAPI 3
285    @pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
286    def test_schema_body_with_dump_only(self, openapi):
287        class UserSchema(Schema):
288            name = fields.Str()
289            email = fields.Email(dump_only=True)
290
291        res_nodump = openapi.schema2parameters(UserSchema, location="body")
292        assert len(res_nodump) == 1
293        param = res_nodump[0]
294        assert param["in"] == "body"
295        assert param["schema"] == build_ref(openapi.spec, "schema", "User")
296
297    # json/body is invalid for OpenAPI 3
298    @pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
299    def test_schema_body_many(self, openapi):
300        class UserSchema(Schema):
301            name = fields.Str()
302            email = fields.Email()
303
304        res = openapi.schema2parameters(UserSchema(many=True), location="body")
305        assert len(res) == 1
306        param = res[0]
307        assert param["in"] == "body"
308        assert param["schema"]["type"] == "array"
309        assert param["schema"]["items"] == {"$ref": "#/definitions/User"}
310
311    def test_schema_query(self, openapi):
312        class UserSchema(Schema):
313            name = fields.Str()
314            email = fields.Email()
315
316        res = openapi.schema2parameters(UserSchema, location="query")
317        assert len(res) == 2
318        res.sort(key=lambda param: param["name"])
319        assert res[0]["name"] == "email"
320        assert res[0]["in"] == "query"
321        assert res[1]["name"] == "name"
322        assert res[1]["in"] == "query"
323
324    def test_schema_query_instance(self, openapi):
325        class UserSchema(Schema):
326            name = fields.Str()
327            email = fields.Email()
328
329        res = openapi.schema2parameters(UserSchema(), location="query")
330        assert len(res) == 2
331        res.sort(key=lambda param: param["name"])
332        assert res[0]["name"] == "email"
333        assert res[0]["in"] == "query"
334        assert res[1]["name"] == "name"
335        assert res[1]["in"] == "query"
336
337    def test_schema_query_instance_many_should_raise_exception(self, openapi):
338        class UserSchema(Schema):
339            name = fields.Str()
340            email = fields.Email()
341
342        with pytest.raises(AssertionError):
343            openapi.schema2parameters(UserSchema(many=True), location="query")
344
345    def test_fields_query(self, openapi):
346        class MySchema(Schema):
347            name = fields.Str()
348            email = fields.Email()
349
350        res = openapi.schema2parameters(MySchema, location="query")
351        assert len(res) == 2
352        res.sort(key=lambda param: param["name"])
353        assert res[0]["name"] == "email"
354        assert res[0]["in"] == "query"
355        assert res[1]["name"] == "name"
356        assert res[1]["in"] == "query"
357
358    def test_raises_error_if_not_a_schema(self, openapi):
359        class NotASchema:
360            pass
361
362        expected_error = (
363            f"{NotASchema!r} is neither a Schema class nor a Schema instance."
364        )
365        with pytest.raises(ValueError, match=expected_error):
366            openapi.schema2jsonschema(NotASchema)
367
368
369class CategorySchema(Schema):
370    id = fields.Int()
371    name = fields.Str(required=True)
372    breed = fields.Str(dump_only=True)
373
374
375class PageSchema(Schema):
376    offset = fields.Int()
377    limit = fields.Int()
378
379
380class PetSchema(Schema):
381    category = fields.Nested(CategorySchema, many=True)
382    name = fields.Str()
383
384
385class TestNesting:
386    def test_schema2jsonschema_with_nested_fields(self, spec_fixture):
387        res = spec_fixture.openapi.schema2jsonschema(PetSchema)
388        props = res["properties"]
389
390        assert props["category"]["items"] == build_ref(
391            spec_fixture.spec, "schema", "Category"
392        )
393
394    @pytest.mark.parametrize("modifier", ("only", "exclude"))
395    def test_schema2jsonschema_with_nested_fields_only_exclude(
396        self, spec_fixture, modifier
397    ):
398        class Child(Schema):
399            i = fields.Int()
400            j = fields.Int()
401
402        class Parent(Schema):
403            child = fields.Nested(Child, **{modifier: ("i",)})
404
405        spec_fixture.openapi.schema2jsonschema(Parent)
406        props = get_schemas(spec_fixture.spec)["Child"]["properties"]
407        assert ("i" in props) == (modifier == "only")
408        assert ("j" not in props) == (modifier == "only")
409
410    def test_schema2jsonschema_with_plucked_field(self, spec_fixture):
411        class PetSchema(Schema):
412            breed = fields.Pluck(CategorySchema, "breed")
413
414        category_schema = spec_fixture.openapi.schema2jsonschema(CategorySchema)
415        pet_schema = spec_fixture.openapi.schema2jsonschema(PetSchema)
416        assert (
417            pet_schema["properties"]["breed"] == category_schema["properties"]["breed"]
418        )
419
420    def test_schema2jsonschema_with_nested_fields_with_adhoc_changes(
421        self, spec_fixture
422    ):
423        category_schema = CategorySchema()
424        category_schema.fields["id"].required = True
425
426        class PetSchema(Schema):
427            category = fields.Nested(category_schema, many=True)
428            name = fields.Str()
429
430        spec_fixture.spec.components.schema("Pet", schema=PetSchema)
431        props = get_schemas(spec_fixture.spec)
432
433        assert props["Category"] == spec_fixture.openapi.schema2jsonschema(
434            category_schema
435        )
436        assert set(props["Category"]["required"]) == {"id", "name"}
437
438        props["Category"]["required"] = ["name"]
439        assert props["Category"] == spec_fixture.openapi.schema2jsonschema(
440            CategorySchema
441        )
442
443    def test_schema2jsonschema_with_plucked_fields_with_adhoc_changes(
444        self, spec_fixture
445    ):
446        category_schema = CategorySchema()
447        category_schema.fields["breed"].dump_only = True
448
449        class PetSchema(Schema):
450            breed = fields.Pluck(category_schema, "breed", many=True)
451
452        spec_fixture.spec.components.schema("Pet", schema=PetSchema)
453        props = get_schemas(spec_fixture.spec)["Pet"]["properties"]
454
455        assert props["breed"]["items"]["readOnly"] is True
456
457    def test_schema2jsonschema_with_nested_excluded_fields(self, spec):
458        category_schema = CategorySchema(exclude=("breed",))
459
460        class PetSchema(Schema):
461            category = fields.Nested(category_schema)
462
463        spec.components.schema("Pet", schema=PetSchema)
464
465        category_props = get_schemas(spec)["Category"]["properties"]
466        assert "breed" not in category_props
467
468
469def test_openapi_tools_validate_v2():
470    ma_plugin = MarshmallowPlugin()
471    spec = APISpec(
472        title="Pets", version="0.1", plugins=(ma_plugin,), openapi_version="2.0"
473    )
474    openapi = ma_plugin.converter
475
476    spec.components.schema("Category", schema=CategorySchema)
477    spec.components.schema("Pet", {"discriminator": "name"}, schema=PetSchema)
478
479    spec.path(
480        view=None,
481        path="/category/{category_id}",
482        operations={
483            "get": {
484                "parameters": [
485                    {"name": "q", "in": "query", "type": "string"},
486                    {
487                        "name": "category_id",
488                        "in": "path",
489                        "required": True,
490                        "type": "string",
491                    },
492                    openapi._field2parameter(
493                        field=fields.List(
494                            fields.Str(),
495                            validate=validate.OneOf(["freddie", "roger"]),
496                        ),
497                        location="query",
498                        name="body",
499                    ),
500                ]
501                + openapi.schema2parameters(PageSchema, location="query"),
502                "responses": {200: {"schema": PetSchema, "description": "A pet"}},
503            },
504            "post": {
505                "parameters": (
506                    [
507                        {
508                            "name": "category_id",
509                            "in": "path",
510                            "required": True,
511                            "type": "string",
512                        }
513                    ]
514                    + openapi.schema2parameters(CategorySchema, location="body")
515                ),
516                "responses": {201: {"schema": PetSchema, "description": "A pet"}},
517            },
518        },
519    )
520    try:
521        utils.validate_spec(spec)
522    except exceptions.OpenAPIError as error:
523        pytest.fail(str(error))
524
525
526def test_openapi_tools_validate_v3():
527    ma_plugin = MarshmallowPlugin()
528    spec = APISpec(
529        title="Pets", version="0.1", plugins=(ma_plugin,), openapi_version="3.0.0"
530    )
531    openapi = ma_plugin.converter
532
533    spec.components.schema("Category", schema=CategorySchema)
534    spec.components.schema("Pet", schema=PetSchema)
535
536    spec.path(
537        view=None,
538        path="/category/{category_id}",
539        operations={
540            "get": {
541                "parameters": [
542                    {"name": "q", "in": "query", "schema": {"type": "string"}},
543                    {
544                        "name": "category_id",
545                        "in": "path",
546                        "required": True,
547                        "schema": {"type": "string"},
548                    },
549                    openapi._field2parameter(
550                        field=fields.List(
551                            fields.Str(),
552                            validate=validate.OneOf(["freddie", "roger"]),
553                        ),
554                        location="query",
555                        name="body",
556                    ),
557                ]
558                + openapi.schema2parameters(PageSchema, location="query"),
559                "responses": {
560                    200: {
561                        "description": "success",
562                        "content": {"application/json": {"schema": PetSchema}},
563                    }
564                },
565            },
566            "post": {
567                "parameters": (
568                    [
569                        {
570                            "name": "category_id",
571                            "in": "path",
572                            "required": True,
573                            "schema": {"type": "string"},
574                        }
575                    ]
576                ),
577                "requestBody": {
578                    "content": {"application/json": {"schema": CategorySchema}}
579                },
580                "responses": {
581                    201: {
582                        "description": "created",
583                        "content": {"application/json": {"schema": PetSchema}},
584                    }
585                },
586            },
587        },
588    )
589    try:
590        utils.validate_spec(spec)
591    except exceptions.OpenAPIError as error:
592        pytest.fail(str(error))
593
594
595class TestFieldValidation:
596    class ValidationSchema(Schema):
597        id = fields.Int(dump_only=True)
598        range = fields.Int(validate=validate.Range(min=1, max=10))
599        range_no_upper = fields.Float(validate=validate.Range(min=1))
600        multiple_ranges = fields.Int(
601            validate=[
602                validate.Range(min=1),
603                validate.Range(min=3),
604                validate.Range(max=10),
605                validate.Range(max=7),
606            ]
607        )
608        list_length = fields.List(fields.Str, validate=validate.Length(min=1, max=10))
609        custom_list_length = CustomList(
610            fields.Str, validate=validate.Length(min=1, max=10)
611        )
612        string_length = fields.Str(validate=validate.Length(min=1, max=10))
613        custom_field_length = CustomStringField(validate=validate.Length(min=1, max=10))
614        multiple_lengths = fields.Str(
615            validate=[
616                validate.Length(min=1),
617                validate.Length(min=3),
618                validate.Length(max=10),
619                validate.Length(max=7),
620            ]
621        )
622        equal_length = fields.Str(
623            validate=[validate.Length(equal=5), validate.Length(min=1, max=10)]
624        )
625        date_range = fields.DateTime(
626            validate=validate.Range(
627                min=datetime(1900, 1, 1),
628            )
629        )
630
631    @pytest.mark.parametrize(
632        ("field", "properties"),
633        [
634            ("range", {"minimum": 1, "maximum": 10}),
635            ("range_no_upper", {"minimum": 1}),
636            ("multiple_ranges", {"minimum": 3, "maximum": 7}),
637            ("list_length", {"minItems": 1, "maxItems": 10}),
638            ("custom_list_length", {"minItems": 1, "maxItems": 10}),
639            ("string_length", {"minLength": 1, "maxLength": 10}),
640            ("custom_field_length", {"minLength": 1, "maxLength": 10}),
641            ("multiple_lengths", {"minLength": 3, "maxLength": 7}),
642            ("equal_length", {"minLength": 5, "maxLength": 5}),
643            ("date_range", {"x-minimum": datetime(1900, 1, 1)}),
644        ],
645    )
646    def test_properties(self, field, properties, spec):
647        spec.components.schema("Validation", schema=self.ValidationSchema)
648        result = get_schemas(spec)["Validation"]["properties"][field]
649
650        for attr, expected_value in properties.items():
651            assert attr in result
652            assert result[attr] == expected_value
653