1"""Example implementation of using a marshmallow Schema for both request input
2and output with a `use_schema` decorator.
3Run the app:
4
5    $ python examples/schema_example.py
6
7Try the following with httpie (a cURL-like utility, http://httpie.org):
8
9    $ pip install httpie
10    $ http GET :5001/users/
11    $ http GET :5001/users/42
12    $ http POST :5001/users/ username=brian first_name=Brian last_name=May
13    $ http PATCH :5001/users/42 username=freddie
14    $ http GET :5001/users/ limit==1
15"""
16import functools
17from flask import Flask, request
18import random
19
20from marshmallow import Schema, fields, post_dump
21from webargs.flaskparser import parser, use_kwargs
22
23app = Flask(__name__)
24
25##### Fake database and model #####
26
27
28class Model:
29    def __init__(self, **kwargs):
30        self.__dict__.update(kwargs)
31
32    def update(self, **kwargs):
33        self.__dict__.update(kwargs)
34
35    @classmethod
36    def insert(cls, db, **kwargs):
37        collection = db[cls.collection]
38        new_id = None
39        if "id" in kwargs:  # for setting up fixtures
40            new_id = kwargs.pop("id")
41        else:  # find a new id
42            found_id = False
43            while not found_id:
44                new_id = random.randint(1, 9999)
45                if new_id not in collection:
46                    found_id = True
47        new_record = cls(id=new_id, **kwargs)
48        collection[new_id] = new_record
49        return new_record
50
51
52class User(Model):
53    collection = "users"
54
55
56db = {"users": {}}
57
58
59##### use_schema #####
60
61
62def use_schema(schema_cls, list_view=False, locations=None):
63    """View decorator for using a marshmallow schema to
64    (1) parse a request's input and
65    (2) serializing the view's output to a JSON response.
66    """
67
68    def decorator(func):
69        @functools.wraps(func)
70        def wrapped(*args, **kwargs):
71            partial = request.method != "POST"
72            schema = schema_cls(partial=partial)
73            use_args_wrapper = parser.use_args(schema, locations=locations)
74            # Function wrapped with use_args
75            func_with_args = use_args_wrapper(func)
76            ret = func_with_args(*args, **kwargs)
77            return schema.dump(ret, many=list_view)
78
79        return wrapped
80
81    return decorator
82
83
84##### Schemas #####
85
86
87class UserSchema(Schema):
88    id = fields.Int(dump_only=True)
89    username = fields.Str(required=True)
90    first_name = fields.Str()
91    last_name = fields.Str()
92
93    @post_dump(pass_many=True)
94    def wrap_with_envelope(self, data, many, **kwargs):
95        return {"data": data}
96
97
98##### Routes #####
99
100
101@app.route("/users/<int:user_id>", methods=["GET", "PATCH"])
102@use_schema(UserSchema)
103def user_detail(reqargs, user_id):
104    user = db["users"].get(user_id)
105    if not user:
106        return {"message": "User not found"}, 404
107    if request.method == "PATCH" and reqargs:
108        user.update(**reqargs)
109    return user
110
111
112# You can add additional arguments with use_kwargs
113@app.route("/users/", methods=["GET", "POST"])
114@use_kwargs({"limit": fields.Int(missing=10, location="query")})
115@use_schema(UserSchema, list_view=True)
116def user_list(reqargs, limit):
117    users = db["users"].values()
118    if request.method == "POST":
119        User.insert(db=db, **reqargs)
120    return list(users)[:limit]
121
122
123# Return validation errors as JSON
124@app.errorhandler(422)
125@app.errorhandler(400)
126def handle_validation_error(err):
127    exc = getattr(err, "exc", None)
128    if exc:
129        headers = err.data["headers"]
130        messages = exc.messages
131    else:
132        headers = None
133        messages = ["Invalid request."]
134    if headers:
135        return {"errors": messages}, err.code, headers
136    else:
137        return {"errors": messages}, err.code
138
139
140if __name__ == "__main__":
141    User.insert(
142        db=db, id=42, username="fred", first_name="Freddie", last_name="Mercury"
143    )
144    app.run(port=5001, debug=True)
145