1package viper
2
3import (
4	"fmt"
5	"strings"
6	"testing"
7
8	"github.com/spf13/cast"
9	"github.com/stretchr/testify/assert"
10)
11
12type layer int
13
14const (
15	defaultLayer layer = iota + 1
16	overrideLayer
17)
18
19func TestNestedOverrides(t *testing.T) {
20	assert := assert.New(t)
21	var v *Viper
22
23	// Case 0: value overridden by a value
24	overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20
25	override(assert, "tom", 10, "tom", 20)        // "tom" is first given value 10, then overridden by 20
26	overrideDefault(assert, "tom.age", 10, "tom.age", 20)
27	override(assert, "tom.age", 10, "tom.age", 20)
28	overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
29	override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
30
31	// Case 1: key:value overridden by a value
32	v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy"
33	assert.Nil(v.Get("tom.age"))                             // "tom.age" should not exist anymore
34	v = override(assert, "tom.age", 10, "tom", "boy")
35	assert.Nil(v.Get("tom.age"))
36
37	// Case 2: value overridden by a key:value
38	overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10}
39	override(assert, "tom.age", 10, "tom", "boy")
40
41	// Case 3: key:value overridden by a key:value
42	v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
43	assert.Equal(4, v.Get("tom.size")) // value should still be reachable
44	v = override(assert, "tom.size", 4, "tom.age", 10)
45	assert.Equal(4, v.Get("tom.size"))
46	deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
47
48	// Case 4: key:value overridden by a map
49	v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
50	assert.Equal(4, v.Get("tom.size"))                                                   // "tom.size" should still be reachable
51	assert.Equal(10, v.Get("tom.age"))                                                   // new value should be there
52	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)                 // new value should be there
53	v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
54	assert.Nil(v.Get("tom.size"))
55	assert.Equal(10, v.Get("tom.age"))
56	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
57
58	// Case 5: array overridden by a value
59	overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
60	override(assert, "tom", []int{10, 20}, "tom", 30)
61	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
62	override(assert, "tom.age", []int{10, 20}, "tom.age", 30)
63
64	// Case 6: array overridden by an array
65	overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
66	override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
67	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
68	v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
69	// explicit array merge:
70	s, ok := v.Get("tom.age").([]int)
71	if assert.True(ok, "tom[\"age\"] is not a slice") {
72		v.Set("tom.age", append(s, []int{50, 60}...))
73		assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
74		deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
75	}
76}
77
78func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
79	return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
80}
81func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
82	return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
83}
84
85// overrideFromLayer performs the sequential override and low-level checks.
86//
87// First assignment is made on layer l for path firstPath with value firstValue,
88// the second one on the override layer (i.e., with the Set() function)
89// for path secondPath with value secondValue.
90//
91// firstPath and secondPath can include an arbitrary number of dots to indicate
92// a nested element.
93//
94// After each assignment, the value is checked, retrieved both by its full path
95// and by its key sequence (successive maps).
96func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
97	v := New()
98	firstKeys := strings.Split(firstPath, v.keyDelim)
99	if assert == nil ||
100		len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
101		return v
102	}
103
104	// Set and check first value
105	switch l {
106	case defaultLayer:
107		v.SetDefault(firstPath, firstValue)
108	case overrideLayer:
109		v.Set(firstPath, firstValue)
110	default:
111		return v
112	}
113	assert.Equal(firstValue, v.Get(firstPath))
114	deepCheckValue(assert, v, l, firstKeys, firstValue)
115
116	// Override and check new value
117	secondKeys := strings.Split(secondPath, v.keyDelim)
118	if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
119		return v
120	}
121	v.Set(secondPath, secondValue)
122	assert.Equal(secondValue, v.Get(secondPath))
123	deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)
124
125	return v
126}
127
128// deepCheckValue checks that all given keys correspond to a valid path in the
129// configuration map of the given layer, and that the final value equals the one given
130func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
131	if assert == nil || v == nil ||
132		len(keys) == 0 || len(keys[0]) == 0 {
133		return
134	}
135
136	// init
137	var val interface{}
138	var ms string
139	switch l {
140	case defaultLayer:
141		val = v.defaults
142		ms = "v.defaults"
143	case overrideLayer:
144		val = v.override
145		ms = "v.override"
146	}
147
148	// loop through map
149	var m map[string]interface{}
150	err := false
151	for _, k := range keys {
152		if val == nil {
153			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
154			return
155		}
156
157		// deep scan of the map to get the final value
158		switch val.(type) {
159		case map[interface{}]interface{}:
160			m = cast.ToStringMap(val)
161		case map[string]interface{}:
162			m = val.(map[string]interface{})
163		default:
164			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
165			return
166		}
167		ms = ms + "[\"" + k + "\"]"
168		val = m[k]
169	}
170	if !err {
171		assert.Equal(value, val)
172	}
173}
174