1package jsonapi
2
3import (
4	"encoding/json"
5	"errors"
6	"fmt"
7	"io"
8	"reflect"
9	"strconv"
10	"strings"
11	"time"
12)
13
14var (
15	// ErrBadJSONAPIStructTag is returned when the Struct field's JSON API
16	// annotation is invalid.
17	ErrBadJSONAPIStructTag = errors.New("Bad jsonapi struct tag format")
18	// ErrBadJSONAPIID is returned when the Struct JSON API annotated "id" field
19	// was not a valid numeric type.
20	ErrBadJSONAPIID = errors.New(
21		"id should be either string, int(8,16,32,64) or uint(8,16,32,64)")
22	// ErrExpectedSlice is returned when a variable or argument was expected to
23	// be a slice of *Structs; MarshalMany will return this error when its
24	// interface{} argument is invalid.
25	ErrExpectedSlice = errors.New("models should be a slice of struct pointers")
26	// ErrUnexpectedType is returned when marshalling an interface; the interface
27	// had to be a pointer or a slice; otherwise this error is returned.
28	ErrUnexpectedType = errors.New("models should be a struct pointer or slice of struct pointers")
29)
30
31// MarshalPayload writes a jsonapi response for one or many records. The
32// related records are sideloaded into the "included" array. If this method is
33// given a struct pointer as an argument it will serialize in the form
34// "data": {...}. If this method is given a slice of pointers, this method will
35// serialize in the form "data": [...]
36//
37// One Example: you could pass it, w, your http.ResponseWriter, and, models, a
38// ptr to a Blog to be written to the response body:
39//
40//	 func ShowBlog(w http.ResponseWriter, r *http.Request) {
41//		 blog := &Blog{}
42//
43//		 w.Header().Set("Content-Type", jsonapi.MediaType)
44//		 w.WriteHeader(http.StatusOK)
45//
46//		 if err := jsonapi.MarshalPayload(w, blog); err != nil {
47//			 http.Error(w, err.Error(), http.StatusInternalServerError)
48//		 }
49//	 }
50//
51// Many Example: you could pass it, w, your http.ResponseWriter, and, models, a
52// slice of Blog struct instance pointers to be written to the response body:
53//
54//	 func ListBlogs(w http.ResponseWriter, r *http.Request) {
55//     blogs := []*Blog{}
56//
57//		 w.Header().Set("Content-Type", jsonapi.MediaType)
58//		 w.WriteHeader(http.StatusOK)
59//
60//		 if err := jsonapi.MarshalPayload(w, blogs); err != nil {
61//			 http.Error(w, err.Error(), http.StatusInternalServerError)
62//		 }
63//	 }
64//
65func MarshalPayload(w io.Writer, models interface{}) error {
66	payload, err := Marshal(models)
67	if err != nil {
68		return err
69	}
70
71	if err := json.NewEncoder(w).Encode(payload); err != nil {
72		return err
73	}
74	return nil
75}
76
77// Marshal does the same as MarshalPayload except it just returns the payload
78// and doesn't write out results. Useful if you use your own JSON rendering
79// library.
80func Marshal(models interface{}) (Payloader, error) {
81	switch vals := reflect.ValueOf(models); vals.Kind() {
82	case reflect.Slice:
83		m, err := convertToSliceInterface(&models)
84		if err != nil {
85			return nil, err
86		}
87
88		payload, err := marshalMany(m)
89		if err != nil {
90			return nil, err
91		}
92
93		if linkableModels, isLinkable := models.(Linkable); isLinkable {
94			jl := linkableModels.JSONAPILinks()
95			if er := jl.validate(); er != nil {
96				return nil, er
97			}
98			payload.Links = linkableModels.JSONAPILinks()
99		}
100
101		if metableModels, ok := models.(Metable); ok {
102			payload.Meta = metableModels.JSONAPIMeta()
103		}
104
105		return payload, nil
106	case reflect.Ptr:
107		// Check that the pointer was to a struct
108		if reflect.Indirect(vals).Kind() != reflect.Struct {
109			return nil, ErrUnexpectedType
110		}
111		return marshalOne(models)
112	default:
113		return nil, ErrUnexpectedType
114	}
115}
116
117// MarshalPayloadWithoutIncluded writes a jsonapi response with one or many
118// records, without the related records sideloaded into "included" array.
119// If you want to serialize the relations into the "included" array see
120// MarshalPayload.
121//
122// models interface{} should be either a struct pointer or a slice of struct
123// pointers.
124func MarshalPayloadWithoutIncluded(w io.Writer, model interface{}) error {
125	payload, err := Marshal(model)
126	if err != nil {
127		return err
128	}
129	payload.clearIncluded()
130
131	if err := json.NewEncoder(w).Encode(payload); err != nil {
132		return err
133	}
134	return nil
135}
136
137// marshalOne does the same as MarshalOnePayload except it just returns the
138// payload and doesn't write out results. Useful is you use your JSON rendering
139// library.
140func marshalOne(model interface{}) (*OnePayload, error) {
141	included := make(map[string]*Node)
142
143	rootNode, err := visitModelNode(model, &included, true)
144	if err != nil {
145		return nil, err
146	}
147	payload := &OnePayload{Data: rootNode}
148
149	payload.Included = nodeMapValues(&included)
150
151	return payload, nil
152}
153
154// marshalMany does the same as MarshalManyPayload except it just returns the
155// payload and doesn't write out results. Useful is you use your JSON rendering
156// library.
157func marshalMany(models []interface{}) (*ManyPayload, error) {
158	payload := &ManyPayload{
159		Data: []*Node{},
160	}
161	included := map[string]*Node{}
162
163	for _, model := range models {
164		node, err := visitModelNode(model, &included, true)
165		if err != nil {
166			return nil, err
167		}
168		payload.Data = append(payload.Data, node)
169	}
170	payload.Included = nodeMapValues(&included)
171
172	return payload, nil
173}
174
175// MarshalOnePayloadEmbedded - This method not meant to for use in
176// implementation code, although feel free.  The purpose of this
177// method is for use in tests.  In most cases, your request
178// payloads for create will be embedded rather than sideloaded for
179// related records. This method will serialize a single struct
180// pointer into an embedded json response. In other words, there
181// will be no, "included", array in the json all relationships will
182// be serailized inline in the data.
183//
184// However, in tests, you may want to construct payloads to post
185// to create methods that are embedded to most closely resemble
186// the payloads that will be produced by the client. This is what
187// this method is intended for.
188//
189// model interface{} should be a pointer to a struct.
190func MarshalOnePayloadEmbedded(w io.Writer, model interface{}) error {
191	rootNode, err := visitModelNode(model, nil, false)
192	if err != nil {
193		return err
194	}
195
196	payload := &OnePayload{Data: rootNode}
197
198	if err := json.NewEncoder(w).Encode(payload); err != nil {
199		return err
200	}
201
202	return nil
203}
204
205func visitModelNode(model interface{}, included *map[string]*Node,
206	sideload bool) (*Node, error) {
207	node := new(Node)
208
209	var er error
210	value := reflect.ValueOf(model)
211	if value.IsNil() {
212		return nil, nil
213	}
214
215	modelValue := value.Elem()
216	modelType := value.Type().Elem()
217
218	for i := 0; i < modelValue.NumField(); i++ {
219		structField := modelValue.Type().Field(i)
220		tag := structField.Tag.Get(annotationJSONAPI)
221		if tag == "" {
222			continue
223		}
224
225		fieldValue := modelValue.Field(i)
226		fieldType := modelType.Field(i)
227
228		args := strings.Split(tag, annotationSeperator)
229
230		if len(args) < 1 {
231			er = ErrBadJSONAPIStructTag
232			break
233		}
234
235		annotation := args[0]
236
237		if (annotation == annotationClientID && len(args) != 1) ||
238			(annotation != annotationClientID && len(args) < 2) {
239			er = ErrBadJSONAPIStructTag
240			break
241		}
242
243		if annotation == annotationPrimary {
244			v := fieldValue
245
246			// Deal with PTRS
247			var kind reflect.Kind
248			if fieldValue.Kind() == reflect.Ptr {
249				kind = fieldType.Type.Elem().Kind()
250				v = reflect.Indirect(fieldValue)
251			} else {
252				kind = fieldType.Type.Kind()
253			}
254
255			// Handle allowed types
256			switch kind {
257			case reflect.String:
258				node.ID = v.Interface().(string)
259			case reflect.Int:
260				node.ID = strconv.FormatInt(int64(v.Interface().(int)), 10)
261			case reflect.Int8:
262				node.ID = strconv.FormatInt(int64(v.Interface().(int8)), 10)
263			case reflect.Int16:
264				node.ID = strconv.FormatInt(int64(v.Interface().(int16)), 10)
265			case reflect.Int32:
266				node.ID = strconv.FormatInt(int64(v.Interface().(int32)), 10)
267			case reflect.Int64:
268				node.ID = strconv.FormatInt(v.Interface().(int64), 10)
269			case reflect.Uint:
270				node.ID = strconv.FormatUint(uint64(v.Interface().(uint)), 10)
271			case reflect.Uint8:
272				node.ID = strconv.FormatUint(uint64(v.Interface().(uint8)), 10)
273			case reflect.Uint16:
274				node.ID = strconv.FormatUint(uint64(v.Interface().(uint16)), 10)
275			case reflect.Uint32:
276				node.ID = strconv.FormatUint(uint64(v.Interface().(uint32)), 10)
277			case reflect.Uint64:
278				node.ID = strconv.FormatUint(v.Interface().(uint64), 10)
279			default:
280				// We had a JSON float (numeric), but our field was not one of the
281				// allowed numeric types
282				er = ErrBadJSONAPIID
283				break
284			}
285
286			node.Type = args[1]
287		} else if annotation == annotationClientID {
288			clientID := fieldValue.String()
289			if clientID != "" {
290				node.ClientID = clientID
291			}
292		} else if annotation == annotationAttribute {
293			var omitEmpty, iso8601 bool
294
295			if len(args) > 2 {
296				for _, arg := range args[2:] {
297					switch arg {
298					case annotationOmitEmpty:
299						omitEmpty = true
300					case annotationISO8601:
301						iso8601 = true
302					}
303				}
304			}
305
306			if node.Attributes == nil {
307				node.Attributes = make(map[string]interface{})
308			}
309
310			if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
311				t := fieldValue.Interface().(time.Time)
312
313				if t.IsZero() {
314					continue
315				}
316
317				if iso8601 {
318					node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat)
319				} else {
320					node.Attributes[args[1]] = t.Unix()
321				}
322			} else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) {
323				// A time pointer may be nil
324				if fieldValue.IsNil() {
325					if omitEmpty {
326						continue
327					}
328
329					node.Attributes[args[1]] = nil
330				} else {
331					tm := fieldValue.Interface().(*time.Time)
332
333					if tm.IsZero() && omitEmpty {
334						continue
335					}
336
337					if iso8601 {
338						node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat)
339					} else {
340						node.Attributes[args[1]] = tm.Unix()
341					}
342				}
343			} else {
344				// Dealing with a fieldValue that is not a time
345				emptyValue := reflect.Zero(fieldValue.Type())
346
347				// See if we need to omit this field
348				if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) {
349					continue
350				}
351
352				strAttr, ok := fieldValue.Interface().(string)
353				if ok {
354					node.Attributes[args[1]] = strAttr
355				} else {
356					node.Attributes[args[1]] = fieldValue.Interface()
357				}
358			}
359		} else if annotation == annotationRelation {
360			var omitEmpty bool
361
362			//add support for 'omitempty' struct tag for marshaling as absent
363			if len(args) > 2 {
364				omitEmpty = args[2] == annotationOmitEmpty
365			}
366
367			isSlice := fieldValue.Type().Kind() == reflect.Slice
368			if omitEmpty &&
369				(isSlice && fieldValue.Len() < 1 ||
370					(!isSlice && fieldValue.IsNil())) {
371				continue
372			}
373
374			if node.Relationships == nil {
375				node.Relationships = make(map[string]interface{})
376			}
377
378			var relLinks *Links
379			if linkableModel, ok := model.(RelationshipLinkable); ok {
380				relLinks = linkableModel.JSONAPIRelationshipLinks(args[1])
381			}
382
383			var relMeta *Meta
384			if metableModel, ok := model.(RelationshipMetable); ok {
385				relMeta = metableModel.JSONAPIRelationshipMeta(args[1])
386			}
387
388			if isSlice {
389				// to-many relationship
390				relationship, err := visitModelNodeRelationships(
391					fieldValue,
392					included,
393					sideload,
394				)
395				if err != nil {
396					er = err
397					break
398				}
399				relationship.Links = relLinks
400				relationship.Meta = relMeta
401
402				if sideload {
403					shallowNodes := []*Node{}
404					for _, n := range relationship.Data {
405						appendIncluded(included, n)
406						shallowNodes = append(shallowNodes, toShallowNode(n))
407					}
408
409					node.Relationships[args[1]] = &RelationshipManyNode{
410						Data:  shallowNodes,
411						Links: relationship.Links,
412						Meta:  relationship.Meta,
413					}
414				} else {
415					node.Relationships[args[1]] = relationship
416				}
417			} else {
418				// to-one relationships
419
420				// Handle null relationship case
421				if fieldValue.IsNil() {
422					node.Relationships[args[1]] = &RelationshipOneNode{Data: nil}
423					continue
424				}
425
426				relationship, err := visitModelNode(
427					fieldValue.Interface(),
428					included,
429					sideload,
430				)
431				if err != nil {
432					er = err
433					break
434				}
435
436				if sideload {
437					appendIncluded(included, relationship)
438					node.Relationships[args[1]] = &RelationshipOneNode{
439						Data:  toShallowNode(relationship),
440						Links: relLinks,
441						Meta:  relMeta,
442					}
443				} else {
444					node.Relationships[args[1]] = &RelationshipOneNode{
445						Data:  relationship,
446						Links: relLinks,
447						Meta:  relMeta,
448					}
449				}
450			}
451
452		} else {
453			er = ErrBadJSONAPIStructTag
454			break
455		}
456	}
457
458	if er != nil {
459		return nil, er
460	}
461
462	if linkableModel, isLinkable := model.(Linkable); isLinkable {
463		jl := linkableModel.JSONAPILinks()
464		if er := jl.validate(); er != nil {
465			return nil, er
466		}
467		node.Links = linkableModel.JSONAPILinks()
468	}
469
470	if metableModel, ok := model.(Metable); ok {
471		node.Meta = metableModel.JSONAPIMeta()
472	}
473
474	return node, nil
475}
476
477func toShallowNode(node *Node) *Node {
478	return &Node{
479		ID:   node.ID,
480		Type: node.Type,
481	}
482}
483
484func visitModelNodeRelationships(models reflect.Value, included *map[string]*Node,
485	sideload bool) (*RelationshipManyNode, error) {
486	nodes := []*Node{}
487
488	for i := 0; i < models.Len(); i++ {
489		n := models.Index(i).Interface()
490
491		node, err := visitModelNode(n, included, sideload)
492		if err != nil {
493			return nil, err
494		}
495
496		nodes = append(nodes, node)
497	}
498
499	return &RelationshipManyNode{Data: nodes}, nil
500}
501
502func appendIncluded(m *map[string]*Node, nodes ...*Node) {
503	included := *m
504
505	for _, n := range nodes {
506		k := fmt.Sprintf("%s,%s", n.Type, n.ID)
507
508		if _, hasNode := included[k]; hasNode {
509			continue
510		}
511
512		included[k] = n
513	}
514}
515
516func nodeMapValues(m *map[string]*Node) []*Node {
517	mp := *m
518	nodes := make([]*Node, len(mp))
519
520	i := 0
521	for _, n := range mp {
522		nodes[i] = n
523		i++
524	}
525
526	return nodes
527}
528
529func convertToSliceInterface(i *interface{}) ([]interface{}, error) {
530	vals := reflect.ValueOf(*i)
531	if vals.Kind() != reflect.Slice {
532		return nil, ErrExpectedSlice
533	}
534	var response []interface{}
535	for x := 0; x < vals.Len(); x++ {
536		response = append(response, vals.Index(x).Interface())
537	}
538	return response, nil
539}
540