summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/spf13/viper/overrides_test.go
blob: dd2aa9b0dbdb3358815255390ab74952d9b452a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package viper

import (
	"fmt"
	"strings"
	"testing"

	"github.com/spf13/cast"
	"github.com/stretchr/testify/assert"
)

type layer int

const (
	defaultLayer layer = iota + 1
	overrideLayer
)

func TestNestedOverrides(t *testing.T) {
	assert := assert.New(t)
	var v *Viper

	// Case 0: value overridden by a value
	overrideDefault(assert, "tom", 10, "tom", 20) // "tom" is first given 10 as default value, then overridden by 20
	override(assert, "tom", 10, "tom", 20)        // "tom" is first given value 10, then overridden by 20
	overrideDefault(assert, "tom.age", 10, "tom.age", 20)
	override(assert, "tom.age", 10, "tom.age", 20)
	overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
	override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)

	// Case 1: key:value overridden by a value
	v = overrideDefault(assert, "tom.age", 10, "tom", "boy") // "tom.age" is first given 10 as default value, then "tom" is overridden by "boy"
	assert.Nil(v.Get("tom.age"))                             // "tom.age" should not exist anymore
	v = override(assert, "tom.age", 10, "tom", "boy")
	assert.Nil(v.Get("tom.age"))

	// Case 2: value overridden by a key:value
	overrideDefault(assert, "tom", "boy", "tom.age", 10) // "tom" is first given "boy" as default value, then "tom" is overridden by map{"age":10}
	override(assert, "tom.age", 10, "tom", "boy")

	// Case 3: key:value overridden by a key:value
	v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
	assert.Equal(4, v.Get("tom.size")) // value should still be reachable
	v = override(assert, "tom.size", 4, "tom.age", 10)
	assert.Equal(4, v.Get("tom.size"))
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)

	// Case 4: key:value overridden by a map
	v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10}) // "tom.size" is first given "4" as default value, then "tom" is overridden by map{"age":10}
	assert.Equal(4, v.Get("tom.size"))                                                   // "tom.size" should still be reachable
	assert.Equal(10, v.Get("tom.age"))                                                   // new value should be there
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)                 // new value should be there
	v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
	assert.Nil(v.Get("tom.size"))
	assert.Equal(10, v.Get("tom.age"))
	deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)

	// Case 5: array overridden by a value
	overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
	override(assert, "tom", []int{10, 20}, "tom", 30)
	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
	override(assert, "tom.age", []int{10, 20}, "tom.age", 30)

	// Case 6: array overridden by an array
	overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
	override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
	overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
	v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
	// explicit array merge:
	s, ok := v.Get("tom.age").([]int)
	if assert.True(ok, "tom[\"age\"] is not a slice") {
		v.Set("tom.age", append(s, []int{50, 60}...))
		assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
		deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
	}
}

func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
}
func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
}

// overrideFromLayer performs the sequential override and low-level checks.
//
// First assignment is made on layer l for path firstPath with value firstValue,
// the second one on the override layer (i.e., with the Set() function)
// for path secondPath with value secondValue.
//
// firstPath and secondPath can include an arbitrary number of dots to indicate
// a nested element.
//
// After each assignment, the value is checked, retrieved both by its full path
// and by its key sequence (successive maps).
func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
	v := New()
	firstKeys := strings.Split(firstPath, v.keyDelim)
	if assert == nil ||
		len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
		return v
	}

	// Set and check first value
	switch l {
	case defaultLayer:
		v.SetDefault(firstPath, firstValue)
	case overrideLayer:
		v.Set(firstPath, firstValue)
	default:
		return v
	}
	assert.Equal(firstValue, v.Get(firstPath))
	deepCheckValue(assert, v, l, firstKeys, firstValue)

	// Override and check new value
	secondKeys := strings.Split(secondPath, v.keyDelim)
	if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
		return v
	}
	v.Set(secondPath, secondValue)
	assert.Equal(secondValue, v.Get(secondPath))
	deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)

	return v
}

// deepCheckValue checks that all given keys correspond to a valid path in the
// configuration map of the given layer, and that the final value equals the one given
func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
	if assert == nil || v == nil ||
		len(keys) == 0 || len(keys[0]) == 0 {
		return
	}

	// init
	var val interface{}
	var ms string
	switch l {
	case defaultLayer:
		val = v.defaults
		ms = "v.defaults"
	case overrideLayer:
		val = v.override
		ms = "v.override"
	}

	// loop through map
	var m map[string]interface{}
	err := false
	for _, k := range keys {
		if val == nil {
			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
			return
		}

		// deep scan of the map to get the final value
		switch val.(type) {
		case map[interface{}]interface{}:
			m = cast.ToStringMap(val)
		case map[string]interface{}:
			m = val.(map[string]interface{})
		default:
			assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
			return
		}
		ms = ms + "[\"" + k + "\"]"
		val = m[k]
	}
	if !err {
		assert.Equal(value, val)
	}
}