1// Copyright (c) 2021, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package main
16
17import (
18	"bytes"
19	"compress/bzip2"
20	"encoding/json"
21	"flag"
22	"fmt"
23	"io"
24	"io/ioutil"
25	"log"
26	"os"
27	"os/exec"
28	"runtime"
29	"strings"
30	"sync"
31	"sync/atomic"
32)
33
34var (
35	toolPath       *string = flag.String("tool", "", "Path to acvptool binary")
36	moduleWrappers *string = flag.String("module-wrappers", "", "Comma-separated list of name:path pairs for known module wrappers")
37	testsPath      *string = flag.String("tests", "", "Path to JSON file listing tests")
38	update         *bool   = flag.Bool("update", false, "If true then write updated outputs")
39)
40
41type invocation struct {
42	toolPath     string
43	wrapperPath  string
44	inPath       string
45	expectedPath string
46}
47
48func main() {
49	flag.Parse()
50
51	if len(*toolPath) == 0 {
52		log.Fatal("-tool must be given")
53	}
54
55	if len(*moduleWrappers) == 0 {
56		log.Fatal("-module-wrappers must be given")
57	}
58
59	wrappers := make(map[string]string)
60	pairs := strings.Split(*moduleWrappers, ",")
61	for _, pair := range pairs {
62		parts := strings.SplitN(pair, ":", 2)
63		if _, ok := wrappers[parts[0]]; ok {
64			log.Fatalf("wrapper %q defined twice", parts[0])
65		}
66		wrappers[parts[0]] = parts[1]
67	}
68
69	if len(*testsPath) == 0 {
70		log.Fatal("-tests must be given")
71	}
72
73	testsFile, err := os.Open(*testsPath)
74	if err != nil {
75		log.Fatal(err)
76	}
77	defer testsFile.Close()
78
79	decoder := json.NewDecoder(testsFile)
80	var tests []struct {
81		Wrapper string
82		In      string
83		Out     string // Optional, may be empty.
84	}
85	if err := decoder.Decode(&tests); err != nil {
86		log.Fatal(err)
87	}
88
89	work := make(chan invocation, runtime.NumCPU())
90	var numFailed uint32
91
92	var wg sync.WaitGroup
93	for i := 0; i < runtime.NumCPU(); i++ {
94		wg.Add(1)
95		go worker(&wg, work, &numFailed)
96	}
97
98	for _, test := range tests {
99		wrapper, ok := wrappers[test.Wrapper]
100		if !ok {
101			log.Fatalf("wrapper %q not specified on command line", test.Wrapper)
102		}
103		work <- invocation{
104			toolPath:     *toolPath,
105			wrapperPath:  wrapper,
106			inPath:       test.In,
107			expectedPath: test.Out,
108		}
109	}
110
111	close(work)
112	wg.Wait()
113
114	n := atomic.LoadUint32(&numFailed)
115	if n > 0 {
116		log.Printf("Failed %d tests", n)
117		os.Exit(1)
118	} else {
119		log.Printf("%d ACVP tests matched expectations", len(tests))
120	}
121}
122
123func worker(wg *sync.WaitGroup, work <-chan invocation, numFailed *uint32) {
124	defer wg.Done()
125
126	for test := range work {
127		if err := doTest(test); err != nil {
128			log.Printf("Test failed for %q: %s", test.inPath, err)
129			atomic.AddUint32(numFailed, 1)
130		}
131	}
132}
133
134func doTest(test invocation) error {
135	input, err := os.Open(test.inPath)
136	if err != nil {
137		return fmt.Errorf("Failed to open %q: %s", test.inPath, err)
138	}
139	defer input.Close()
140
141	tempFile, err := ioutil.TempFile("", "boringssl-check_expected-")
142	if err != nil {
143		return fmt.Errorf("Failed to create temp file: %s", err)
144	}
145	defer os.Remove(tempFile.Name())
146	defer tempFile.Close()
147
148	decompressor := bzip2.NewReader(input)
149	if _, err := io.Copy(tempFile, decompressor); err != nil {
150		return fmt.Errorf("Failed to decompress %q: %s", test.inPath, err)
151	}
152
153	cmd := exec.Command(test.toolPath, "-wrapper", test.wrapperPath, "-json", tempFile.Name())
154	result, err := cmd.CombinedOutput()
155	if err != nil {
156		os.Stderr.Write(result)
157		return fmt.Errorf("Failed to process %q", test.inPath)
158	}
159
160	if len(test.expectedPath) == 0 {
161		// This test has variable output and thus cannot be compared against a fixed
162		// result.
163		return nil
164	}
165
166	expected, err := os.Open(test.expectedPath)
167	if err != nil {
168		if *update {
169			writeUpdate(test.expectedPath, result)
170		}
171		return fmt.Errorf("Failed to open %q: %s", test.expectedPath, err)
172	}
173	defer expected.Close()
174
175	decompressor = bzip2.NewReader(expected)
176
177	var expectedBuf bytes.Buffer
178	if _, err := io.Copy(&expectedBuf, decompressor); err != nil {
179		return fmt.Errorf("Failed to decompress %q: %s", test.expectedPath, err)
180	}
181
182	if !bytes.Equal(expectedBuf.Bytes(), result) {
183		if *update {
184			writeUpdate(test.expectedPath, result)
185		}
186		return fmt.Errorf("Mismatch for %q", test.expectedPath)
187	}
188
189	return nil
190}
191
192func writeUpdate(path string, contents []byte) {
193	if err := ioutil.WriteFile(path, contents, 0644); err != nil {
194		log.Printf("Failed to create missing file %q: %s", path, err)
195	} else {
196		log.Printf("Wrote %q", path)
197	}
198}
199