1package graphql
2
3import (
4	"encoding/json"
5	"fmt"
6	"math"
7	"reflect"
8	"sort"
9	"strings"
10
11	"github.com/graphql-go/graphql/gqlerrors"
12	"github.com/graphql-go/graphql/language/ast"
13	"github.com/graphql-go/graphql/language/kinds"
14	"github.com/graphql-go/graphql/language/printer"
15)
16
17// Prepares an object map of variableValues of the correct type based on the
18// provided variable definitions and arbitrary input. If the input cannot be
19// parsed to match the variable definitions, a GraphQLError will be returned.
20func getVariableValues(
21	schema Schema,
22	definitionASTs []*ast.VariableDefinition,
23	inputs map[string]interface{}) (map[string]interface{}, error) {
24	values := map[string]interface{}{}
25	for _, defAST := range definitionASTs {
26		if defAST == nil || defAST.Variable == nil || defAST.Variable.Name == nil {
27			continue
28		}
29		varName := defAST.Variable.Name.Value
30		if varValue, err := getVariableValue(schema, defAST, inputs[varName]); err != nil {
31			return values, err
32		} else {
33			values[varName] = varValue
34		}
35	}
36	return values, nil
37}
38
39// Prepares an object map of argument values given a list of argument
40// definitions and list of argument AST nodes.
41func getArgumentValues(
42	argDefs []*Argument, argASTs []*ast.Argument,
43	variableValues map[string]interface{}) map[string]interface{} {
44
45	argASTMap := map[string]*ast.Argument{}
46	for _, argAST := range argASTs {
47		if argAST.Name != nil {
48			argASTMap[argAST.Name.Value] = argAST
49		}
50	}
51	results := map[string]interface{}{}
52	for _, argDef := range argDefs {
53		var (
54			tmp   interface{}
55			value ast.Value
56		)
57		if tmpValue, ok := argASTMap[argDef.PrivateName]; ok {
58			value = tmpValue.Value
59		}
60		if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) {
61			tmp = argDef.DefaultValue
62		}
63		if !isNullish(tmp) {
64			results[argDef.PrivateName] = tmp
65		}
66	}
67	return results
68}
69
70// Given a variable definition, and any value of input, return a value which
71// adheres to the variable definition, or throw an error.
72func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, input interface{}) (interface{}, error) {
73	ttype, err := typeFromAST(schema, definitionAST.Type)
74	if err != nil {
75		return nil, err
76	}
77	variable := definitionAST.Variable
78
79	if ttype == nil || !IsInputType(ttype) {
80		return "", gqlerrors.NewError(
81			fmt.Sprintf(`Variable "$%v" expected value of type `+
82				`"%v" which cannot be used as an input type.`, variable.Name.Value, printer.Print(definitionAST.Type)),
83			[]ast.Node{definitionAST},
84			"",
85			nil,
86			[]int{},
87			nil,
88		)
89	}
90
91	isValid, messages := isValidInputValue(input, ttype)
92	if isValid {
93		if isNullish(input) {
94			if definitionAST.DefaultValue != nil {
95				return valueFromAST(definitionAST.DefaultValue, ttype, nil), nil
96			}
97		}
98		return coerceValue(ttype, input), nil
99	}
100	if isNullish(input) {
101		return "", gqlerrors.NewError(
102			fmt.Sprintf(`Variable "$%v" of required type `+
103				`"%v" was not provided.`, variable.Name.Value, printer.Print(definitionAST.Type)),
104			[]ast.Node{definitionAST},
105			"",
106			nil,
107			[]int{},
108			nil,
109		)
110	}
111	// convert input interface into string for error message
112	bts, _ := json.Marshal(input)
113	var (
114		inputStr = string(bts)
115		msg      string
116	)
117	if len(messages) > 0 {
118		msg = "\n" + strings.Join(messages, "\n")
119	}
120
121	return "", gqlerrors.NewError(
122		fmt.Sprintf(`Variable "$%v" got invalid value `+
123			`%v.%v`, variable.Name.Value, inputStr, msg),
124		[]ast.Node{definitionAST},
125		"",
126		nil,
127		[]int{},
128		nil,
129	)
130}
131
132// Given a type and any value, return a runtime value coerced to match the type.
133func coerceValue(ttype Input, value interface{}) interface{} {
134	if isNullish(value) {
135		return nil
136	}
137	switch ttype := ttype.(type) {
138	case *NonNull:
139		return coerceValue(ttype.OfType, value)
140	case *List:
141		var values = []interface{}{}
142		valType := reflect.ValueOf(value)
143		if valType.Kind() == reflect.Slice {
144			for i := 0; i < valType.Len(); i++ {
145				val := valType.Index(i).Interface()
146				values = append(values, coerceValue(ttype.OfType, val))
147			}
148			return values
149		}
150		return append(values, coerceValue(ttype.OfType, value))
151	case *InputObject:
152		var obj = map[string]interface{}{}
153		valueMap, _ := value.(map[string]interface{})
154		if valueMap == nil {
155			valueMap = map[string]interface{}{}
156		}
157
158		for name, field := range ttype.Fields() {
159			fieldValue := coerceValue(field.Type, valueMap[name])
160			if isNullish(fieldValue) {
161				fieldValue = field.DefaultValue
162			}
163			if !isNullish(fieldValue) {
164				obj[name] = fieldValue
165			}
166		}
167		return obj
168	case *Scalar:
169		if parsed := ttype.ParseValue(value); !isNullish(parsed) {
170			return parsed
171		}
172	case *Enum:
173		if parsed := ttype.ParseValue(value); !isNullish(parsed) {
174			return parsed
175		}
176	}
177
178	return nil
179}
180
181// graphql-js/src/utilities.js`
182// TODO: figure out where to organize utils
183// TODO: change to *Schema
184func typeFromAST(schema Schema, inputTypeAST ast.Type) (Type, error) {
185	switch inputTypeAST := inputTypeAST.(type) {
186	case *ast.List:
187		innerType, err := typeFromAST(schema, inputTypeAST.Type)
188		if err != nil {
189			return nil, err
190		}
191		return NewList(innerType), nil
192	case *ast.NonNull:
193		innerType, err := typeFromAST(schema, inputTypeAST.Type)
194		if err != nil {
195			return nil, err
196		}
197		return NewNonNull(innerType), nil
198	case *ast.Named:
199		nameValue := ""
200		if inputTypeAST.Name != nil {
201			nameValue = inputTypeAST.Name.Value
202		}
203		ttype := schema.Type(nameValue)
204		return ttype, nil
205	default:
206		return nil, invariant(inputTypeAST.GetKind() == kinds.Named, "Must be a named type.")
207	}
208}
209
210// isValidInputValue alias isValidJSValue
211// Given a value and a GraphQL type, determine if the value will be
212// accepted for that type. This is primarily useful for validating the
213// runtime values of query variables.
214func isValidInputValue(value interface{}, ttype Input) (bool, []string) {
215	if isNullish(value) {
216		if ttype, ok := ttype.(*NonNull); ok {
217			if ttype.OfType.Name() != "" {
218				return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())}
219			}
220			return false, []string{"Expected non-null value, found null."}
221		}
222		return true, nil
223	}
224	switch ttype := ttype.(type) {
225	case *NonNull:
226		return isValidInputValue(value, ttype.OfType)
227	case *List:
228		valType := reflect.ValueOf(value)
229		if valType.Kind() == reflect.Ptr {
230			valType = valType.Elem()
231		}
232		if valType.Kind() == reflect.Slice {
233			messagesReduce := []string{}
234			for i := 0; i < valType.Len(); i++ {
235				val := valType.Index(i).Interface()
236				_, messages := isValidInputValue(val, ttype.OfType)
237				for idx, message := range messages {
238					messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, idx+1, message))
239				}
240			}
241			return (len(messagesReduce) == 0), messagesReduce
242		}
243		return isValidInputValue(value, ttype.OfType)
244
245	case *InputObject:
246		messagesReduce := []string{}
247
248		valueMap, ok := value.(map[string]interface{})
249		if !ok {
250			return false, []string{fmt.Sprintf(`Expected "%v", found not an object.`, ttype.Name())}
251		}
252		fields := ttype.Fields()
253
254		// to ensure stable order of field evaluation
255		fieldNames := []string{}
256		valueMapFieldNames := []string{}
257
258		for fieldName := range fields {
259			fieldNames = append(fieldNames, fieldName)
260		}
261		sort.Strings(fieldNames)
262
263		for fieldName := range valueMap {
264			valueMapFieldNames = append(valueMapFieldNames, fieldName)
265		}
266		sort.Strings(valueMapFieldNames)
267
268		// Ensure every provided field is defined.
269		for _, fieldName := range valueMapFieldNames {
270			if _, ok := fields[fieldName]; !ok {
271				messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": Unknown field.`, fieldName))
272			}
273		}
274
275		// Ensure every defined field is valid.
276		for _, fieldName := range fieldNames {
277			_, messages := isValidInputValue(valueMap[fieldName], fields[fieldName].Type)
278			if messages != nil {
279				for _, message := range messages {
280					messagesReduce = append(messagesReduce, fmt.Sprintf(`In field "%v": %v`, fieldName, message))
281				}
282			}
283		}
284		return (len(messagesReduce) == 0), messagesReduce
285	case *Scalar:
286		if parsedVal := ttype.ParseValue(value); isNullish(parsedVal) {
287			return false, []string{fmt.Sprintf(`Expected type "%v", found "%v".`, ttype.Name(), value)}
288		}
289	case *Enum:
290		if parsedVal := ttype.ParseValue(value); isNullish(parsedVal) {
291			return false, []string{fmt.Sprintf(`Expected type "%v", found "%v".`, ttype.Name(), value)}
292		}
293	}
294
295	return true, nil
296}
297
298// Returns true if a value is null, undefined, or NaN.
299func isNullish(src interface{}) bool {
300	if src == nil {
301		return true
302	}
303	value := reflect.ValueOf(src)
304	if value.Kind() == reflect.Ptr {
305		if value.IsNil() {
306			return true
307		}
308		value = value.Elem()
309	}
310	switch value.Kind() {
311	case reflect.String:
312		// if src is ptr type and len(string)=0, it returns false
313		if !value.IsValid() {
314			return true
315		}
316	case reflect.Int:
317		return math.IsNaN(float64(value.Int()))
318	case reflect.Float32, reflect.Float64:
319		return math.IsNaN(float64(value.Float()))
320	}
321	return false
322}
323
324// Returns true if src is a slice or an array
325func isIterable(src interface{}) bool {
326	if src == nil {
327		return false
328	}
329	t := reflect.TypeOf(src)
330	if t.Kind() == reflect.Ptr {
331		t = t.Elem()
332	}
333	return t.Kind() == reflect.Slice || t.Kind() == reflect.Array
334}
335
336/**
337 * Produces a value given a GraphQL Value AST.
338 *
339 * A GraphQL type must be provided, which will be used to interpret different
340 * GraphQL Value literals.
341 *
342 * | GraphQL Value        | JSON Value    |
343 * | -------------------- | ------------- |
344 * | Input Object         | Object        |
345 * | List                 | Array         |
346 * | Boolean              | Boolean       |
347 * | String / Enum Value  | String        |
348 * | Int / Float          | Number        |
349 *
350 */
351func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interface{}) interface{} {
352	if valueAST == nil {
353		return nil
354	}
355	// precedence: value > type
356	if valueAST, ok := valueAST.(*ast.Variable); ok {
357		if valueAST.Name == nil || variables == nil {
358			return nil
359		}
360		// Note: we're not doing any checking that this variable is correct. We're
361		// assuming that this query has been validated and the variable usage here
362		// is of the correct type.
363		return variables[valueAST.Name.Value]
364	}
365	switch ttype := ttype.(type) {
366	case *NonNull:
367		return valueFromAST(valueAST, ttype.OfType, variables)
368	case *List:
369		values := []interface{}{}
370		if valueAST, ok := valueAST.(*ast.ListValue); ok {
371			for _, itemAST := range valueAST.Values {
372				values = append(values, valueFromAST(itemAST, ttype.OfType, variables))
373			}
374			return values
375		}
376		return append(values, valueFromAST(valueAST, ttype.OfType, variables))
377	case *InputObject:
378		var (
379			ok bool
380			ov *ast.ObjectValue
381			of *ast.ObjectField
382		)
383		if ov, ok = valueAST.(*ast.ObjectValue); !ok {
384			return nil
385		}
386		fieldASTs := map[string]*ast.ObjectField{}
387		for _, of = range ov.Fields {
388			if of == nil || of.Name == nil {
389				continue
390			}
391			fieldASTs[of.Name.Value] = of
392		}
393		obj := map[string]interface{}{}
394		for name, field := range ttype.Fields() {
395			var value interface{}
396			if of, ok = fieldASTs[name]; ok {
397				value = valueFromAST(of.Value, field.Type, variables)
398			} else {
399				value = field.DefaultValue
400			}
401			if !isNullish(value) {
402				obj[name] = value
403			}
404		}
405		return obj
406	case *Scalar:
407		return ttype.ParseLiteral(valueAST)
408	case *Enum:
409		return ttype.ParseLiteral(valueAST)
410	}
411
412	return nil
413}
414
415func invariant(condition bool, message string) error {
416	if !condition {
417		return gqlerrors.NewFormattedError(message)
418	}
419	return nil
420}
421
422func invariantf(condition bool, format string, a ...interface{}) error {
423	if !condition {
424		return gqlerrors.NewFormattedError(fmt.Sprintf(format, a...))
425	}
426	return nil
427}
428