1/* $OpenBSD: wycheproof.go,v 1.107 2019/11/28 23:13:34 tb Exp $ */
2/*
3 * Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
4 * Copyright (c) 2018, 2019 Theo Buehler <tb@openbsd.org>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19// Wycheproof runs test vectors from Project Wycheproof against libcrypto.
20package main
21
22/*
23#cgo LDFLAGS: -lcrypto
24
25#include <string.h>
26
27#include <openssl/aes.h>
28#include <openssl/bio.h>
29#include <openssl/bn.h>
30#include <openssl/cmac.h>
31#include <openssl/curve25519.h>
32#include <openssl/dsa.h>
33#include <openssl/ec.h>
34#include <openssl/ecdsa.h>
35#include <openssl/evp.h>
36#include <openssl/hkdf.h>
37#include <openssl/objects.h>
38#include <openssl/pem.h>
39#include <openssl/x509.h>
40#include <openssl/rsa.h>
41*/
42import "C"
43
44import (
45	"bytes"
46	"crypto/sha1"
47	"crypto/sha256"
48	"crypto/sha512"
49	"encoding/base64"
50	"encoding/hex"
51	"encoding/json"
52	"flag"
53	"fmt"
54	"hash"
55	"io/ioutil"
56	"log"
57	"os"
58	"path/filepath"
59	"regexp"
60	"sort"
61	"strings"
62	"unsafe"
63)
64
65const testVectorPath = "/usr/local/share/wycheproof/testvectors"
66
67var acceptableAudit = false
68var acceptableComments map[string]int
69var acceptableFlags map[string]int
70
71type wycheproofJWKPublic struct {
72	Crv string `json:"crv"`
73	KID string `json:"kid"`
74	KTY string `json:"kty"`
75	X   string `json:"x"`
76	Y   string `json:"y"`
77}
78
79type wycheproofJWKPrivate struct {
80	Crv string `json:"crv"`
81	D   string `json:"d"`
82	KID string `json:"kid"`
83	KTY string `json:"kty"`
84	X   string `json:"x"`
85	Y   string `json:"y"`
86}
87
88type wycheproofTestGroupAesCbcPkcs5 struct {
89	IVSize  int                          `json:"ivSize"`
90	KeySize int                          `json:"keySize"`
91	Type    string                       `json:"type"`
92	Tests   []*wycheproofTestAesCbcPkcs5 `json:"tests"`
93}
94
95type wycheproofTestAesCbcPkcs5 struct {
96	TCID    int      `json:"tcId"`
97	Comment string   `json:"comment"`
98	Key     string   `json:"key"`
99	IV      string   `json:"iv"`
100	Msg     string   `json:"msg"`
101	CT      string   `json:"ct"`
102	Result  string   `json:"result"`
103	Flags   []string `json:"flags"`
104}
105
106type wycheproofTestGroupAead struct {
107	IVSize  int                   `json:"ivSize"`
108	KeySize int                   `json:"keySize"`
109	TagSize int                   `json:"tagSize"`
110	Type    string                `json:"type"`
111	Tests   []*wycheproofTestAead `json:"tests"`
112}
113
114type wycheproofTestAead struct {
115	TCID    int      `json:"tcId"`
116	Comment string   `json:"comment"`
117	Key     string   `json:"key"`
118	IV      string   `json:"iv"`
119	AAD     string   `json:"aad"`
120	Msg     string   `json:"msg"`
121	CT      string   `json:"ct"`
122	Tag     string   `json:"tag"`
123	Result  string   `json:"result"`
124	Flags   []string `json:"flags"`
125}
126
127type wycheproofTestGroupAesCmac struct {
128	KeySize int                      `json:"keySize"`
129	TagSize int                      `json:"tagSize"`
130	Type    string                   `json:"type"`
131	Tests   []*wycheproofTestAesCmac `json:"tests"`
132}
133
134type wycheproofTestAesCmac struct {
135	TCID    int      `json:"tcId"`
136	Comment string   `json:"comment"`
137	Key     string   `json:"key"`
138	Msg     string   `json:"msg"`
139	Tag     string   `json:"tag"`
140	Result  string   `json:"result"`
141	Flags   []string `json:"flags"`
142}
143
144type wycheproofDSAKey struct {
145	G       string `json:"g"`
146	KeySize int    `json:"keySize"`
147	P       string `json:"p"`
148	Q       string `json:"q"`
149	Type    string `json:"type"`
150	Y       string `json:"y"`
151}
152
153type wycheproofTestDSA struct {
154	TCID    int      `json:"tcId"`
155	Comment string   `json:"comment"`
156	Msg     string   `json:"msg"`
157	Sig     string   `json:"sig"`
158	Result  string   `json:"result"`
159	Flags   []string `json:"flags"`
160}
161
162type wycheproofTestGroupDSA struct {
163	Key    *wycheproofDSAKey    `json:"key"`
164	KeyDER string               `json:"keyDer"`
165	KeyPEM string               `json:"keyPem"`
166	SHA    string               `json:"sha"`
167	Type   string               `json:"type"`
168	Tests  []*wycheproofTestDSA `json:"tests"`
169}
170
171type wycheproofTestECDH struct {
172	TCID    int      `json:"tcId"`
173	Comment string   `json:"comment"`
174	Public  string   `json:"public"`
175	Private string   `json:"private"`
176	Shared  string   `json:"shared"`
177	Result  string   `json:"result"`
178	Flags   []string `json:"flags"`
179}
180
181type wycheproofTestGroupECDH struct {
182	Curve    string                `json:"curve"`
183	Encoding string                `json:"encoding"`
184	Type     string                `json:"type"`
185	Tests    []*wycheproofTestECDH `json:"tests"`
186}
187
188type wycheproofTestECDHWebCrypto struct {
189	TCID    int                   `json:"tcId"`
190	Comment string                `json:"comment"`
191	Public  *wycheproofJWKPublic  `json:"public"`
192	Private *wycheproofJWKPrivate `json:"private"`
193	Shared  string                `json:"shared"`
194	Result  string                `json:"result"`
195	Flags   []string              `json:"flags"`
196}
197
198type wycheproofTestGroupECDHWebCrypto struct {
199	Curve    string                         `json:"curve"`
200	Encoding string                         `json:"encoding"`
201	Type     string                         `json:"type"`
202	Tests    []*wycheproofTestECDHWebCrypto `json:"tests"`
203}
204
205type wycheproofECDSAKey struct {
206	Curve        string `json:"curve"`
207	KeySize      int    `json:"keySize"`
208	Type         string `json:"type"`
209	Uncompressed string `json:"uncompressed"`
210	WX           string `json:"wx"`
211	WY           string `json:"wy"`
212}
213
214type wycheproofTestECDSA struct {
215	TCID    int      `json:"tcId"`
216	Comment string   `json:"comment"`
217	Msg     string   `json:"msg"`
218	Sig     string   `json:"sig"`
219	Result  string   `json:"result"`
220	Flags   []string `json:"flags"`
221}
222
223type wycheproofTestGroupECDSA struct {
224	Key    *wycheproofECDSAKey    `json:"key"`
225	KeyDER string                 `json:"keyDer"`
226	KeyPEM string                 `json:"keyPem"`
227	SHA    string                 `json:"sha"`
228	Type   string                 `json:"type"`
229	Tests  []*wycheproofTestECDSA `json:"tests"`
230}
231
232type wycheproofTestGroupECDSAWebCrypto struct {
233	JWK    *wycheproofJWKPublic   `json:"jwk"`
234	Key    *wycheproofECDSAKey    `json:"key"`
235	KeyDER string                 `json:"keyDer"`
236	KeyPEM string                 `json:"keyPem"`
237	SHA    string                 `json:"sha"`
238	Type   string                 `json:"type"`
239	Tests  []*wycheproofTestECDSA `json:"tests"`
240}
241
242type wycheproofTestHkdf struct {
243	TCID    int      `json:"tcId"`
244	Comment string   `json:"comment"`
245	Ikm     string   `json:"ikm"`
246	Salt    string   `json:"salt"`
247	Info    string   `json:"info"`
248	Size    int      `json:"size"`
249	Okm     string   `json:"okm"`
250	Result  string   `json:"result"`
251	Flags   []string `json:"flags"`
252}
253
254type wycheproofTestGroupHkdf struct {
255	Type    string                `json:"type"`
256	KeySize int                   `json:"keySize"`
257	Tests   []*wycheproofTestHkdf `json:"tests"`
258}
259
260type wycheproofTestKW struct {
261	TCID    int      `json:"tcId"`
262	Comment string   `json:"comment"`
263	Key     string   `json:"key"`
264	Msg     string   `json:"msg"`
265	CT      string   `json:"ct"`
266	Result  string   `json:"result"`
267	Flags   []string `json:"flags"`
268}
269
270type wycheproofTestGroupKW struct {
271	KeySize int                 `json:"keySize"`
272	Type    string              `json:"type"`
273	Tests   []*wycheproofTestKW `json:"tests"`
274}
275
276type wycheproofTestRSA struct {
277	TCID    int      `json:"tcId"`
278	Comment string   `json:"comment"`
279	Msg     string   `json:"msg"`
280	Sig     string   `json:"sig"`
281	Padding string   `json:"padding"`
282	Result  string   `json:"result"`
283	Flags   []string `json:"flags"`
284}
285
286type wycheproofTestGroupRSA struct {
287	E       string               `json:"e"`
288	KeyASN  string               `json:"keyAsn"`
289	KeyDER  string               `json:"keyDer"`
290	KeyPEM  string               `json:"keyPem"`
291	KeySize int                  `json:"keysize"`
292	N       string               `json:"n"`
293	SHA     string               `json:"sha"`
294	Type    string               `json:"type"`
295	Tests   []*wycheproofTestRSA `json:"tests"`
296}
297
298type wycheproofPrivateKeyJwk struct {
299	Alg string `json:"alg"`
300	D   string `json:"d"`
301	DP  string `json:"dp"`
302	DQ  string `json:"dq"`
303	E   string `json:"e"`
304	KID string `json:"kid"`
305	Kty string `json:"kty"`
306	N   string `json:"n"`
307	P   string `json:"p"`
308	Q   string `json:"q"`
309	QI  string `json:"qi"`
310}
311
312type wycheproofTestRsaes struct {
313	TCID    int      `json:"tcId"`
314	Comment string   `json:"comment"`
315	Msg     string   `json:"msg"`
316	CT      string   `json:"ct"`
317	Label   string   `json:"label"`
318	Result  string   `json:"result"`
319	Flags   []string `json:"flags"`
320}
321
322type wycheproofTestGroupRsaesOaep struct {
323	D               string                   `json:"d"`
324	E               string                   `json:"e"`
325	KeySize         int                      `json:"keysize"`
326	MGF             string                   `json:"mgf"`
327	MGFSHA          string                   `json:"mgfSha"`
328	N               string                   `json:"n"`
329	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
330	PrivateKeyPem   string                   `json:"privateKeyPem"`
331	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
332	SHA             string                   `json:"sha"`
333	Type            string                   `json:"type"`
334	Tests           []*wycheproofTestRsaes   `json:"tests"`
335}
336
337type wycheproofTestGroupRsaesPkcs1 struct {
338	D               string                   `json:"d"`
339	E               string                   `json:"e"`
340	KeySize         int                      `json:"keysize"`
341	N               string                   `json:"n"`
342	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
343	PrivateKeyPem   string                   `json:"privateKeyPem"`
344	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
345	Type            string                   `json:"type"`
346	Tests           []*wycheproofTestRsaes   `json:"tests"`
347}
348
349type wycheproofTestRsassa struct {
350	TCID    int      `json:"tcId"`
351	Comment string   `json:"comment"`
352	Msg     string   `json:"msg"`
353	Sig     string   `json:"sig"`
354	Result  string   `json:"result"`
355	Flags   []string `json:"flags"`
356}
357
358type wycheproofTestGroupRsassa struct {
359	E       string                  `json:"e"`
360	KeyASN  string                  `json:"keyAsn"`
361	KeyDER  string                  `json:"keyDer"`
362	KeyPEM  string                  `json:"keyPem"`
363	KeySize int                     `json:"keysize"`
364	MGF     string                  `json:"mgf"`
365	MGFSHA  string                  `json:"mgfSha"`
366	N       string                  `json:"n"`
367	SLen    int                     `json:"sLen"`
368	SHA     string                  `json:"sha"`
369	Type    string                  `json:"type"`
370	Tests   []*wycheproofTestRsassa `json:"tests"`
371}
372
373type wycheproofTestX25519 struct {
374	TCID    int      `json:"tcId"`
375	Comment string   `json:"comment"`
376	Curve   string   `json:"curve"`
377	Public  string   `json:"public"`
378	Private string   `json:"private"`
379	Shared  string   `json:"shared"`
380	Result  string   `json:"result"`
381	Flags   []string `json:"flags"`
382}
383
384type wycheproofTestGroupX25519 struct {
385	Curve string                  `json:"curve"`
386	Tests []*wycheproofTestX25519 `json:"tests"`
387}
388
389type wycheproofTestVectors struct {
390	Algorithm        string            `json:"algorithm"`
391	GeneratorVersion string            `json:"generatorVersion"`
392	Notes            map[string]string `json:"notes"`
393	NumberOfTests    int               `json:"numberOfTests"`
394	// Header
395	TestGroups []json.RawMessage `json:"testGroups"`
396}
397
398var nids = map[string]int{
399	"brainpoolP224r1": C.NID_brainpoolP224r1,
400	"brainpoolP256r1": C.NID_brainpoolP256r1,
401	"brainpoolP320r1": C.NID_brainpoolP320r1,
402	"brainpoolP384r1": C.NID_brainpoolP384r1,
403	"brainpoolP512r1": C.NID_brainpoolP512r1,
404	"brainpoolP224t1": C.NID_brainpoolP224t1,
405	"brainpoolP256t1": C.NID_brainpoolP256t1,
406	"brainpoolP320t1": C.NID_brainpoolP320t1,
407	"brainpoolP384t1": C.NID_brainpoolP384t1,
408	"brainpoolP512t1": C.NID_brainpoolP512t1,
409	"secp224k1":       C.NID_secp224k1,
410	"secp224r1":       C.NID_secp224r1,
411	"secp256k1":       C.NID_secp256k1,
412	"P-256K":          C.NID_secp256k1,
413	"secp256r1":       C.NID_X9_62_prime256v1, // RFC 8422, Table 4, p.32
414	"P-256":           C.NID_X9_62_prime256v1,
415	"secp384r1":       C.NID_secp384r1,
416	"P-384":           C.NID_secp384r1,
417	"secp521r1":       C.NID_secp521r1,
418	"P-521":           C.NID_secp521r1,
419	"SHA-1":           C.NID_sha1,
420	"SHA-224":         C.NID_sha224,
421	"SHA-256":         C.NID_sha256,
422	"SHA-384":         C.NID_sha384,
423	"SHA-512":         C.NID_sha512,
424}
425
426func gatherAcceptableStatistics(testcase int, comment string, flags []string) {
427	fmt.Printf("AUDIT: Test case %d (%q) %v\n", testcase, comment, flags)
428
429	if comment == "" {
430		acceptableComments["No comment"]++
431	} else {
432		acceptableComments[comment]++
433	}
434
435	if len(flags) == 0 {
436		acceptableFlags["NoFlag"]++
437	} else {
438		for _, flag := range flags {
439			acceptableFlags[flag]++
440		}
441	}
442}
443
444func printAcceptableStatistics() {
445	fmt.Printf("\nComment statistics:\n")
446
447	var comments []string
448	for comment := range acceptableComments {
449		comments = append(comments, comment)
450	}
451	sort.Strings(comments)
452	for _, comment := range comments {
453		prcomment := comment
454		if len(comment) > 45 {
455			prcomment = comment[0:42] + "..."
456		}
457		fmt.Printf("%-45v %5d\n", prcomment, acceptableComments[comment])
458	}
459
460	fmt.Printf("\nFlag statistics:\n")
461	var flags []string
462	for flag := range acceptableFlags {
463		flags = append(flags, flag)
464	}
465	sort.Strings(flags)
466	for _, flag := range flags {
467		fmt.Printf("%-45v %5d\n", flag, acceptableFlags[flag])
468	}
469}
470
471func nidFromString(ns string) (int, error) {
472	nid, ok := nids[ns]
473	if ok {
474		return nid, nil
475	}
476	return -1, fmt.Errorf("unknown NID %q", ns)
477}
478
479func hashFromString(hs string) (hash.Hash, error) {
480	switch hs {
481	case "SHA-1":
482		return sha1.New(), nil
483	case "SHA-224":
484		return sha256.New224(), nil
485	case "SHA-256":
486		return sha256.New(), nil
487	case "SHA-384":
488		return sha512.New384(), nil
489	case "SHA-512":
490		return sha512.New(), nil
491	default:
492		return nil, fmt.Errorf("unknown hash %q", hs)
493	}
494}
495
496func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
497	switch hs {
498	case "SHA-1":
499		return C.EVP_sha1(), nil
500	case "SHA-224":
501		return C.EVP_sha224(), nil
502	case "SHA-256":
503		return C.EVP_sha256(), nil
504	case "SHA-384":
505		return C.EVP_sha384(), nil
506	case "SHA-512":
507		return C.EVP_sha512(), nil
508	default:
509		return nil, fmt.Errorf("unknown hash %q", hs)
510	}
511}
512
513func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
514	iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
515	wt *wycheproofTestAesCbcPkcs5) bool {
516	var action string
517	if doEncrypt == 1 {
518		action = "encrypting"
519	} else {
520		action = "decrypting"
521	}
522
523	ret := C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
524		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
525	if ret != 1 {
526		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
527	}
528
529	cipherOut := make([]byte, inLen+C.EVP_MAX_BLOCK_LENGTH)
530	var cipherOutLen C.int
531
532	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
533		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
534	if ret != 1 {
535		if wt.Result == "invalid" {
536			fmt.Printf("INFO: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
537				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
538			return true
539		}
540		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
541			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
542		return false
543	}
544
545	var finallen C.int
546	ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[cipherOutLen])), &finallen)
547	if ret != 1 {
548		if wt.Result == "invalid" {
549			return true
550		}
551		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
552			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
553		return false
554	}
555
556	cipherOutLen += finallen
557	if cipherOutLen != C.int(outLen) && wt.Result != "invalid" {
558		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - open length mismatch: got %d, want %d\n",
559			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen)
560		return false
561	}
562
563	openedMsg := out[0:cipherOutLen]
564	if outLen == 0 {
565		out = nil
566	}
567
568	success := false
569	if bytes.Equal(openedMsg, out) || wt.Result == "invalid" {
570		success = true
571		if acceptableAudit && wt.Result == "acceptable" {
572			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
573		}
574	} else {
575		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - msg match: %t; want %v\n",
576			wt.TCID, wt.Comment, action, wt.Flags, bytes.Equal(openedMsg, out), wt.Result)
577	}
578	return success
579}
580
581func runAesCbcPkcs5Test(ctx *C.EVP_CIPHER_CTX, wt *wycheproofTestAesCbcPkcs5) bool {
582	key, err := hex.DecodeString(wt.Key)
583	if err != nil {
584		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
585	}
586	iv, err := hex.DecodeString(wt.IV)
587	if err != nil {
588		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
589	}
590	ct, err := hex.DecodeString(wt.CT)
591	if err != nil {
592		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
593	}
594	msg, err := hex.DecodeString(wt.Msg)
595	if err != nil {
596		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
597	}
598
599	keyLen, ivLen, ctLen, msgLen := len(key), len(iv), len(ct), len(msg)
600
601	if keyLen == 0 {
602		key = append(key, 0)
603	}
604	if ivLen == 0 {
605		iv = append(iv, 0)
606	}
607	if ctLen == 0 {
608		ct = append(ct, 0)
609	}
610	if msgLen == 0 {
611		msg = append(msg, 0)
612	}
613
614	openSuccess := checkAesCbcPkcs5(ctx, 0, key, keyLen, iv, ivLen, ct, ctLen, msg, msgLen, wt)
615	sealSuccess := checkAesCbcPkcs5(ctx, 1, key, keyLen, iv, ivLen, msg, msgLen, ct, ctLen, wt)
616
617	return openSuccess && sealSuccess
618}
619
620func runAesCbcPkcs5TestGroup(algorithm string, wtg *wycheproofTestGroupAesCbcPkcs5) bool {
621	fmt.Printf("Running %v test group %v with IV size %d and key size %d...\n",
622		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize)
623
624	var cipher *C.EVP_CIPHER
625	switch wtg.KeySize {
626	case 128:
627		cipher = C.EVP_aes_128_cbc()
628	case 192:
629		cipher = C.EVP_aes_192_cbc()
630	case 256:
631		cipher = C.EVP_aes_256_cbc()
632	default:
633		log.Fatalf("Unsupported key size: %d", wtg.KeySize)
634	}
635
636	ctx := C.EVP_CIPHER_CTX_new()
637	if ctx == nil {
638		log.Fatal("EVP_CIPHER_CTX_new() failed")
639	}
640	defer C.EVP_CIPHER_CTX_free(ctx)
641
642	ret := C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 0)
643	if ret != 1 {
644		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
645	}
646
647	success := true
648	for _, wt := range wtg.Tests {
649		if !runAesCbcPkcs5Test(ctx, wt) {
650			success = false
651		}
652	}
653	return success
654}
655
656func checkAesAead(algorithm string, ctx *C.EVP_CIPHER_CTX, doEncrypt int,
657	key []byte, keyLen int, iv []byte, ivLen int, aad []byte, aadLen int,
658	in []byte, inLen int, out []byte, outLen int, tag []byte, tagLen int,
659	wt *wycheproofTestAead) bool {
660	var ctrlSetIVLen C.int
661	var ctrlSetTag C.int
662	var ctrlGetTag C.int
663
664	doCCM := false
665	switch algorithm {
666	case "AES-CCM":
667		doCCM = true
668		ctrlSetIVLen = C.EVP_CTRL_CCM_SET_IVLEN
669		ctrlSetTag = C.EVP_CTRL_CCM_SET_TAG
670		ctrlGetTag = C.EVP_CTRL_CCM_GET_TAG
671	case "AES-GCM":
672		ctrlSetIVLen = C.EVP_CTRL_GCM_SET_IVLEN
673		ctrlSetTag = C.EVP_CTRL_GCM_SET_TAG
674		ctrlGetTag = C.EVP_CTRL_GCM_GET_TAG
675	}
676
677	setTag := unsafe.Pointer(nil)
678	var action string
679
680	if doEncrypt == 1 {
681		action = "encrypting"
682	} else {
683		action = "decrypting"
684		setTag = unsafe.Pointer(&tag[0])
685	}
686
687	ret := C.EVP_CipherInit_ex(ctx, nil, nil, nil, nil, C.int(doEncrypt))
688	if ret != 1 {
689		log.Fatalf("[%v] cipher init failed", action)
690	}
691
692	ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetIVLen, C.int(ivLen), nil)
693	if ret != 1 {
694		if wt.Comment == "Nonce is too long" || wt.Comment == "Invalid nonce size" ||
695			wt.Comment == "0 size IV is not valid" || wt.Comment == "Very long nonce" {
696			return true
697		}
698		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting IV len to %d failed. got %d, want %v\n",
699			wt.TCID, wt.Comment, action, wt.Flags, ivLen, ret, wt.Result)
700		return false
701	}
702
703	if doEncrypt == 0 || doCCM {
704		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetTag, C.int(tagLen), setTag)
705		if ret != 1 {
706			if wt.Comment == "Invalid tag size" {
707				return true
708			}
709			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting tag length to %d failed. got %d, want %v\n",
710				wt.TCID, wt.Comment, action, wt.Flags, tagLen, ret, wt.Result)
711			return false
712		}
713	}
714
715	ret = C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
716		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
717	if ret != 1 {
718		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting key and IV failed. got %d, want %v\n",
719			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
720		return false
721	}
722
723	var cipherOutLen C.int
724	if doCCM {
725		ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, nil, C.int(inLen))
726		if ret != 1 {
727			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting input length to %d failed. got %d, want %v\n",
728				wt.TCID, wt.Comment, action, wt.Flags, inLen, ret, wt.Result)
729			return false
730		}
731	}
732
733	ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, (*C.uchar)(unsafe.Pointer(&aad[0])), C.int(aadLen))
734	if ret != 1 {
735		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - processing AAD failed. got %d, want %v\n",
736			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
737		return false
738	}
739
740	cipherOutLen = 0
741	cipherOut := make([]byte, inLen)
742	if inLen == 0 {
743		cipherOut = append(cipherOut, 0)
744	}
745
746	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
747		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
748	if ret != 1 {
749		if wt.Result == "invalid" {
750			return true
751		}
752		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
753			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
754		return false
755	}
756
757	if doEncrypt == 1 {
758		var tmpLen C.int
759		dummyOut := make([]byte, 16)
760
761		ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&dummyOut[0])), &tmpLen)
762		if ret != 1 {
763			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
764				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
765			return false
766		}
767		cipherOutLen += tmpLen
768	}
769
770	if cipherOutLen != C.int(outLen) {
771		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - cipherOutLen %d != outLen %d. Result %v\n",
772			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen, wt.Result)
773		return false
774	}
775
776	success := true
777	if !bytes.Equal(cipherOut, out) {
778		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed output do not match. Result: %v\n",
779			wt.TCID, wt.Comment, action, wt.Flags, wt.Result)
780		success = false
781	}
782	if doEncrypt == 1 {
783		tagOut := make([]byte, tagLen)
784		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlGetTag, C.int(tagLen), unsafe.Pointer(&tagOut[0]))
785		if ret != 1 {
786			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CIPHER_CTX_ctrl() = %d, want %v\n",
787				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
788			return false
789		}
790
791		// There are no acceptable CCM cases. All acceptable GCM tests
792		// pass. They have len(IV) <= 48. NIST SP 800-38D, 5.2.1.1, p.8,
793		// allows 1 <= len(IV) <= 2^64-1, but notes:
794		//   "For IVs it is recommended that implementations restrict
795		//    support to the length of 96 bits, to promote
796		//    interoperability, efficiency and simplicity of design."
797		if bytes.Equal(tagOut, tag) != (wt.Result == "valid" || wt.Result == "acceptable") {
798			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed tag do not match - ret: %d, Result: %v\n",
799				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
800			success = false
801		}
802		if acceptableAudit && bytes.Equal(tagOut, tag) && wt.Result == "acceptable" {
803			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
804		}
805	}
806	return success
807}
808
809func runAesAeadTest(algorithm string, ctx *C.EVP_CIPHER_CTX, aead *C.EVP_AEAD, wt *wycheproofTestAead) bool {
810	key, err := hex.DecodeString(wt.Key)
811	if err != nil {
812		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
813	}
814
815	iv, err := hex.DecodeString(wt.IV)
816	if err != nil {
817		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
818	}
819
820	aad, err := hex.DecodeString(wt.AAD)
821	if err != nil {
822		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
823	}
824
825	msg, err := hex.DecodeString(wt.Msg)
826	if err != nil {
827		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
828	}
829
830	ct, err := hex.DecodeString(wt.CT)
831	if err != nil {
832		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
833	}
834
835	tag, err := hex.DecodeString(wt.Tag)
836	if err != nil {
837		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
838	}
839
840	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
841
842	if keyLen == 0 {
843		key = append(key, 0)
844	}
845	if ivLen == 0 {
846		iv = append(iv, 0)
847	}
848	if aadLen == 0 {
849		aad = append(aad, 0)
850	}
851	if msgLen == 0 {
852		msg = append(msg, 0)
853	}
854	if ctLen == 0 {
855		ct = append(ct, 0)
856	}
857	if tagLen == 0 {
858		tag = append(tag, 0)
859	}
860
861	openEvp := checkAesAead(algorithm, ctx, 0, key, keyLen, iv, ivLen, aad, aadLen, ct, ctLen, msg, msgLen, tag, tagLen, wt)
862	sealEvp := checkAesAead(algorithm, ctx, 1, key, keyLen, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
863
864	openAead, sealAead := true, true
865	if aead != nil {
866		var ctx C.EVP_AEAD_CTX
867		if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
868			log.Fatal("Failed to initialize AEAD context")
869		}
870		defer C.EVP_AEAD_CTX_cleanup(&ctx)
871
872		// Make sure we don't accidentally prepend or compare against a 0.
873		if ctLen == 0 {
874			ct = nil
875		}
876
877		openAead = checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
878		sealAead = checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
879	}
880
881	return openEvp && sealEvp && openAead && sealAead
882}
883
884func runAesAeadTestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
885	fmt.Printf("Running %v test group %v with IV size %d, key size %d and tag size %d...\n",
886		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
887
888	var cipher *C.EVP_CIPHER
889	var aead *C.EVP_AEAD
890	switch algorithm {
891	case "AES-CCM":
892		switch wtg.KeySize {
893		case 128:
894			cipher = C.EVP_aes_128_ccm()
895		case 192:
896			cipher = C.EVP_aes_192_ccm()
897		case 256:
898			cipher = C.EVP_aes_256_ccm()
899		default:
900			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
901			return true
902		}
903	case "AES-GCM":
904		switch wtg.KeySize {
905		case 128:
906			cipher = C.EVP_aes_128_gcm()
907			aead = C.EVP_aead_aes_128_gcm()
908		case 192:
909			cipher = C.EVP_aes_192_gcm()
910		case 256:
911			cipher = C.EVP_aes_256_gcm()
912			aead = C.EVP_aead_aes_256_gcm()
913		default:
914			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
915			return true
916		}
917	default:
918		log.Fatalf("runAesAeadTestGroup() - unhandled algorithm: %v", algorithm)
919	}
920
921	ctx := C.EVP_CIPHER_CTX_new()
922	if ctx == nil {
923		log.Fatal("EVP_CIPHER_CTX_new() failed")
924	}
925	defer C.EVP_CIPHER_CTX_free(ctx)
926
927	C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 1)
928
929	success := true
930	for _, wt := range wtg.Tests {
931		if !runAesAeadTest(algorithm, ctx, aead, wt) {
932			success = false
933		}
934	}
935	return success
936}
937
938func runAesCmacTest(cipher *C.EVP_CIPHER, wt *wycheproofTestAesCmac) bool {
939	key, err := hex.DecodeString(wt.Key)
940	if err != nil {
941		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
942	}
943
944	msg, err := hex.DecodeString(wt.Msg)
945	if err != nil {
946		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
947	}
948
949	tag, err := hex.DecodeString(wt.Tag)
950	if err != nil {
951		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
952	}
953
954	keyLen, msgLen, tagLen := len(key), len(msg), len(tag)
955
956	if keyLen == 0 {
957		key = append(key, 0)
958	}
959	if msgLen == 0 {
960		msg = append(msg, 0)
961	}
962	if tagLen == 0 {
963		tag = append(tag, 0)
964	}
965
966	ctx := C.CMAC_CTX_new()
967	if ctx == nil {
968		log.Fatal("CMAC_CTX_new failed")
969	}
970	defer C.CMAC_CTX_free(ctx)
971
972	ret := C.CMAC_Init(ctx, unsafe.Pointer(&key[0]), C.size_t(keyLen), cipher, nil)
973	if ret != 1 {
974		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Init() = %d, want %v\n",
975			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
976		return false
977	}
978
979	ret = C.CMAC_Update(ctx, unsafe.Pointer(&msg[0]), C.size_t(msgLen))
980	if ret != 1 {
981		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Update() = %d, want %v\n",
982			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
983		return false
984	}
985
986	var outLen C.size_t
987	outTag := make([]byte, 16)
988
989	ret = C.CMAC_Final(ctx, (*C.uchar)(unsafe.Pointer(&outTag[0])), &outLen)
990	if ret != 1 {
991		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Final() = %d, want %v\n",
992			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
993		return false
994	}
995
996	outTag = outTag[0:tagLen]
997
998	success := true
999	if bytes.Equal(tag, outTag) != (wt.Result == "valid") {
1000		fmt.Printf("FAIL: Test case %d (%q) %v - want %v\n",
1001			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1002		success = false
1003	}
1004	return success
1005}
1006
1007func runAesCmacTestGroup(algorithm string, wtg *wycheproofTestGroupAesCmac) bool {
1008	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n",
1009		algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
1010	var cipher *C.EVP_CIPHER
1011
1012	switch wtg.KeySize {
1013	case 128:
1014		cipher = C.EVP_aes_128_cbc()
1015	case 192:
1016		cipher = C.EVP_aes_192_cbc()
1017	case 256:
1018		cipher = C.EVP_aes_256_cbc()
1019	default:
1020		fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
1021		return true
1022	}
1023
1024	success := true
1025	for _, wt := range wtg.Tests {
1026		if !runAesCmacTest(cipher, wt) {
1027			success = false
1028		}
1029	}
1030	return success
1031}
1032
1033func checkAeadOpen(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte, msgLen int,
1034	ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1035	maxOutLen := ctLen + tagLen
1036
1037	opened := make([]byte, maxOutLen)
1038	if maxOutLen == 0 {
1039		opened = append(opened, 0)
1040	}
1041	var openedMsgLen C.size_t
1042
1043	catCtTag := append(ct, tag...)
1044	catCtTagLen := len(catCtTag)
1045	if catCtTagLen == 0 {
1046		catCtTag = append(catCtTag, 0)
1047	}
1048	openRet := C.EVP_AEAD_CTX_open(ctx, (*C.uint8_t)(unsafe.Pointer(&opened[0])),
1049		(*C.size_t)(unsafe.Pointer(&openedMsgLen)), C.size_t(maxOutLen),
1050		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1051		(*C.uint8_t)(unsafe.Pointer(&catCtTag[0])), C.size_t(catCtTagLen),
1052		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1053
1054	if openRet != 1 {
1055		if wt.Result == "invalid" {
1056			return true
1057		}
1058		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_open() = %d, want %v\n",
1059			wt.TCID, wt.Comment, wt.Flags, int(openRet), wt.Result)
1060		return false
1061	}
1062
1063	if openedMsgLen != C.size_t(msgLen) {
1064		fmt.Printf("FAIL: Test case %d (%q) %v - open length mismatch: got %d, want %d\n",
1065			wt.TCID, wt.Comment, wt.Flags, openedMsgLen, msgLen)
1066		return false
1067	}
1068
1069	openedMsg := opened[0:openedMsgLen]
1070	if msgLen == 0 {
1071		msg = nil
1072	}
1073
1074	success := false
1075	if bytes.Equal(openedMsg, msg) || wt.Result == "invalid" {
1076		if acceptableAudit && wt.Result == "acceptable" {
1077			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1078		}
1079		success = true
1080	} else {
1081		fmt.Printf("FAIL: Test case %d (%q) %v - msg match: %t; want %v\n",
1082			wt.TCID, wt.Comment, wt.Flags, bytes.Equal(openedMsg, msg), wt.Result)
1083	}
1084	return success
1085}
1086
1087func checkAeadSeal(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte,
1088	msgLen int, ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1089	maxOutLen := msgLen + tagLen
1090
1091	sealed := make([]byte, maxOutLen)
1092	if maxOutLen == 0 {
1093		sealed = append(sealed, 0)
1094	}
1095	var sealedLen C.size_t
1096
1097	sealRet := C.EVP_AEAD_CTX_seal(ctx, (*C.uint8_t)(unsafe.Pointer(&sealed[0])),
1098		(*C.size_t)(unsafe.Pointer(&sealedLen)), C.size_t(maxOutLen),
1099		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1100		(*C.uint8_t)(unsafe.Pointer(&msg[0])), C.size_t(msgLen),
1101		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1102
1103	if sealRet != 1 {
1104		success := (wt.Result == "invalid")
1105		if !success {
1106			fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, int(sealRet), wt.Result)
1107		}
1108		return success
1109	}
1110
1111	if sealedLen != C.size_t(maxOutLen) {
1112		fmt.Printf("FAIL: Test case %d (%q) %v - seal length mismatch: got %d, want %d\n",
1113			wt.TCID, wt.Comment, wt.Flags, sealedLen, maxOutLen)
1114		return false
1115	}
1116
1117	sealedCt := sealed[0:msgLen]
1118	sealedTag := sealed[msgLen:maxOutLen]
1119
1120	success := false
1121	if bytes.Equal(sealedCt, ct) && bytes.Equal(sealedTag, tag) || wt.Result == "invalid" {
1122		if acceptableAudit && wt.Result == "acceptable" {
1123			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1124		}
1125		success = true
1126	} else {
1127		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, ct match: %t, tag match: %t; want %v\n",
1128			wt.TCID, wt.Comment, wt.Flags, int(sealRet),
1129			bytes.Equal(sealedCt, ct), bytes.Equal(sealedTag, tag), wt.Result)
1130	}
1131	return success
1132}
1133
1134func runChaCha20Poly1305Test(algorithm string, wt *wycheproofTestAead) bool {
1135	var aead *C.EVP_AEAD
1136	switch algorithm {
1137	case "CHACHA20-POLY1305":
1138		aead = C.EVP_aead_chacha20_poly1305()
1139	case "XCHACHA20-POLY1305":
1140		aead = C.EVP_aead_xchacha20_poly1305()
1141	}
1142
1143	key, err := hex.DecodeString(wt.Key)
1144	if err != nil {
1145		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
1146	}
1147	iv, err := hex.DecodeString(wt.IV)
1148	if err != nil {
1149		log.Fatalf("Failed to decode key %q: %v", wt.IV, err)
1150	}
1151	aad, err := hex.DecodeString(wt.AAD)
1152	if err != nil {
1153		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
1154	}
1155	msg, err := hex.DecodeString(wt.Msg)
1156	if err != nil {
1157		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
1158	}
1159	ct, err := hex.DecodeString(wt.CT)
1160	if err != nil {
1161		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1162	}
1163	tag, err := hex.DecodeString(wt.Tag)
1164	if err != nil {
1165		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1166	}
1167
1168	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
1169
1170	if ivLen == 0 {
1171		iv = append(iv, 0)
1172	}
1173	if aadLen == 0 {
1174		aad = append(aad, 0)
1175	}
1176	if msgLen == 0 {
1177		msg = append(msg, 0)
1178	}
1179	if ctLen == 0 {
1180		msg = append(ct, 0)
1181	}
1182	if tagLen == 0 {
1183		msg = append(tag, 0)
1184	}
1185
1186	var ctx C.EVP_AEAD_CTX
1187	if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
1188		log.Fatal("Failed to initialize AEAD context")
1189	}
1190	defer C.EVP_AEAD_CTX_cleanup(&ctx)
1191
1192	openSuccess := checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1193	sealSuccess := checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1194
1195	return openSuccess && sealSuccess
1196}
1197
1198func runChaCha20Poly1305TestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
1199	// ChaCha20-Poly1305 currently only supports nonces of length 12 (96 bits)
1200	if algorithm == "CHACHA20-POLY1305" && wtg.IVSize != 96 {
1201		return true
1202	}
1203
1204	fmt.Printf("Running %v test group %v with IV size %d, key size %d, tag size %d...\n",
1205		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
1206
1207	success := true
1208	for _, wt := range wtg.Tests {
1209		if !runChaCha20Poly1305Test(algorithm, wt) {
1210			success = false
1211		}
1212	}
1213	return success
1214}
1215
1216func runDSATest(dsa *C.DSA, h hash.Hash, wt *wycheproofTestDSA) bool {
1217	msg, err := hex.DecodeString(wt.Msg)
1218	if err != nil {
1219		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1220	}
1221
1222	h.Reset()
1223	h.Write(msg)
1224	msg = h.Sum(nil)
1225
1226	sig, err := hex.DecodeString(wt.Sig)
1227	if err != nil {
1228		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1229	}
1230
1231	msgLen, sigLen := len(msg), len(sig)
1232	if msgLen == 0 {
1233		msg = append(msg, 0)
1234	}
1235	if sigLen == 0 {
1236		sig = append(msg, 0)
1237	}
1238
1239	ret := C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1240		(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), dsa)
1241
1242	success := true
1243	if ret == 1 != (wt.Result == "valid") {
1244		fmt.Printf("FAIL: Test case %d (%q) %v - DSA_verify() = %d, want %v\n",
1245			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1246		success = false
1247	}
1248	return success
1249}
1250
1251func runDSATestGroup(algorithm string, wtg *wycheproofTestGroupDSA) bool {
1252	fmt.Printf("Running %v test group %v, key size %d and %v...\n",
1253		algorithm, wtg.Type, wtg.Key.KeySize, wtg.SHA)
1254
1255	dsa := C.DSA_new()
1256	if dsa == nil {
1257		log.Fatal("DSA_new failed")
1258	}
1259	defer C.DSA_free(dsa)
1260
1261	var bnG *C.BIGNUM
1262	wg := C.CString(wtg.Key.G)
1263	if C.BN_hex2bn(&bnG, wg) == 0 {
1264		log.Fatal("Failed to decode g")
1265	}
1266	C.free(unsafe.Pointer(wg))
1267
1268	var bnP *C.BIGNUM
1269	wp := C.CString(wtg.Key.P)
1270	if C.BN_hex2bn(&bnP, wp) == 0 {
1271		log.Fatal("Failed to decode p")
1272	}
1273	C.free(unsafe.Pointer(wp))
1274
1275	var bnQ *C.BIGNUM
1276	wq := C.CString(wtg.Key.Q)
1277	if C.BN_hex2bn(&bnQ, wq) == 0 {
1278		log.Fatal("Failed to decode q")
1279	}
1280	C.free(unsafe.Pointer(wq))
1281
1282	ret := C.DSA_set0_pqg(dsa, bnP, bnQ, bnG)
1283	if ret != 1 {
1284		log.Fatalf("DSA_set0_pqg returned %d", ret)
1285	}
1286
1287	var bnY *C.BIGNUM
1288	wy := C.CString(wtg.Key.Y)
1289	if C.BN_hex2bn(&bnY, wy) == 0 {
1290		log.Fatal("Failed to decode y")
1291	}
1292	C.free(unsafe.Pointer(wy))
1293
1294	ret = C.DSA_set0_key(dsa, bnY, nil)
1295	if ret != 1 {
1296		log.Fatalf("DSA_set0_key returned %d", ret)
1297	}
1298
1299	h, err := hashFromString(wtg.SHA)
1300	if err != nil {
1301		log.Fatalf("Failed to get hash: %v", err)
1302	}
1303
1304	der, err := hex.DecodeString(wtg.KeyDER)
1305	if err != nil {
1306		log.Fatalf("Failed to decode DER encoded key: %v", err)
1307	}
1308
1309	derLen := len(der)
1310	if derLen == 0 {
1311		der = append(der, 0)
1312	}
1313
1314	Cder := (*C.uchar)(C.malloc(C.ulong(derLen)))
1315	if Cder == nil {
1316		log.Fatal("malloc failed")
1317	}
1318	C.memcpy(unsafe.Pointer(Cder), unsafe.Pointer(&der[0]), C.ulong(derLen))
1319
1320	p := (*C.uchar)(Cder)
1321	dsaDER := C.d2i_DSA_PUBKEY(nil, (**C.uchar)(&p), C.long(derLen))
1322	defer C.DSA_free(dsaDER)
1323	C.free(unsafe.Pointer(Cder))
1324
1325	keyPEM := C.CString(wtg.KeyPEM)
1326	bio := C.BIO_new_mem_buf(unsafe.Pointer(keyPEM), C.int(len(wtg.KeyPEM)))
1327	if bio == nil {
1328		log.Fatal("BIO_new_mem_buf failed")
1329	}
1330	defer C.free(unsafe.Pointer(keyPEM))
1331	defer C.BIO_free(bio)
1332
1333	dsaPEM := C.PEM_read_bio_DSA_PUBKEY(bio, nil, nil, nil)
1334	if dsaPEM == nil {
1335		log.Fatal("PEM_read_bio_DSA_PUBKEY failed")
1336	}
1337	defer C.DSA_free(dsaPEM)
1338
1339	success := true
1340	for _, wt := range wtg.Tests {
1341		if !runDSATest(dsa, h, wt) {
1342			success = false
1343		}
1344		if !runDSATest(dsaDER, h, wt) {
1345			success = false
1346		}
1347		if !runDSATest(dsaPEM, h, wt) {
1348			success = false
1349		}
1350	}
1351	return success
1352}
1353
1354func runECDHTest(nid int, doECpoint bool, wt *wycheproofTestECDH) bool {
1355	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1356	if privKey == nil {
1357		log.Fatalf("EC_KEY_new_by_curve_name failed")
1358	}
1359	defer C.EC_KEY_free(privKey)
1360
1361	var bnPriv *C.BIGNUM
1362	wPriv := C.CString(wt.Private)
1363	if C.BN_hex2bn(&bnPriv, wPriv) == 0 {
1364		log.Fatal("Failed to decode wPriv")
1365	}
1366	C.free(unsafe.Pointer(wPriv))
1367	defer C.BN_free(bnPriv)
1368
1369	ret := C.EC_KEY_set_private_key(privKey, bnPriv)
1370	if ret != 1 {
1371		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1372			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1373		return false
1374	}
1375
1376	pub, err := hex.DecodeString(wt.Public)
1377	if err != nil {
1378		log.Fatalf("Failed to decode public key: %v", err)
1379	}
1380
1381	pubLen := len(pub)
1382	if pubLen == 0 {
1383		pub = append(pub, 0)
1384	}
1385
1386	Cpub := (*C.uchar)(C.malloc(C.ulong(pubLen)))
1387	if Cpub == nil {
1388		log.Fatal("malloc failed")
1389	}
1390	C.memcpy(unsafe.Pointer(Cpub), unsafe.Pointer(&pub[0]), C.ulong(pubLen))
1391
1392	p := (*C.uchar)(Cpub)
1393	var pubKey *C.EC_KEY
1394	if doECpoint {
1395		pubKey = C.EC_KEY_new_by_curve_name(C.int(nid))
1396		if pubKey == nil {
1397			log.Fatal("EC_KEY_new_by_curve_name failed")
1398		}
1399		pubKey = C.o2i_ECPublicKey(&pubKey, (**C.uchar)(&p), C.long(pubLen))
1400	} else {
1401		pubKey = C.d2i_EC_PUBKEY(nil, (**C.uchar)(&p), C.long(pubLen))
1402	}
1403	defer C.EC_KEY_free(pubKey)
1404	C.free(unsafe.Pointer(Cpub))
1405
1406	if pubKey == nil {
1407		if wt.Result == "invalid" || wt.Result == "acceptable" {
1408			return true
1409		}
1410		fmt.Printf("FAIL: Test case %d (%q) %v - ASN decoding failed: want %v\n",
1411			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1412		return false
1413	}
1414
1415	privGroup := C.EC_KEY_get0_group(privKey)
1416
1417	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1418
1419	secret := make([]byte, secLen)
1420	if secLen == 0 {
1421		secret = append(secret, 0)
1422	}
1423
1424	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1425
1426	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1427	if ret != C.int(secLen) {
1428		if wt.Result == "invalid" {
1429			return true
1430		}
1431		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1432			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1433		return false
1434	}
1435
1436	shared, err := hex.DecodeString(wt.Shared)
1437	if err != nil {
1438		log.Fatalf("Failed to decode shared secret: %v", err)
1439	}
1440
1441	success := true
1442	if !bytes.Equal(shared, secret) {
1443		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1444			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1445		success = false
1446	}
1447	if acceptableAudit && success && wt.Result == "acceptable" {
1448		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1449	}
1450	return success
1451}
1452
1453func runECDHTestGroup(algorithm string, wtg *wycheproofTestGroupECDH) bool {
1454	doECpoint := false
1455	if wtg.Encoding == "ecpoint" {
1456		doECpoint = true
1457	}
1458
1459	// XXX
1460	if wtg.Curve == "secp224k1" {
1461		fmt.Printf("INFO: skipping %v test group %v with curve %v and %v encoding...\n", algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1462		return true
1463	}
1464
1465	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1466		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1467
1468	nid, err := nidFromString(wtg.Curve)
1469	if err != nil {
1470		log.Fatalf("Failed to get nid for curve: %v", err)
1471	}
1472
1473	success := true
1474	for _, wt := range wtg.Tests {
1475		if !runECDHTest(nid, doECpoint, wt) {
1476			success = false
1477		}
1478	}
1479	return success
1480}
1481
1482func runECDHWebCryptoTest(nid int, wt *wycheproofTestECDHWebCrypto) bool {
1483	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1484	if privKey == nil {
1485		log.Fatalf("EC_KEY_new_by_curve_name failed")
1486	}
1487	defer C.EC_KEY_free(privKey)
1488
1489	d, err := base64.RawURLEncoding.DecodeString(wt.Private.D)
1490	if err != nil {
1491		log.Fatalf("Failed to base64 decode d: %v", err)
1492	}
1493	bnD := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&d[0])), C.int(len(d)), nil)
1494	if bnD == nil {
1495		log.Fatal("Failed to decode D")
1496	}
1497	defer C.BN_free(bnD)
1498
1499	ret := C.EC_KEY_set_private_key(privKey, bnD)
1500	if ret != 1 {
1501		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1502			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1503		return false
1504	}
1505
1506	x, err := base64.RawURLEncoding.DecodeString(wt.Public.X)
1507	if err != nil {
1508		log.Fatalf("Failed to base64 decode x: %v", err)
1509	}
1510	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1511	if bnX == nil {
1512		log.Fatal("Failed to decode X")
1513	}
1514	defer C.BN_free(bnX)
1515
1516	y, err := base64.RawURLEncoding.DecodeString(wt.Public.Y)
1517	if err != nil {
1518		log.Fatalf("Failed to base64 decode y: %v", err)
1519	}
1520	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1521	if bnY == nil {
1522		log.Fatal("Failed to decode Y")
1523	}
1524	defer C.BN_free(bnY)
1525
1526	pubKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1527	if pubKey == nil {
1528		log.Fatal("Failed to create EC_KEY")
1529	}
1530	defer C.EC_KEY_free(pubKey)
1531
1532	ret = C.EC_KEY_set_public_key_affine_coordinates(pubKey, bnX, bnY)
1533	if ret != 1 {
1534		if wt.Result == "invalid" {
1535			return true
1536		}
1537		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_public_key_affine_coordinates() = %d, want %v\n",
1538			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1539		return false
1540	}
1541	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1542
1543	privGroup := C.EC_KEY_get0_group(privKey)
1544
1545	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1546
1547	secret := make([]byte, secLen)
1548	if secLen == 0 {
1549		secret = append(secret, 0)
1550	}
1551
1552	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1553	if ret != C.int(secLen) {
1554		if wt.Result == "invalid" {
1555			return true
1556		}
1557		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1558			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1559		return false
1560	}
1561
1562	shared, err := hex.DecodeString(wt.Shared)
1563	if err != nil {
1564		log.Fatalf("Failed to decode shared secret: %v", err)
1565	}
1566
1567	success := true
1568	if !bytes.Equal(shared, secret) {
1569		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1570			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1571		success = false
1572	}
1573	if acceptableAudit && success && wt.Result == "acceptable" {
1574		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1575	}
1576	return success
1577}
1578
1579func runECDHWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDHWebCrypto) bool {
1580	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1581		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1582
1583	nid, err := nidFromString(wtg.Curve)
1584	if err != nil {
1585		log.Fatalf("Failed to get nid for curve: %v", err)
1586	}
1587
1588	success := true
1589	for _, wt := range wtg.Tests {
1590		if !runECDHWebCryptoTest(nid, wt) {
1591			success = false
1592		}
1593	}
1594	return success
1595}
1596
1597func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, webcrypto bool, wt *wycheproofTestECDSA) bool {
1598	msg, err := hex.DecodeString(wt.Msg)
1599	if err != nil {
1600		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1601	}
1602
1603	h.Reset()
1604	h.Write(msg)
1605	msg = h.Sum(nil)
1606
1607	msgLen := len(msg)
1608	if msgLen == 0 {
1609		msg = append(msg, 0)
1610	}
1611
1612	var ret C.int
1613	if webcrypto {
1614		cDer, derLen := encodeECDSAWebCryptoSig(wt.Sig)
1615		if cDer == nil {
1616			fmt.Print("FAIL: unable to decode signature")
1617			return false
1618		}
1619		defer C.free(unsafe.Pointer(cDer))
1620
1621		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1622			(*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), ecKey)
1623	} else {
1624		sig, err := hex.DecodeString(wt.Sig)
1625		if err != nil {
1626			log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1627		}
1628
1629		sigLen := len(sig)
1630		if sigLen == 0 {
1631			sig = append(sig, 0)
1632		}
1633		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1634			(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), ecKey)
1635	}
1636
1637	// XXX audit acceptable cases...
1638	success := true
1639	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
1640		fmt.Printf("FAIL: Test case %d (%q) %v - ECDSA_verify() = %d, want %v\n",
1641			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
1642		success = false
1643	}
1644	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
1645		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1646	}
1647	return success
1648}
1649
1650func runECDSATestGroup(algorithm string, wtg *wycheproofTestGroupECDSA) bool {
1651	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1652		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1653
1654	nid, err := nidFromString(wtg.Key.Curve)
1655	if err != nil {
1656		log.Fatalf("Failed to get nid for curve: %v", err)
1657	}
1658	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1659	if ecKey == nil {
1660		log.Fatal("EC_KEY_new_by_curve_name failed")
1661	}
1662	defer C.EC_KEY_free(ecKey)
1663
1664	var bnX *C.BIGNUM
1665	wx := C.CString(wtg.Key.WX)
1666	if C.BN_hex2bn(&bnX, wx) == 0 {
1667		log.Fatal("Failed to decode WX")
1668	}
1669	C.free(unsafe.Pointer(wx))
1670	defer C.BN_free(bnX)
1671
1672	var bnY *C.BIGNUM
1673	wy := C.CString(wtg.Key.WY)
1674	if C.BN_hex2bn(&bnY, wy) == 0 {
1675		log.Fatal("Failed to decode WY")
1676	}
1677	C.free(unsafe.Pointer(wy))
1678	defer C.BN_free(bnY)
1679
1680	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1681		log.Fatal("Failed to set EC public key")
1682	}
1683
1684	nid, err = nidFromString(wtg.SHA)
1685	if err != nil {
1686		log.Fatalf("Failed to get MD NID: %v", err)
1687	}
1688	h, err := hashFromString(wtg.SHA)
1689	if err != nil {
1690		log.Fatalf("Failed to get hash: %v", err)
1691	}
1692
1693	success := true
1694	for _, wt := range wtg.Tests {
1695		if !runECDSATest(ecKey, nid, h, false, wt) {
1696			success = false
1697		}
1698	}
1699	return success
1700}
1701
1702// DER encode the signature (so that ECDSA_verify() can decode and encode it again...)
1703func encodeECDSAWebCryptoSig(wtSig string) (*C.uchar, C.int) {
1704	cSig := C.ECDSA_SIG_new()
1705	if cSig == nil {
1706		log.Fatal("ECDSA_SIG_new() failed")
1707	}
1708	defer C.ECDSA_SIG_free(cSig)
1709
1710	sigLen := len(wtSig)
1711	r := C.CString(wtSig[:sigLen/2])
1712	s := C.CString(wtSig[sigLen/2:])
1713	defer C.free(unsafe.Pointer(r))
1714	defer C.free(unsafe.Pointer(s))
1715	if C.BN_hex2bn(&cSig.r, r) == 0 {
1716		return nil, 0
1717	}
1718	if C.BN_hex2bn(&cSig.s, s) == 0 {
1719		return nil, 0
1720	}
1721
1722	derLen := C.i2d_ECDSA_SIG(cSig, nil)
1723	if derLen == 0 {
1724		return nil, 0
1725	}
1726	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1727	if cDer == nil {
1728		log.Fatal("malloc failed")
1729	}
1730
1731	p := cDer
1732	ret := C.i2d_ECDSA_SIG(cSig, (**C.uchar)(&p))
1733	if ret == 0 || ret != derLen {
1734		C.free(unsafe.Pointer(cDer))
1735		return nil, 0
1736	}
1737
1738	return cDer, derLen
1739}
1740
1741func runECDSAWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDSAWebCrypto) bool {
1742	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1743		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1744
1745	nid, err := nidFromString(wtg.JWK.Crv)
1746	if err != nil {
1747		log.Fatalf("Failed to get nid for curve: %v", err)
1748	}
1749	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1750	if ecKey == nil {
1751		log.Fatal("EC_KEY_new_by_curve_name failed")
1752	}
1753	defer C.EC_KEY_free(ecKey)
1754
1755	x, err := base64.RawURLEncoding.DecodeString(wtg.JWK.X)
1756	if err != nil {
1757		log.Fatalf("Failed to base64 decode X: %v", err)
1758	}
1759	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1760	if bnX == nil {
1761		log.Fatal("Failed to decode X")
1762	}
1763	defer C.BN_free(bnX)
1764
1765	y, err := base64.RawURLEncoding.DecodeString(wtg.JWK.Y)
1766	if err != nil {
1767		log.Fatalf("Failed to base64 decode Y: %v", err)
1768	}
1769	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1770	if bnY == nil {
1771		log.Fatal("Failed to decode Y")
1772	}
1773	defer C.BN_free(bnY)
1774
1775	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1776		log.Fatal("Failed to set EC public key")
1777	}
1778
1779	nid, err = nidFromString(wtg.SHA)
1780	if err != nil {
1781		log.Fatalf("Failed to get MD NID: %v", err)
1782	}
1783	h, err := hashFromString(wtg.SHA)
1784	if err != nil {
1785		log.Fatalf("Failed to get hash: %v", err)
1786	}
1787
1788	success := true
1789	for _, wt := range wtg.Tests {
1790		if !runECDSATest(ecKey, nid, h, true, wt) {
1791			success = false
1792		}
1793	}
1794	return success
1795}
1796
1797func runHkdfTest(md *C.EVP_MD, wt *wycheproofTestHkdf) bool {
1798	ikm, err := hex.DecodeString(wt.Ikm)
1799	if err != nil {
1800		log.Fatalf("Failed to decode ikm %q: %v", wt.Ikm, err)
1801	}
1802	salt, err := hex.DecodeString(wt.Salt)
1803	if err != nil {
1804		log.Fatalf("Failed to decode salt %q: %v", wt.Salt, err)
1805	}
1806	info, err := hex.DecodeString(wt.Info)
1807	if err != nil {
1808		log.Fatalf("Failed to decode info %q: %v", wt.Info, err)
1809	}
1810
1811	ikmLen, saltLen, infoLen := len(ikm), len(salt), len(info)
1812	if ikmLen == 0 {
1813		ikm = append(ikm, 0)
1814	}
1815	if saltLen == 0 {
1816		salt = append(salt, 0)
1817	}
1818	if infoLen == 0 {
1819		info = append(info, 0)
1820	}
1821
1822	outLen := wt.Size
1823	out := make([]byte, outLen)
1824	if outLen == 0 {
1825		out = append(out, 0)
1826	}
1827
1828	ret := C.HKDF((*C.uchar)(unsafe.Pointer(&out[0])), C.size_t(outLen), md, (*C.uchar)(unsafe.Pointer(&ikm[0])), C.size_t(ikmLen), (*C.uchar)(&salt[0]), C.size_t(saltLen), (*C.uchar)(unsafe.Pointer(&info[0])), C.size_t(infoLen))
1829
1830	if ret != 1 {
1831		success := wt.Result == "invalid"
1832		if !success {
1833			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1834		}
1835		return success
1836	}
1837
1838	okm, err := hex.DecodeString(wt.Okm)
1839	if err != nil {
1840		log.Fatalf("Failed to decode okm %q: %v", wt.Okm, err)
1841	}
1842	if !bytes.Equal(out[:outLen], okm) {
1843		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed output don't match: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
1844	}
1845
1846	return wt.Result == "valid"
1847}
1848
1849func runHkdfTestGroup(algorithm string, wtg *wycheproofTestGroupHkdf) bool {
1850	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
1851	md, err := hashEvpMdFromString(strings.TrimPrefix(algorithm, "HKDF-"))
1852	if err != nil {
1853		log.Fatalf("Failed to get hash: %v", err)
1854	}
1855
1856	success := true
1857	for _, wt := range wtg.Tests {
1858		if !runHkdfTest(md, wt) {
1859			success = false
1860		}
1861	}
1862	return success
1863}
1864
1865func runKWTestWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
1866	var aesKey C.AES_KEY
1867
1868	ret := C.AES_set_encrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
1869	if ret != 0 {
1870		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
1871			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
1872		return false
1873	}
1874
1875	outLen := msgLen
1876	out := make([]byte, outLen)
1877	copy(out, msg)
1878	out = append(out, make([]byte, 8)...)
1879	ret = C.AES_wrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(msgLen))
1880	success := false
1881	if ret == C.int(len(out)) && bytes.Equal(out, ct) {
1882		if acceptableAudit && wt.Result == "acceptable" {
1883			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1884		}
1885		if wt.Result != "invalid" {
1886			success = true
1887		}
1888	} else if wt.Result != "valid" {
1889		success = true
1890	}
1891	if !success {
1892		fmt.Printf("FAIL: Test case %d (%q) %v - msgLen = %d, AES_wrap_key() = %d, want %v\n",
1893			wt.TCID, wt.Comment, wt.Flags, msgLen, int(ret), wt.Result)
1894	}
1895	return success
1896}
1897
1898func runKWTestUnWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
1899	var aesKey C.AES_KEY
1900
1901	ret := C.AES_set_decrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
1902	if ret != 0 {
1903		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
1904			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
1905		return false
1906	}
1907
1908	out := make([]byte, ctLen)
1909	copy(out, ct)
1910	if ctLen == 0 {
1911		out = append(out, 0)
1912	}
1913	ret = C.AES_unwrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(ctLen))
1914	success := false
1915	if ret == C.int(ctLen-8) && bytes.Equal(out[0:ret], msg[0:ret]) {
1916		if acceptableAudit && wt.Result == "acceptable" {
1917			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1918		}
1919		if wt.Result != "invalid" {
1920			success = true
1921		}
1922	} else if wt.Result != "valid" {
1923		success = true
1924	}
1925	if !success {
1926		fmt.Printf("FAIL: Test case %d (%q) %v - keyLen = %d, AES_unwrap_key() = %d, want %v\n",
1927			wt.TCID, wt.Comment, wt.Flags, keyLen, int(ret), wt.Result)
1928	}
1929	return success
1930}
1931
1932func runKWTest(keySize int, wt *wycheproofTestKW) bool {
1933	key, err := hex.DecodeString(wt.Key)
1934	if err != nil {
1935		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
1936	}
1937	msg, err := hex.DecodeString(wt.Msg)
1938	if err != nil {
1939		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
1940	}
1941	ct, err := hex.DecodeString(wt.CT)
1942	if err != nil {
1943		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1944	}
1945
1946	keyLen, msgLen, ctLen := len(key), len(msg), len(ct)
1947
1948	if keyLen == 0 {
1949		key = append(key, 0)
1950	}
1951	if msgLen == 0 {
1952		msg = append(msg, 0)
1953	}
1954	if ctLen == 0 {
1955		ct = append(ct, 0)
1956	}
1957
1958	wrapSuccess := runKWTestWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
1959	unwrapSuccess := runKWTestUnWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
1960
1961	return wrapSuccess && unwrapSuccess
1962}
1963
1964func runKWTestGroup(algorithm string, wtg *wycheproofTestGroupKW) bool {
1965	fmt.Printf("Running %v test group %v with key size %d...\n",
1966		algorithm, wtg.Type, wtg.KeySize)
1967
1968	success := true
1969	for _, wt := range wtg.Tests {
1970		if !runKWTest(wtg.KeySize, wt) {
1971			success = false
1972		}
1973	}
1974	return success
1975}
1976
1977func runRsaesOaepTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, wt *wycheproofTestRsaes) bool {
1978	ct, err := hex.DecodeString(wt.CT)
1979	if err != nil {
1980		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
1981	}
1982	ctLen := len(ct)
1983	if ctLen == 0 {
1984		ct = append(ct, 0)
1985	}
1986
1987	rsaSize := C.RSA_size(rsa)
1988	decrypted := make([]byte, rsaSize)
1989
1990	success := true
1991
1992	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_NO_PADDING)
1993
1994	if ret != rsaSize {
1995		success = (wt.Result == "invalid")
1996
1997		if !success {
1998			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
1999		}
2000		return success
2001	}
2002
2003	label, err := hex.DecodeString(wt.Label)
2004	if err != nil {
2005		log.Fatalf("Failed to decode label %q: %v", wt.Label, err)
2006	}
2007	labelLen := len(label)
2008	if labelLen == 0 {
2009		label = append(label, 0)
2010	}
2011
2012	msg, err := hex.DecodeString(wt.Msg)
2013	if err != nil {
2014		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2015	}
2016	msgLen := len(msg)
2017
2018	to := make([]byte, rsaSize)
2019
2020	ret = C.RSA_padding_check_PKCS1_OAEP_mgf1((*C.uchar)(unsafe.Pointer(&to[0])), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&decrypted[0])), C.int(rsaSize), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&label[0])), C.int(labelLen), sha, mgfSha)
2021
2022	if int(ret) != msgLen {
2023		success = (wt.Result == "invalid")
2024
2025		if !success {
2026			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
2027		}
2028		return success
2029	}
2030
2031	to = to[:msgLen]
2032	if !bytes.Equal(msg, to) {
2033		success = false
2034		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2035	}
2036
2037	return success
2038}
2039
2040func runRsaesOaepTestGroup(algorithm string, wtg *wycheproofTestGroupRsaesOaep) bool {
2041	fmt.Printf("Running %v test group %v with key size %d MGF %v and %v...\n",
2042		algorithm, wtg.Type, wtg.KeySize, wtg.MGFSHA, wtg.SHA)
2043
2044	rsa := C.RSA_new()
2045	if rsa == nil {
2046		log.Fatal("RSA_new failed")
2047	}
2048	defer C.RSA_free(rsa)
2049
2050	d := C.CString(wtg.D)
2051	if C.BN_hex2bn(&rsa.d, d) == 0 {
2052		log.Fatal("Failed to set RSA d")
2053	}
2054	C.free(unsafe.Pointer(d))
2055
2056	e := C.CString(wtg.E)
2057	if C.BN_hex2bn(&rsa.e, e) == 0 {
2058		log.Fatal("Failed to set RSA e")
2059	}
2060	C.free(unsafe.Pointer(e))
2061
2062	n := C.CString(wtg.N)
2063	if C.BN_hex2bn(&rsa.n, n) == 0 {
2064		log.Fatal("Failed to set RSA n")
2065	}
2066	C.free(unsafe.Pointer(n))
2067
2068	sha, err := hashEvpMdFromString(wtg.SHA)
2069	if err != nil {
2070		log.Fatalf("Failed to get hash: %v", err)
2071	}
2072
2073	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2074	if err != nil {
2075		log.Fatalf("Failed to get MGF hash: %v", err)
2076	}
2077
2078	success := true
2079	for _, wt := range wtg.Tests {
2080		if !runRsaesOaepTest(rsa, sha, mgfSha, wt) {
2081			success = false
2082		}
2083	}
2084	return success
2085}
2086
2087func runRsaesPkcs1Test(rsa *C.RSA, wt *wycheproofTestRsaes) bool {
2088	ct, err := hex.DecodeString(wt.CT)
2089	if err != nil {
2090		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
2091	}
2092	ctLen := len(ct)
2093	if ctLen == 0 {
2094		ct = append(ct, 0)
2095	}
2096
2097	rsaSize := C.RSA_size(rsa)
2098	decrypted := make([]byte, rsaSize)
2099
2100	success := true
2101
2102	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_PKCS1_PADDING)
2103
2104	if ret == -1 {
2105		success = (wt.Result == "invalid")
2106
2107		if !success {
2108			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(wt.Msg)/2, wt.Result)
2109		}
2110		return success
2111	}
2112
2113	msg, err := hex.DecodeString(wt.Msg)
2114	if err != nil {
2115		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2116	}
2117
2118	if int(ret) != len(msg) {
2119		success = false
2120		fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(msg), wt.Result)
2121	} else if !bytes.Equal(msg, decrypted[:len(msg)]) {
2122		success = false
2123		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2124	}
2125
2126	return success
2127}
2128
2129func runRsaesPkcs1TestGroup(algorithm string, wtg *wycheproofTestGroupRsaesPkcs1) bool {
2130	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2131	rsa := C.RSA_new()
2132	if rsa == nil {
2133		log.Fatal("RSA_new failed")
2134	}
2135	defer C.RSA_free(rsa)
2136
2137	d := C.CString(wtg.D)
2138	if C.BN_hex2bn(&rsa.d, d) == 0 {
2139		log.Fatal("Failed to set RSA d")
2140	}
2141	C.free(unsafe.Pointer(d))
2142
2143	e := C.CString(wtg.E)
2144	if C.BN_hex2bn(&rsa.e, e) == 0 {
2145		log.Fatal("Failed to set RSA e")
2146	}
2147	C.free(unsafe.Pointer(e))
2148
2149	n := C.CString(wtg.N)
2150	if C.BN_hex2bn(&rsa.n, n) == 0 {
2151		log.Fatal("Failed to set RSA n")
2152	}
2153	C.free(unsafe.Pointer(n))
2154
2155	success := true
2156	for _, wt := range wtg.Tests {
2157		if !runRsaesPkcs1Test(rsa, wt) {
2158			success = false
2159		}
2160	}
2161	return success
2162}
2163
2164func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
2165	msg, err := hex.DecodeString(wt.Msg)
2166	if err != nil {
2167		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2168	}
2169
2170	h.Reset()
2171	h.Write(msg)
2172	msg = h.Sum(nil)
2173
2174	sig, err := hex.DecodeString(wt.Sig)
2175	if err != nil {
2176		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2177	}
2178
2179	msgLen, sigLen := len(msg), len(sig)
2180	if msgLen == 0 {
2181		msg = append(msg, 0)
2182	}
2183	if sigLen == 0 {
2184		sig = append(sig, 0)
2185	}
2186
2187	sigOut := make([]byte, C.RSA_size(rsa)-11)
2188	if sigLen == 0 {
2189		sigOut = append(sigOut, 0)
2190	}
2191
2192	ret := C.RSA_public_decrypt(C.int(sigLen), (*C.uchar)(unsafe.Pointer(&sig[0])),
2193		(*C.uchar)(unsafe.Pointer(&sigOut[0])), rsa, C.RSA_NO_PADDING)
2194	if ret == -1 {
2195		if wt.Result == "invalid" {
2196			return true
2197		}
2198		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_public_decrypt() = %d, want %v\n",
2199			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2200		return false
2201	}
2202
2203	ret = C.RSA_verify_PKCS1_PSS_mgf1(rsa, (*C.uchar)(unsafe.Pointer(&msg[0])), sha, mgfSha,
2204		(*C.uchar)(unsafe.Pointer(&sigOut[0])), C.int(sLen))
2205
2206	success := false
2207	if ret == 1 && (wt.Result == "valid" || wt.Result == "acceptable") {
2208		// All acceptable cases that pass use SHA-1 and are flagged:
2209		// "WeakHash" : "The key for this test vector uses a weak hash function."
2210		if acceptableAudit && wt.Result == "acceptable" {
2211			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2212		}
2213		success = true
2214	} else if ret == 0 && (wt.Result == "invalid" || wt.Result == "acceptable") {
2215		success = true
2216	} else {
2217		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify_PKCS1_PSS_mgf1() = %d, want %v\n",
2218			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2219	}
2220	return success
2221}
2222
2223func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
2224	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2225		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2226	rsa := C.RSA_new()
2227	if rsa == nil {
2228		log.Fatal("RSA_new failed")
2229	}
2230	defer C.RSA_free(rsa)
2231
2232	e := C.CString(wtg.E)
2233	if C.BN_hex2bn(&rsa.e, e) == 0 {
2234		log.Fatal("Failed to set RSA e")
2235	}
2236	C.free(unsafe.Pointer(e))
2237
2238	n := C.CString(wtg.N)
2239	if C.BN_hex2bn(&rsa.n, n) == 0 {
2240		log.Fatal("Failed to set RSA n")
2241	}
2242	C.free(unsafe.Pointer(n))
2243
2244	h, err := hashFromString(wtg.SHA)
2245	if err != nil {
2246		log.Fatalf("Failed to get hash: %v", err)
2247	}
2248
2249	sha, err := hashEvpMdFromString(wtg.SHA)
2250	if err != nil {
2251		log.Fatalf("Failed to get hash: %v", err)
2252	}
2253
2254	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2255	if err != nil {
2256		log.Fatalf("Failed to get MGF hash: %v", err)
2257	}
2258
2259	success := true
2260	for _, wt := range wtg.Tests {
2261		if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
2262			success = false
2263		}
2264	}
2265	return success
2266}
2267
2268func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
2269	msg, err := hex.DecodeString(wt.Msg)
2270	if err != nil {
2271		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2272	}
2273
2274	h.Reset()
2275	h.Write(msg)
2276	msg = h.Sum(nil)
2277
2278	sig, err := hex.DecodeString(wt.Sig)
2279	if err != nil {
2280		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2281	}
2282
2283	msgLen, sigLen := len(msg), len(sig)
2284	if msgLen == 0 {
2285		msg = append(msg, 0)
2286	}
2287	if sigLen == 0 {
2288		sig = append(sig, 0)
2289	}
2290
2291	ret := C.RSA_verify(C.int(nid), (*C.uchar)(unsafe.Pointer(&msg[0])), C.uint(msgLen),
2292		(*C.uchar)(unsafe.Pointer(&sig[0])), C.uint(sigLen), rsa)
2293
2294	// XXX audit acceptable cases...
2295	success := true
2296	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
2297		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify() = %d, want %v\n",
2298			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2299		success = false
2300	}
2301	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
2302		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2303	}
2304	return success
2305}
2306
2307func runRSATestGroup(algorithm string, wtg *wycheproofTestGroupRSA) bool {
2308	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2309		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2310
2311	rsa := C.RSA_new()
2312	if rsa == nil {
2313		log.Fatal("RSA_new failed")
2314	}
2315	defer C.RSA_free(rsa)
2316
2317	e := C.CString(wtg.E)
2318	if C.BN_hex2bn(&rsa.e, e) == 0 {
2319		log.Fatal("Failed to set RSA e")
2320	}
2321	C.free(unsafe.Pointer(e))
2322
2323	n := C.CString(wtg.N)
2324	if C.BN_hex2bn(&rsa.n, n) == 0 {
2325		log.Fatal("Failed to set RSA n")
2326	}
2327	C.free(unsafe.Pointer(n))
2328
2329	nid, err := nidFromString(wtg.SHA)
2330	if err != nil {
2331		log.Fatalf("Failed to get MD NID: %v", err)
2332	}
2333	h, err := hashFromString(wtg.SHA)
2334	if err != nil {
2335		log.Fatalf("Failed to get hash: %v", err)
2336	}
2337
2338	success := true
2339	for _, wt := range wtg.Tests {
2340		if !runRSATest(rsa, nid, h, wt) {
2341			success = false
2342		}
2343	}
2344	return success
2345}
2346
2347func runX25519Test(wt *wycheproofTestX25519) bool {
2348	public, err := hex.DecodeString(wt.Public)
2349	if err != nil {
2350		log.Fatalf("Failed to decode public %q: %v", wt.Public, err)
2351	}
2352	private, err := hex.DecodeString(wt.Private)
2353	if err != nil {
2354		log.Fatalf("Failed to decode private %q: %v", wt.Private, err)
2355	}
2356	shared, err := hex.DecodeString(wt.Shared)
2357	if err != nil {
2358		log.Fatalf("Failed to decode shared %q: %v", wt.Shared, err)
2359	}
2360
2361	got := make([]byte, C.X25519_KEY_LENGTH)
2362	result := true
2363
2364	if C.X25519((*C.uint8_t)(unsafe.Pointer(&got[0])), (*C.uint8_t)(unsafe.Pointer(&private[0])), (*C.uint8_t)(unsafe.Pointer(&public[0]))) != 1 {
2365		result = false
2366	} else {
2367		result = bytes.Equal(got, shared)
2368	}
2369
2370	// XXX audit acceptable cases...
2371	success := true
2372	if result != (wt.Result == "valid") && wt.Result != "acceptable" {
2373		fmt.Printf("FAIL: Test case %d (%q) %v - X25519(), want %v\n",
2374			wt.TCID, wt.Comment, wt.Flags, wt.Result)
2375		success = false
2376	}
2377	if acceptableAudit && result && wt.Result == "acceptable" {
2378		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2379	}
2380	return success
2381}
2382
2383func runX25519TestGroup(algorithm string, wtg *wycheproofTestGroupX25519) bool {
2384	fmt.Printf("Running %v test group with curve %v...\n", algorithm, wtg.Curve)
2385
2386	success := true
2387	for _, wt := range wtg.Tests {
2388		if !runX25519Test(wt) {
2389			success = false
2390		}
2391	}
2392	return success
2393}
2394
2395func runTestVectors(path string, webcrypto bool) bool {
2396	b, err := ioutil.ReadFile(path)
2397	if err != nil {
2398		log.Fatalf("Failed to read test vectors: %v", err)
2399	}
2400	wtv := &wycheproofTestVectors{}
2401	if err := json.Unmarshal(b, wtv); err != nil {
2402		log.Fatalf("Failed to unmarshal JSON: %v", err)
2403	}
2404	fmt.Printf("Loaded Wycheproof test vectors for %v with %d tests from %q\n",
2405		wtv.Algorithm, wtv.NumberOfTests, filepath.Base(path))
2406
2407	var wtg interface{}
2408	switch wtv.Algorithm {
2409	case "AES-CBC-PKCS5":
2410		wtg = &wycheproofTestGroupAesCbcPkcs5{}
2411	case "AES-CCM":
2412		wtg = &wycheproofTestGroupAead{}
2413	case "AES-CMAC":
2414		wtg = &wycheproofTestGroupAesCmac{}
2415	case "AES-GCM":
2416		wtg = &wycheproofTestGroupAead{}
2417	case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2418		wtg = &wycheproofTestGroupAead{}
2419	case "DSA":
2420		wtg = &wycheproofTestGroupDSA{}
2421	case "ECDH":
2422		if webcrypto {
2423			wtg = &wycheproofTestGroupECDHWebCrypto{}
2424		} else {
2425			wtg = &wycheproofTestGroupECDH{}
2426		}
2427	case "ECDSA":
2428		if webcrypto {
2429			wtg = &wycheproofTestGroupECDSAWebCrypto{}
2430		} else {
2431			wtg = &wycheproofTestGroupECDSA{}
2432		}
2433	case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2434		wtg = &wycheproofTestGroupHkdf{}
2435	case "KW":
2436		wtg = &wycheproofTestGroupKW{}
2437	case "RSAES-OAEP":
2438		wtg = &wycheproofTestGroupRsaesOaep{}
2439	case "RSAES-PKCS1-v1_5":
2440		wtg = &wycheproofTestGroupRsaesPkcs1{}
2441	case "RSASSA-PSS":
2442		wtg = &wycheproofTestGroupRsassa{}
2443	case "RSASSA-PKCS1-v1_5", "RSASig":
2444		wtg = &wycheproofTestGroupRSA{}
2445	case "XDH", "X25519":
2446		wtg = &wycheproofTestGroupX25519{}
2447	default:
2448		log.Printf("INFO: Unknown test vector algorithm %q", wtv.Algorithm)
2449		return false
2450	}
2451
2452	success := true
2453	for _, tg := range wtv.TestGroups {
2454		if err := json.Unmarshal(tg, wtg); err != nil {
2455			log.Fatalf("Failed to unmarshal test groups JSON: %v", err)
2456		}
2457		switch wtv.Algorithm {
2458		case "AES-CBC-PKCS5":
2459			if !runAesCbcPkcs5TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCbcPkcs5)) {
2460				success = false
2461			}
2462		case "AES-CCM":
2463			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2464				success = false
2465			}
2466		case "AES-CMAC":
2467			if !runAesCmacTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCmac)) {
2468				success = false
2469			}
2470		case "AES-GCM":
2471			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2472				success = false
2473			}
2474		case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2475			if !runChaCha20Poly1305TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2476				success = false
2477			}
2478		case "DSA":
2479			if !runDSATestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupDSA)) {
2480				success = false
2481			}
2482		case "ECDH":
2483			if webcrypto {
2484				if !runECDHWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDHWebCrypto)) {
2485					success = false
2486				}
2487			} else {
2488				if !runECDHTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDH)) {
2489					success = false
2490				}
2491			}
2492		case "ECDSA":
2493			if webcrypto {
2494				if !runECDSAWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDSAWebCrypto)) {
2495					success = false
2496				}
2497			} else {
2498				if !runECDSATestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDSA)) {
2499					success = false
2500				}
2501			}
2502		case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2503			if !runHkdfTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupHkdf)) {
2504				success = false
2505			}
2506		case "KW":
2507			if !runKWTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupKW)) {
2508				success = false
2509			}
2510		case "RSAES-OAEP":
2511			if !runRsaesOaepTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesOaep)) {
2512				success = false
2513			}
2514		case "RSAES-PKCS1-v1_5":
2515			if !runRsaesPkcs1TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesPkcs1)) {
2516				success = false
2517			}
2518		case "RSASSA-PSS":
2519			if !runRsassaTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsassa)) {
2520				success = false
2521			}
2522		case "RSASSA-PKCS1-v1_5", "RSASig":
2523			if !runRSATestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRSA)) {
2524				success = false
2525			}
2526		case "XDH", "X25519":
2527			if !runX25519TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupX25519)) {
2528				success = false
2529			}
2530		default:
2531			log.Fatalf("Unknown test vector algorithm %q", wtv.Algorithm)
2532		}
2533	}
2534	return success
2535}
2536
2537func main() {
2538	if _, err := os.Stat(testVectorPath); os.IsNotExist(err) {
2539		fmt.Printf("package wycheproof-testvectors is required for this regress\n")
2540		fmt.Printf("SKIPPED\n")
2541		os.Exit(0)
2542	}
2543
2544	flag.BoolVar(&acceptableAudit, "v", false, "audit acceptable cases")
2545	flag.Parse()
2546
2547	acceptableComments = make(map[string]int)
2548	acceptableFlags = make(map[string]int)
2549
2550	// TODO: Investigate the following new test vectors:
2551	//	primality_test.json
2552	//	x25519_{asn,jwk,pem}_test.json
2553	// What's up with the *_p1363_test.json files?
2554	tests := []struct {
2555		name    string
2556		pattern string
2557	}{
2558		{"AES", "aes_[cg]*[^xv]_test.json"}, // Skip AES-EAX, AES-GCM-SIV and AES-SIV-CMAC.
2559		{"ChaCha20-Poly1305", "chacha20_poly1305_test.json"},
2560		{"DSA", "dsa_*test.json"},
2561		{"ECDH", "ecdh_test.json"},
2562		{"ECDH", "ecdh_[^w]*test.json"},
2563		{"ECDHWebCrypto", "ecdh_webcrypto_test.json"},
2564		{"ECDSA", "ecdsa_[^w]*test.json"},
2565		{"ECDSA", "ecdsa_test.json"},
2566		{"ECDSAWebCrypto", "ecdsa_webcrypto_test.json"},
2567		{"HKDF", "hkdf_sha*_test.json"},
2568		{"KW", "kw_test.json"},
2569		{"RSA", "rsa_*test.json"},
2570		{"X25519", "x25519_test.json"},
2571		{"XCHACHA20-POLY1305", "xchacha20_poly1305_test.json"},
2572	}
2573
2574	success := true
2575
2576	skip := regexp.MustCompile(`_(p1363|sha3|sha512_(224|256))_`)
2577
2578	for _, test := range tests {
2579		webcrypto := test.name == "ECDSAWebCrypto" || test.name == "ECDHWebCrypto"
2580		tvs, err := filepath.Glob(filepath.Join(testVectorPath, test.pattern))
2581		if err != nil {
2582			log.Fatalf("Failed to glob %v test vectors: %v", test.name, err)
2583		}
2584		// XXX put check back after wycheproof-testvectors update to 20191126
2585		// if len(tvs) == 0 {
2586		// 	log.Fatalf("Failed to find %v test vectors at %q\n", test.name, testVectorPath)
2587		// }
2588		for _, tv := range tvs {
2589			if skip.Match([]byte(tv)) {
2590				fmt.Printf("INFO: Skipping tests from \"%s\"\n", strings.TrimPrefix(tv, testVectorPath+"/"))
2591				continue
2592			}
2593			if !runTestVectors(tv, webcrypto) {
2594				success = false
2595			}
2596		}
2597	}
2598
2599	if acceptableAudit {
2600		printAcceptableStatistics()
2601	}
2602
2603	if !success {
2604		os.Exit(1)
2605	}
2606}
2607