1package jsonapi
2
3import (
4	"encoding/json"
5	"errors"
6	"fmt"
7	"io"
8	"reflect"
9	"strconv"
10	"strings"
11	"time"
12)
13
14var (
15	// ErrBadJSONAPIStructTag is returned when the Struct field's JSON API
16	// annotation is invalid.
17	ErrBadJSONAPIStructTag = errors.New("Bad jsonapi struct tag format")
18	// ErrBadJSONAPIID is returned when the Struct JSON API annotated "id" field
19	// was not a valid numeric type.
20	ErrBadJSONAPIID = errors.New(
21		"id should be either string, int(8,16,32,64) or uint(8,16,32,64)")
22	// ErrExpectedSlice is returned when a variable or argument was expected to
23	// be a slice of *Structs; MarshalMany will return this error when its
24	// interface{} argument is invalid.
25	ErrExpectedSlice = errors.New("models should be a slice of struct pointers")
26	// ErrUnexpectedType is returned when marshalling an interface; the interface
27	// had to be a pointer or a slice; otherwise this error is returned.
28	ErrUnexpectedType = errors.New("models should be a struct pointer or slice of struct pointers")
29)
30
31// MarshalPayload writes a jsonapi response for one or many records. The
32// related records are sideloaded into the "included" array. If this method is
33// given a struct pointer as an argument it will serialize in the form
34// "data": {...}. If this method is given a slice of pointers, this method will
35// serialize in the form "data": [...]
36//
37// One Example: you could pass it, w, your http.ResponseWriter, and, models, a
38// ptr to a Blog to be written to the response body:
39//
40//	 func ShowBlog(w http.ResponseWriter, r *http.Request) {
41//		 blog := &Blog{}
42//
43//		 w.Header().Set("Content-Type", jsonapi.MediaType)
44//		 w.WriteHeader(http.StatusOK)
45//
46//		 if err := jsonapi.MarshalPayload(w, blog); err != nil {
47//			 http.Error(w, err.Error(), http.StatusInternalServerError)
48//		 }
49//	 }
50//
51// Many Example: you could pass it, w, your http.ResponseWriter, and, models, a
52// slice of Blog struct instance pointers to be written to the response body:
53//
54//	 func ListBlogs(w http.ResponseWriter, r *http.Request) {
55//     blogs := []*Blog{}
56//
57//		 w.Header().Set("Content-Type", jsonapi.MediaType)
58//		 w.WriteHeader(http.StatusOK)
59//
60//		 if err := jsonapi.MarshalPayload(w, blogs); err != nil {
61//			 http.Error(w, err.Error(), http.StatusInternalServerError)
62//		 }
63//	 }
64//
65func MarshalPayload(w io.Writer, models interface{}) error {
66	payload, err := Marshal(models)
67	if err != nil {
68		return err
69	}
70
71	return json.NewEncoder(w).Encode(payload)
72}
73
74// Marshal does the same as MarshalPayload except it just returns the payload
75// and doesn't write out results. Useful if you use your own JSON rendering
76// library.
77func Marshal(models interface{}) (Payloader, error) {
78	switch vals := reflect.ValueOf(models); vals.Kind() {
79	case reflect.Slice:
80		m, err := convertToSliceInterface(&models)
81		if err != nil {
82			return nil, err
83		}
84
85		payload, err := marshalMany(m)
86		if err != nil {
87			return nil, err
88		}
89
90		if linkableModels, isLinkable := models.(Linkable); isLinkable {
91			jl := linkableModels.JSONAPILinks()
92			if er := jl.validate(); er != nil {
93				return nil, er
94			}
95			payload.Links = linkableModels.JSONAPILinks()
96		}
97
98		if metableModels, ok := models.(Metable); ok {
99			payload.Meta = metableModels.JSONAPIMeta()
100		}
101
102		return payload, nil
103	case reflect.Ptr:
104		// Check that the pointer was to a struct
105		if reflect.Indirect(vals).Kind() != reflect.Struct {
106			return nil, ErrUnexpectedType
107		}
108		return marshalOne(models)
109	default:
110		return nil, ErrUnexpectedType
111	}
112}
113
114// MarshalPayloadWithoutIncluded writes a jsonapi response with one or many
115// records, without the related records sideloaded into "included" array.
116// If you want to serialize the relations into the "included" array see
117// MarshalPayload.
118//
119// models interface{} should be either a struct pointer or a slice of struct
120// pointers.
121func MarshalPayloadWithoutIncluded(w io.Writer, model interface{}) error {
122	payload, err := Marshal(model)
123	if err != nil {
124		return err
125	}
126	payload.clearIncluded()
127
128	return json.NewEncoder(w).Encode(payload)
129}
130
131// marshalOne does the same as MarshalOnePayload except it just returns the
132// payload and doesn't write out results. Useful is you use your JSON rendering
133// library.
134func marshalOne(model interface{}) (*OnePayload, error) {
135	included := make(map[string]*Node)
136
137	rootNode, err := visitModelNode(model, &included, true)
138	if err != nil {
139		return nil, err
140	}
141	payload := &OnePayload{Data: rootNode}
142
143	payload.Included = nodeMapValues(&included)
144
145	return payload, nil
146}
147
148// marshalMany does the same as MarshalManyPayload except it just returns the
149// payload and doesn't write out results. Useful is you use your JSON rendering
150// library.
151func marshalMany(models []interface{}) (*ManyPayload, error) {
152	payload := &ManyPayload{
153		Data: []*Node{},
154	}
155	included := map[string]*Node{}
156
157	for _, model := range models {
158		node, err := visitModelNode(model, &included, true)
159		if err != nil {
160			return nil, err
161		}
162		payload.Data = append(payload.Data, node)
163	}
164	payload.Included = nodeMapValues(&included)
165
166	return payload, nil
167}
168
169// MarshalOnePayloadEmbedded - This method not meant to for use in
170// implementation code, although feel free.  The purpose of this
171// method is for use in tests.  In most cases, your request
172// payloads for create will be embedded rather than sideloaded for
173// related records. This method will serialize a single struct
174// pointer into an embedded json response. In other words, there
175// will be no, "included", array in the json all relationships will
176// be serailized inline in the data.
177//
178// However, in tests, you may want to construct payloads to post
179// to create methods that are embedded to most closely resemble
180// the payloads that will be produced by the client. This is what
181// this method is intended for.
182//
183// model interface{} should be a pointer to a struct.
184func MarshalOnePayloadEmbedded(w io.Writer, model interface{}) error {
185	rootNode, err := visitModelNode(model, nil, false)
186	if err != nil {
187		return err
188	}
189
190	payload := &OnePayload{Data: rootNode}
191
192	return json.NewEncoder(w).Encode(payload)
193}
194
195func visitModelNode(model interface{}, included *map[string]*Node,
196	sideload bool) (*Node, error) {
197	node := new(Node)
198
199	var er error
200	value := reflect.ValueOf(model)
201	if value.IsNil() {
202		return nil, nil
203	}
204
205	modelValue := value.Elem()
206	modelType := value.Type().Elem()
207
208	for i := 0; i < modelValue.NumField(); i++ {
209		structField := modelValue.Type().Field(i)
210		tag := structField.Tag.Get(annotationJSONAPI)
211		if tag == "" {
212			continue
213		}
214
215		fieldValue := modelValue.Field(i)
216		fieldType := modelType.Field(i)
217
218		args := strings.Split(tag, annotationSeperator)
219
220		if len(args) < 1 {
221			er = ErrBadJSONAPIStructTag
222			break
223		}
224
225		annotation := args[0]
226
227		if (annotation == annotationClientID && len(args) != 1) ||
228			(annotation != annotationClientID && len(args) < 2) {
229			er = ErrBadJSONAPIStructTag
230			break
231		}
232
233		if annotation == annotationPrimary {
234			v := fieldValue
235
236			// Deal with PTRS
237			var kind reflect.Kind
238			if fieldValue.Kind() == reflect.Ptr {
239				kind = fieldType.Type.Elem().Kind()
240				v = reflect.Indirect(fieldValue)
241			} else {
242				kind = fieldType.Type.Kind()
243			}
244
245			// Handle allowed types
246			switch kind {
247			case reflect.String:
248				node.ID = v.Interface().(string)
249			case reflect.Int:
250				node.ID = strconv.FormatInt(int64(v.Interface().(int)), 10)
251			case reflect.Int8:
252				node.ID = strconv.FormatInt(int64(v.Interface().(int8)), 10)
253			case reflect.Int16:
254				node.ID = strconv.FormatInt(int64(v.Interface().(int16)), 10)
255			case reflect.Int32:
256				node.ID = strconv.FormatInt(int64(v.Interface().(int32)), 10)
257			case reflect.Int64:
258				node.ID = strconv.FormatInt(v.Interface().(int64), 10)
259			case reflect.Uint:
260				node.ID = strconv.FormatUint(uint64(v.Interface().(uint)), 10)
261			case reflect.Uint8:
262				node.ID = strconv.FormatUint(uint64(v.Interface().(uint8)), 10)
263			case reflect.Uint16:
264				node.ID = strconv.FormatUint(uint64(v.Interface().(uint16)), 10)
265			case reflect.Uint32:
266				node.ID = strconv.FormatUint(uint64(v.Interface().(uint32)), 10)
267			case reflect.Uint64:
268				node.ID = strconv.FormatUint(v.Interface().(uint64), 10)
269			default:
270				// We had a JSON float (numeric), but our field was not one of the
271				// allowed numeric types
272				er = ErrBadJSONAPIID
273			}
274
275			if er != nil {
276				break
277			}
278
279			node.Type = args[1]
280		} else if annotation == annotationClientID {
281			clientID := fieldValue.String()
282			if clientID != "" {
283				node.ClientID = clientID
284			}
285		} else if annotation == annotationAttribute {
286			var omitEmpty, iso8601, rfc3339 bool
287
288			if len(args) > 2 {
289				for _, arg := range args[2:] {
290					switch arg {
291					case annotationOmitEmpty:
292						omitEmpty = true
293					case annotationISO8601:
294						iso8601 = true
295					case annotationRFC3339:
296						rfc3339 = true
297					}
298				}
299			}
300
301			if node.Attributes == nil {
302				node.Attributes = make(map[string]interface{})
303			}
304
305			if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
306				t := fieldValue.Interface().(time.Time)
307
308				if t.IsZero() {
309					continue
310				}
311
312				if iso8601 {
313					node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat)
314				} else if rfc3339 {
315					node.Attributes[args[1]] = t.UTC().Format(time.RFC3339)
316				} else {
317					node.Attributes[args[1]] = t.Unix()
318				}
319			} else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) {
320				// A time pointer may be nil
321				if fieldValue.IsNil() {
322					if omitEmpty {
323						continue
324					}
325
326					node.Attributes[args[1]] = nil
327				} else {
328					tm := fieldValue.Interface().(*time.Time)
329
330					if tm.IsZero() && omitEmpty {
331						continue
332					}
333
334					if iso8601 {
335						node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat)
336					} else if rfc3339 {
337						node.Attributes[args[1]] = tm.UTC().Format(time.RFC3339)
338					} else {
339						node.Attributes[args[1]] = tm.Unix()
340					}
341				}
342			} else {
343				// Dealing with a fieldValue that is not a time
344				emptyValue := reflect.Zero(fieldValue.Type())
345
346				// See if we need to omit this field
347				if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) {
348					continue
349				}
350
351				strAttr, ok := fieldValue.Interface().(string)
352				if ok {
353					node.Attributes[args[1]] = strAttr
354				} else {
355					node.Attributes[args[1]] = fieldValue.Interface()
356				}
357			}
358		} else if annotation == annotationRelation {
359			var omitEmpty bool
360
361			//add support for 'omitempty' struct tag for marshaling as absent
362			if len(args) > 2 {
363				omitEmpty = args[2] == annotationOmitEmpty
364			}
365
366			isSlice := fieldValue.Type().Kind() == reflect.Slice
367			if omitEmpty &&
368				(isSlice && fieldValue.Len() < 1 ||
369					(!isSlice && fieldValue.IsNil())) {
370				continue
371			}
372
373			if node.Relationships == nil {
374				node.Relationships = make(map[string]interface{})
375			}
376
377			var relLinks *Links
378			if linkableModel, ok := model.(RelationshipLinkable); ok {
379				relLinks = linkableModel.JSONAPIRelationshipLinks(args[1])
380			}
381
382			var relMeta *Meta
383			if metableModel, ok := model.(RelationshipMetable); ok {
384				relMeta = metableModel.JSONAPIRelationshipMeta(args[1])
385			}
386
387			if isSlice {
388				// to-many relationship
389				relationship, err := visitModelNodeRelationships(
390					fieldValue,
391					included,
392					sideload,
393				)
394				if err != nil {
395					er = err
396					break
397				}
398				relationship.Links = relLinks
399				relationship.Meta = relMeta
400
401				if sideload {
402					shallowNodes := []*Node{}
403					for _, n := range relationship.Data {
404						appendIncluded(included, n)
405						shallowNodes = append(shallowNodes, toShallowNode(n))
406					}
407
408					node.Relationships[args[1]] = &RelationshipManyNode{
409						Data:  shallowNodes,
410						Links: relationship.Links,
411						Meta:  relationship.Meta,
412					}
413				} else {
414					node.Relationships[args[1]] = relationship
415				}
416			} else {
417				// to-one relationships
418
419				// Handle null relationship case
420				if fieldValue.IsNil() {
421					node.Relationships[args[1]] = &RelationshipOneNode{Data: nil}
422					continue
423				}
424
425				relationship, err := visitModelNode(
426					fieldValue.Interface(),
427					included,
428					sideload,
429				)
430				if err != nil {
431					er = err
432					break
433				}
434
435				if sideload {
436					appendIncluded(included, relationship)
437					node.Relationships[args[1]] = &RelationshipOneNode{
438						Data:  toShallowNode(relationship),
439						Links: relLinks,
440						Meta:  relMeta,
441					}
442				} else {
443					node.Relationships[args[1]] = &RelationshipOneNode{
444						Data:  relationship,
445						Links: relLinks,
446						Meta:  relMeta,
447					}
448				}
449			}
450		} else if annotation == annotationLinks {
451			// Nothing. Ignore this field, as Links fields are only for unmarshaling requests.
452			// The Linkable interface methods are used for marshaling data in a response.
453		} else {
454			er = ErrBadJSONAPIStructTag
455			break
456		}
457	}
458
459	if er != nil {
460		return nil, er
461	}
462
463	if linkableModel, isLinkable := model.(Linkable); isLinkable {
464		jl := linkableModel.JSONAPILinks()
465		if er := jl.validate(); er != nil {
466			return nil, er
467		}
468		node.Links = linkableModel.JSONAPILinks()
469	}
470
471	if metableModel, ok := model.(Metable); ok {
472		node.Meta = metableModel.JSONAPIMeta()
473	}
474
475	return node, nil
476}
477
478func toShallowNode(node *Node) *Node {
479	return &Node{
480		ID:   node.ID,
481		Type: node.Type,
482	}
483}
484
485func visitModelNodeRelationships(models reflect.Value, included *map[string]*Node,
486	sideload bool) (*RelationshipManyNode, error) {
487	nodes := []*Node{}
488
489	for i := 0; i < models.Len(); i++ {
490		n := models.Index(i).Interface()
491
492		node, err := visitModelNode(n, included, sideload)
493		if err != nil {
494			return nil, err
495		}
496
497		nodes = append(nodes, node)
498	}
499
500	return &RelationshipManyNode{Data: nodes}, nil
501}
502
503func appendIncluded(m *map[string]*Node, nodes ...*Node) {
504	included := *m
505
506	for _, n := range nodes {
507		k := fmt.Sprintf("%s,%s", n.Type, n.ID)
508
509		if _, hasNode := included[k]; hasNode {
510			continue
511		}
512
513		included[k] = n
514	}
515}
516
517func nodeMapValues(m *map[string]*Node) []*Node {
518	mp := *m
519	nodes := make([]*Node, len(mp))
520
521	i := 0
522	for _, n := range mp {
523		nodes[i] = n
524		i++
525	}
526
527	return nodes
528}
529
530func convertToSliceInterface(i *interface{}) ([]interface{}, error) {
531	vals := reflect.ValueOf(*i)
532	if vals.Kind() != reflect.Slice {
533		return nil, ErrExpectedSlice
534	}
535	var response []interface{}
536	for x := 0; x < vals.Len(); x++ {
537		response = append(response, vals.Index(x).Interface())
538	}
539	return response, nil
540}
541