1package main
2
3import (
4	"flag"
5	"fmt"
6	"strings"
7)
8
9// PlatformFlag is a flag.Value (and flag.Getter) implementation that
10// is used to track the os/arch flags on the command-line.
11type PlatformFlag struct {
12	OS     []string
13	Arch   []string
14	OSArch []Platform
15}
16
17// Platforms returns the list of platforms that were set by this flag.
18// The default set of platforms must be passed in.
19func (p *PlatformFlag) Platforms(supported []Platform) []Platform {
20	// NOTE: Reading this method alone is a bit hard to understand. It
21	// is much easier to understand this method if you pair this with the
22	// table of test cases it has.
23
24	// Build a list of OS and archs NOT to build
25	ignoreArch := make(map[string]struct{})
26	includeArch := make(map[string]struct{})
27	ignoreOS := make(map[string]struct{})
28	includeOS := make(map[string]struct{})
29	ignoreOSArch := make(map[string]Platform)
30	includeOSArch := make(map[string]Platform)
31	for _, v := range p.Arch {
32		if v[0] == '!' {
33			ignoreArch[v[1:]] = struct{}{}
34		} else {
35			includeArch[v] = struct{}{}
36		}
37	}
38	for _, v := range p.OS {
39		if v[0] == '!' {
40			ignoreOS[v[1:]] = struct{}{}
41		} else {
42			includeOS[v] = struct{}{}
43		}
44	}
45	for _, v := range p.OSArch {
46		if v.OS[0] == '!' {
47			v = Platform{
48				OS:   v.OS[1:],
49				Arch: v.Arch,
50			}
51
52			ignoreOSArch[v.String()] = v
53		} else {
54			includeOSArch[v.String()] = v
55		}
56	}
57
58	// We're building a list of new platforms, so build the list
59	// based only on the configured OS/arch pairs.
60	var prefilter []Platform = nil
61	if len(includeOSArch) > 0 {
62		prefilter = make([]Platform, 0, len(p.Arch)*len(p.OS)+len(includeOSArch))
63		for _, v := range includeOSArch {
64			prefilter = append(prefilter, v)
65		}
66	}
67
68	if len(includeOS) > 0 && len(includeArch) > 0 {
69		// Build up the list of prefiltered by what is specified
70		if prefilter == nil {
71			prefilter = make([]Platform, 0, len(p.Arch)*len(p.OS))
72		}
73
74		for _, os := range p.OS {
75			if _, ok := includeOS[os]; !ok {
76				continue
77			}
78
79			for _, arch := range p.Arch {
80				if _, ok := includeArch[arch]; !ok {
81					continue
82				}
83
84				prefilter = append(prefilter, Platform{
85					OS:   os,
86					Arch: arch,
87				})
88			}
89		}
90	} else if len(includeOS) > 0 {
91		// Build up the list of prefiltered by what is specified
92		if prefilter == nil {
93			prefilter = make([]Platform, 0, len(p.Arch)*len(p.OS))
94		}
95
96		for _, os := range p.OS {
97			for _, platform := range supported {
98				if platform.OS == os {
99					prefilter = append(prefilter, platform)
100				}
101			}
102		}
103	}
104
105	if prefilter != nil {
106		// Remove any that aren't supported
107		result := make([]Platform, 0, len(prefilter))
108		for _, pending := range prefilter {
109			found := false
110			for _, platform := range supported {
111				if pending.String() == platform.String() {
112					found = true
113					break
114				}
115			}
116
117			if found {
118				add := pending
119				add.Default = false
120				result = append(result, add)
121			}
122		}
123
124		prefilter = result
125	}
126
127	if prefilter == nil {
128		prefilter = make([]Platform, 0, len(supported))
129		for _, v := range supported {
130			if v.Default {
131				add := v
132				add.Default = false
133				prefilter = append(prefilter, add)
134			}
135		}
136	}
137
138	// Go through each default platform and filter out the bad ones
139	result := make([]Platform, 0, len(prefilter))
140	for _, platform := range prefilter {
141		if len(ignoreOSArch) > 0 {
142			if _, ok := ignoreOSArch[platform.String()]; ok {
143				continue
144			}
145		}
146
147		// We only want to check the components (OS and Arch) if we didn't
148		// specifically ask to include it via the osarch.
149		checkComponents := true
150		if len(includeOSArch) > 0 {
151			if _, ok := includeOSArch[platform.String()]; ok {
152				checkComponents = false
153			}
154		}
155
156		if checkComponents {
157			if len(ignoreArch) > 0 {
158				if _, ok := ignoreArch[platform.Arch]; ok {
159					continue
160				}
161			}
162			if len(ignoreOS) > 0 {
163				if _, ok := ignoreOS[platform.OS]; ok {
164					continue
165				}
166			}
167			if len(includeArch) > 0 {
168				if _, ok := includeArch[platform.Arch]; !ok {
169					continue
170				}
171			}
172			if len(includeOS) > 0 {
173				if _, ok := includeOS[platform.OS]; !ok {
174					continue
175				}
176			}
177		}
178
179		result = append(result, platform)
180	}
181
182	return result
183}
184
185// ArchFlagValue returns a flag.Value that can be used with the flag
186// package to collect the arches for the flag.
187func (p *PlatformFlag) ArchFlagValue() flag.Value {
188	return (*appendStringValue)(&p.Arch)
189}
190
191// OSFlagValue returns a flag.Value that can be used with the flag
192// package to collect the operating systems for the flag.
193func (p *PlatformFlag) OSFlagValue() flag.Value {
194	return (*appendStringValue)(&p.OS)
195}
196
197// OSArchFlagValue returns a flag.Value that can be used with the flag
198// package to collect complete os and arch pairs for the flag.
199func (p *PlatformFlag) OSArchFlagValue() flag.Value {
200	return (*appendPlatformValue)(&p.OSArch)
201}
202
203// appendPlatformValue is a flag.Value that appends a full platform (os/arch)
204// to a list where the values from space-separated lines. This is used to
205// satisfy the -osarch flag.
206type appendPlatformValue []Platform
207
208func (s *appendPlatformValue) String() string {
209	return ""
210}
211
212func (s *appendPlatformValue) Set(value string) error {
213	if value == "" {
214		return nil
215	}
216
217	for _, v := range strings.Split(value, " ") {
218		parts := strings.Split(v, "/")
219		if len(parts) != 2 {
220			return fmt.Errorf(
221				"Invalid platform syntax: %s should be os/arch", v)
222		}
223
224		platform := Platform{
225			OS:   strings.ToLower(parts[0]),
226			Arch: strings.ToLower(parts[1]),
227		}
228
229		s.appendIfMissing(&platform)
230	}
231
232	return nil
233}
234
235func (s *appendPlatformValue) appendIfMissing(value *Platform) {
236	for _, existing := range *s {
237		if existing == *value {
238			return
239		}
240	}
241
242	*s = append(*s, *value)
243}
244
245// appendStringValue is a flag.Value that appends values to the list,
246// where the values come from space-separated lines. This is used to
247// satisfy the -os="windows linux" flag to become []string{"windows", "linux"}
248type appendStringValue []string
249
250func (s *appendStringValue) String() string {
251	return strings.Join(*s, " ")
252}
253
254func (s *appendStringValue) Set(value string) error {
255	for _, v := range strings.Split(value, " ") {
256		if v != "" {
257			s.appendIfMissing(strings.ToLower(v))
258		}
259	}
260
261	return nil
262}
263
264func (s *appendStringValue) appendIfMissing(value string) {
265	for _, existing := range *s {
266		if existing == value {
267			return
268		}
269	}
270
271	*s = append(*s, value)
272}
273