1/* $OpenBSD: wycheproof.go,v 1.120 2020/05/14 18:11:45 tb Exp $ */
2/*
3 * Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
4 * Copyright (c) 2018, 2019 Theo Buehler <tb@openbsd.org>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19// Wycheproof runs test vectors from Project Wycheproof against libcrypto.
20package main
21
22/*
23#cgo LDFLAGS: -lcrypto
24
25#include <string.h>
26
27#include <openssl/aes.h>
28#include <openssl/bio.h>
29#include <openssl/bn.h>
30#include <openssl/cmac.h>
31#include <openssl/curve25519.h>
32#include <openssl/dsa.h>
33#include <openssl/ec.h>
34#include <openssl/ecdsa.h>
35#include <openssl/evp.h>
36#include <openssl/hkdf.h>
37#include <openssl/hmac.h>
38#include <openssl/objects.h>
39#include <openssl/pem.h>
40#include <openssl/x509.h>
41#include <openssl/rsa.h>
42*/
43import "C"
44
45import (
46	"bytes"
47	"crypto/sha1"
48	"crypto/sha256"
49	"crypto/sha512"
50	"encoding/base64"
51	"encoding/hex"
52	"encoding/json"
53	"flag"
54	"fmt"
55	"hash"
56	"io/ioutil"
57	"log"
58	"os"
59	"path/filepath"
60	"regexp"
61	"sort"
62	"strings"
63	"unsafe"
64)
65
66const testVectorPath = "/usr/local/share/wycheproof/testvectors"
67
68type testVariant int
69
70const (
71	Normal    testVariant = 0
72	EcPoint   testVariant = 1
73	P1363     testVariant = 2
74	Webcrypto testVariant = 3
75	Asn1      testVariant = 4
76	Pem       testVariant = 5
77	Jwk       testVariant = 6
78	Skip      testVariant = 7
79)
80
81func (variant testVariant) String() string {
82	variants := [...]string{
83		"Normal",
84		"EcPoint",
85		"P1363",
86		"Webcrypto",
87		"Asn1",
88		"Pem",
89		"Jwk",
90		"Skip",
91	}
92	return variants[variant]
93}
94
95var acceptableAudit = false
96var acceptableComments map[string]int
97var acceptableFlags map[string]int
98
99type wycheproofJWKPublic struct {
100	Crv string `json:"crv"`
101	KID string `json:"kid"`
102	KTY string `json:"kty"`
103	X   string `json:"x"`
104	Y   string `json:"y"`
105}
106
107type wycheproofJWKPrivate struct {
108	Crv string `json:"crv"`
109	D   string `json:"d"`
110	KID string `json:"kid"`
111	KTY string `json:"kty"`
112	X   string `json:"x"`
113	Y   string `json:"y"`
114}
115
116type wycheproofTestGroupAesCbcPkcs5 struct {
117	IVSize  int                          `json:"ivSize"`
118	KeySize int                          `json:"keySize"`
119	Type    string                       `json:"type"`
120	Tests   []*wycheproofTestAesCbcPkcs5 `json:"tests"`
121}
122
123type wycheproofTestAesCbcPkcs5 struct {
124	TCID    int      `json:"tcId"`
125	Comment string   `json:"comment"`
126	Key     string   `json:"key"`
127	IV      string   `json:"iv"`
128	Msg     string   `json:"msg"`
129	CT      string   `json:"ct"`
130	Result  string   `json:"result"`
131	Flags   []string `json:"flags"`
132}
133
134type wycheproofTestGroupAead struct {
135	IVSize  int                   `json:"ivSize"`
136	KeySize int                   `json:"keySize"`
137	TagSize int                   `json:"tagSize"`
138	Type    string                `json:"type"`
139	Tests   []*wycheproofTestAead `json:"tests"`
140}
141
142type wycheproofTestAead struct {
143	TCID    int      `json:"tcId"`
144	Comment string   `json:"comment"`
145	Key     string   `json:"key"`
146	IV      string   `json:"iv"`
147	AAD     string   `json:"aad"`
148	Msg     string   `json:"msg"`
149	CT      string   `json:"ct"`
150	Tag     string   `json:"tag"`
151	Result  string   `json:"result"`
152	Flags   []string `json:"flags"`
153}
154
155type wycheproofTestGroupAesCmac struct {
156	KeySize int                      `json:"keySize"`
157	TagSize int                      `json:"tagSize"`
158	Type    string                   `json:"type"`
159	Tests   []*wycheproofTestAesCmac `json:"tests"`
160}
161
162type wycheproofTestAesCmac struct {
163	TCID    int      `json:"tcId"`
164	Comment string   `json:"comment"`
165	Key     string   `json:"key"`
166	Msg     string   `json:"msg"`
167	Tag     string   `json:"tag"`
168	Result  string   `json:"result"`
169	Flags   []string `json:"flags"`
170}
171
172type wycheproofDSAKey struct {
173	G       string `json:"g"`
174	KeySize int    `json:"keySize"`
175	P       string `json:"p"`
176	Q       string `json:"q"`
177	Type    string `json:"type"`
178	Y       string `json:"y"`
179}
180
181type wycheproofTestDSA struct {
182	TCID    int      `json:"tcId"`
183	Comment string   `json:"comment"`
184	Msg     string   `json:"msg"`
185	Sig     string   `json:"sig"`
186	Result  string   `json:"result"`
187	Flags   []string `json:"flags"`
188}
189
190type wycheproofTestGroupDSA struct {
191	Key    *wycheproofDSAKey    `json:"key"`
192	KeyDER string               `json:"keyDer"`
193	KeyPEM string               `json:"keyPem"`
194	SHA    string               `json:"sha"`
195	Type   string               `json:"type"`
196	Tests  []*wycheproofTestDSA `json:"tests"`
197}
198
199type wycheproofTestECDH struct {
200	TCID    int      `json:"tcId"`
201	Comment string   `json:"comment"`
202	Public  string   `json:"public"`
203	Private string   `json:"private"`
204	Shared  string   `json:"shared"`
205	Result  string   `json:"result"`
206	Flags   []string `json:"flags"`
207}
208
209type wycheproofTestGroupECDH struct {
210	Curve    string                `json:"curve"`
211	Encoding string                `json:"encoding"`
212	Type     string                `json:"type"`
213	Tests    []*wycheproofTestECDH `json:"tests"`
214}
215
216type wycheproofTestECDHWebCrypto struct {
217	TCID    int                   `json:"tcId"`
218	Comment string                `json:"comment"`
219	Public  *wycheproofJWKPublic  `json:"public"`
220	Private *wycheproofJWKPrivate `json:"private"`
221	Shared  string                `json:"shared"`
222	Result  string                `json:"result"`
223	Flags   []string              `json:"flags"`
224}
225
226type wycheproofTestGroupECDHWebCrypto struct {
227	Curve    string                         `json:"curve"`
228	Encoding string                         `json:"encoding"`
229	Type     string                         `json:"type"`
230	Tests    []*wycheproofTestECDHWebCrypto `json:"tests"`
231}
232
233type wycheproofECDSAKey struct {
234	Curve        string `json:"curve"`
235	KeySize      int    `json:"keySize"`
236	Type         string `json:"type"`
237	Uncompressed string `json:"uncompressed"`
238	WX           string `json:"wx"`
239	WY           string `json:"wy"`
240}
241
242type wycheproofTestECDSA struct {
243	TCID    int      `json:"tcId"`
244	Comment string   `json:"comment"`
245	Msg     string   `json:"msg"`
246	Sig     string   `json:"sig"`
247	Result  string   `json:"result"`
248	Flags   []string `json:"flags"`
249}
250
251type wycheproofTestGroupECDSA struct {
252	Key    *wycheproofECDSAKey    `json:"key"`
253	KeyDER string                 `json:"keyDer"`
254	KeyPEM string                 `json:"keyPem"`
255	SHA    string                 `json:"sha"`
256	Type   string                 `json:"type"`
257	Tests  []*wycheproofTestECDSA `json:"tests"`
258}
259
260type wycheproofTestGroupECDSAWebCrypto struct {
261	JWK    *wycheproofJWKPublic   `json:"jwk"`
262	Key    *wycheproofECDSAKey    `json:"key"`
263	KeyDER string                 `json:"keyDer"`
264	KeyPEM string                 `json:"keyPem"`
265	SHA    string                 `json:"sha"`
266	Type   string                 `json:"type"`
267	Tests  []*wycheproofTestECDSA `json:"tests"`
268}
269
270type wycheproofTestHkdf struct {
271	TCID    int      `json:"tcId"`
272	Comment string   `json:"comment"`
273	Ikm     string   `json:"ikm"`
274	Salt    string   `json:"salt"`
275	Info    string   `json:"info"`
276	Size    int      `json:"size"`
277	Okm     string   `json:"okm"`
278	Result  string   `json:"result"`
279	Flags   []string `json:"flags"`
280}
281
282type wycheproofTestGroupHkdf struct {
283	Type    string                `json:"type"`
284	KeySize int                   `json:"keySize"`
285	Tests   []*wycheproofTestHkdf `json:"tests"`
286}
287
288type wycheproofTestHmac struct {
289	TCID    int      `json:"tcId"`
290	Comment string   `json:"comment"`
291	Key     string   `json:"key"`
292	Msg     string   `json:"msg"`
293	Tag     string   `json:"tag"`
294	Result  string   `json:"result"`
295	Flags   []string `json:"flags"`
296}
297
298type wycheproofTestGroupHmac struct {
299	KeySize int                   `json:"keySize"`
300	TagSize int                   `json:"tagSize"`
301	Type    string                `json:"type"`
302	Tests   []*wycheproofTestHmac `json:"tests"`
303}
304
305type wycheproofTestKW struct {
306	TCID    int      `json:"tcId"`
307	Comment string   `json:"comment"`
308	Key     string   `json:"key"`
309	Msg     string   `json:"msg"`
310	CT      string   `json:"ct"`
311	Result  string   `json:"result"`
312	Flags   []string `json:"flags"`
313}
314
315type wycheproofTestGroupKW struct {
316	KeySize int                 `json:"keySize"`
317	Type    string              `json:"type"`
318	Tests   []*wycheproofTestKW `json:"tests"`
319}
320
321type wycheproofTestRSA struct {
322	TCID    int      `json:"tcId"`
323	Comment string   `json:"comment"`
324	Msg     string   `json:"msg"`
325	Sig     string   `json:"sig"`
326	Padding string   `json:"padding"`
327	Result  string   `json:"result"`
328	Flags   []string `json:"flags"`
329}
330
331type wycheproofTestGroupRSA struct {
332	E       string               `json:"e"`
333	KeyASN  string               `json:"keyAsn"`
334	KeyDER  string               `json:"keyDer"`
335	KeyPEM  string               `json:"keyPem"`
336	KeySize int                  `json:"keysize"`
337	N       string               `json:"n"`
338	SHA     string               `json:"sha"`
339	Type    string               `json:"type"`
340	Tests   []*wycheproofTestRSA `json:"tests"`
341}
342
343type wycheproofPrivateKeyJwk struct {
344	Alg string `json:"alg"`
345	D   string `json:"d"`
346	DP  string `json:"dp"`
347	DQ  string `json:"dq"`
348	E   string `json:"e"`
349	KID string `json:"kid"`
350	Kty string `json:"kty"`
351	N   string `json:"n"`
352	P   string `json:"p"`
353	Q   string `json:"q"`
354	QI  string `json:"qi"`
355}
356
357type wycheproofTestRsaes struct {
358	TCID    int      `json:"tcId"`
359	Comment string   `json:"comment"`
360	Msg     string   `json:"msg"`
361	CT      string   `json:"ct"`
362	Label   string   `json:"label"`
363	Result  string   `json:"result"`
364	Flags   []string `json:"flags"`
365}
366
367type wycheproofTestGroupRsaesOaep struct {
368	D               string                   `json:"d"`
369	E               string                   `json:"e"`
370	KeySize         int                      `json:"keysize"`
371	MGF             string                   `json:"mgf"`
372	MGFSHA          string                   `json:"mgfSha"`
373	N               string                   `json:"n"`
374	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
375	PrivateKeyPem   string                   `json:"privateKeyPem"`
376	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
377	SHA             string                   `json:"sha"`
378	Type            string                   `json:"type"`
379	Tests           []*wycheproofTestRsaes   `json:"tests"`
380}
381
382type wycheproofTestGroupRsaesPkcs1 struct {
383	D               string                   `json:"d"`
384	E               string                   `json:"e"`
385	KeySize         int                      `json:"keysize"`
386	N               string                   `json:"n"`
387	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
388	PrivateKeyPem   string                   `json:"privateKeyPem"`
389	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
390	Type            string                   `json:"type"`
391	Tests           []*wycheproofTestRsaes   `json:"tests"`
392}
393
394type wycheproofTestRsassa struct {
395	TCID    int      `json:"tcId"`
396	Comment string   `json:"comment"`
397	Msg     string   `json:"msg"`
398	Sig     string   `json:"sig"`
399	Result  string   `json:"result"`
400	Flags   []string `json:"flags"`
401}
402
403type wycheproofTestGroupRsassa struct {
404	E       string                  `json:"e"`
405	KeyASN  string                  `json:"keyAsn"`
406	KeyDER  string                  `json:"keyDer"`
407	KeyPEM  string                  `json:"keyPem"`
408	KeySize int                     `json:"keysize"`
409	MGF     string                  `json:"mgf"`
410	MGFSHA  string                  `json:"mgfSha"`
411	N       string                  `json:"n"`
412	SLen    int                     `json:"sLen"`
413	SHA     string                  `json:"sha"`
414	Type    string                  `json:"type"`
415	Tests   []*wycheproofTestRsassa `json:"tests"`
416}
417
418type wycheproofTestX25519 struct {
419	TCID    int      `json:"tcId"`
420	Comment string   `json:"comment"`
421	Curve   string   `json:"curve"`
422	Public  string   `json:"public"`
423	Private string   `json:"private"`
424	Shared  string   `json:"shared"`
425	Result  string   `json:"result"`
426	Flags   []string `json:"flags"`
427}
428
429type wycheproofTestGroupX25519 struct {
430	Curve string                  `json:"curve"`
431	Tests []*wycheproofTestX25519 `json:"tests"`
432}
433
434type wycheproofTestVectors struct {
435	Algorithm        string            `json:"algorithm"`
436	GeneratorVersion string            `json:"generatorVersion"`
437	Notes            map[string]string `json:"notes"`
438	NumberOfTests    int               `json:"numberOfTests"`
439	// Header
440	TestGroups []json.RawMessage `json:"testGroups"`
441}
442
443var nids = map[string]int{
444	"brainpoolP224r1": C.NID_brainpoolP224r1,
445	"brainpoolP256r1": C.NID_brainpoolP256r1,
446	"brainpoolP320r1": C.NID_brainpoolP320r1,
447	"brainpoolP384r1": C.NID_brainpoolP384r1,
448	"brainpoolP512r1": C.NID_brainpoolP512r1,
449	"brainpoolP224t1": C.NID_brainpoolP224t1,
450	"brainpoolP256t1": C.NID_brainpoolP256t1,
451	"brainpoolP320t1": C.NID_brainpoolP320t1,
452	"brainpoolP384t1": C.NID_brainpoolP384t1,
453	"brainpoolP512t1": C.NID_brainpoolP512t1,
454	"secp224k1":       C.NID_secp224k1,
455	"secp224r1":       C.NID_secp224r1,
456	"secp256k1":       C.NID_secp256k1,
457	"P-256K":          C.NID_secp256k1,
458	"secp256r1":       C.NID_X9_62_prime256v1, // RFC 8422, Table 4, p.32
459	"P-256":           C.NID_X9_62_prime256v1,
460	"secp384r1":       C.NID_secp384r1,
461	"P-384":           C.NID_secp384r1,
462	"secp521r1":       C.NID_secp521r1,
463	"P-521":           C.NID_secp521r1,
464	"SHA-1":           C.NID_sha1,
465	"SHA-224":         C.NID_sha224,
466	"SHA-256":         C.NID_sha256,
467	"SHA-384":         C.NID_sha384,
468	"SHA-512":         C.NID_sha512,
469}
470
471func gatherAcceptableStatistics(testcase int, comment string, flags []string) {
472	fmt.Printf("AUDIT: Test case %d (%q) %v\n", testcase, comment, flags)
473
474	if comment == "" {
475		acceptableComments["No comment"]++
476	} else {
477		acceptableComments[comment]++
478	}
479
480	if len(flags) == 0 {
481		acceptableFlags["NoFlag"]++
482	} else {
483		for _, flag := range flags {
484			acceptableFlags[flag]++
485		}
486	}
487}
488
489func printAcceptableStatistics() {
490	fmt.Printf("\nComment statistics:\n")
491
492	var comments []string
493	for comment := range acceptableComments {
494		comments = append(comments, comment)
495	}
496	sort.Strings(comments)
497	for _, comment := range comments {
498		prcomment := comment
499		if len(comment) > 45 {
500			prcomment = comment[0:42] + "..."
501		}
502		fmt.Printf("%-45v %5d\n", prcomment, acceptableComments[comment])
503	}
504
505	fmt.Printf("\nFlag statistics:\n")
506	var flags []string
507	for flag := range acceptableFlags {
508		flags = append(flags, flag)
509	}
510	sort.Strings(flags)
511	for _, flag := range flags {
512		fmt.Printf("%-45v %5d\n", flag, acceptableFlags[flag])
513	}
514}
515
516func nidFromString(ns string) (int, error) {
517	nid, ok := nids[ns]
518	if ok {
519		return nid, nil
520	}
521	return -1, fmt.Errorf("unknown NID %q", ns)
522}
523
524func hashFromString(hs string) (hash.Hash, error) {
525	switch hs {
526	case "SHA-1":
527		return sha1.New(), nil
528	case "SHA-224":
529		return sha256.New224(), nil
530	case "SHA-256":
531		return sha256.New(), nil
532	case "SHA-384":
533		return sha512.New384(), nil
534	case "SHA-512":
535		return sha512.New(), nil
536	default:
537		return nil, fmt.Errorf("unknown hash %q", hs)
538	}
539}
540
541func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
542	switch hs {
543	case "SHA-1":
544		return C.EVP_sha1(), nil
545	case "SHA-224":
546		return C.EVP_sha224(), nil
547	case "SHA-256":
548		return C.EVP_sha256(), nil
549	case "SHA-384":
550		return C.EVP_sha384(), nil
551	case "SHA-512":
552		return C.EVP_sha512(), nil
553	default:
554		return nil, fmt.Errorf("unknown hash %q", hs)
555	}
556}
557
558func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
559	iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
560	wt *wycheproofTestAesCbcPkcs5) bool {
561	var action string
562	if doEncrypt == 1 {
563		action = "encrypting"
564	} else {
565		action = "decrypting"
566	}
567
568	ret := C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
569		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
570	if ret != 1 {
571		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
572	}
573
574	cipherOut := make([]byte, inLen+C.EVP_MAX_BLOCK_LENGTH)
575	var cipherOutLen C.int
576
577	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
578		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
579	if ret != 1 {
580		if wt.Result == "invalid" {
581			fmt.Printf("INFO: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
582				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
583			return true
584		}
585		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
586			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
587		return false
588	}
589
590	var finallen C.int
591	ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[cipherOutLen])), &finallen)
592	if ret != 1 {
593		if wt.Result == "invalid" {
594			return true
595		}
596		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
597			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
598		return false
599	}
600
601	cipherOutLen += finallen
602	if cipherOutLen != C.int(outLen) && wt.Result != "invalid" {
603		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - open length mismatch: got %d, want %d\n",
604			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen)
605		return false
606	}
607
608	openedMsg := cipherOut[0:cipherOutLen]
609	if outLen == 0 {
610		out = nil
611	}
612
613	success := false
614	if bytes.Equal(openedMsg, out) == (wt.Result != "invalid") {
615		success = true
616		if acceptableAudit && wt.Result == "acceptable" {
617			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
618		}
619	} else {
620		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - msg match: %t; want %v\n",
621			wt.TCID, wt.Comment, action, wt.Flags, bytes.Equal(openedMsg, out), wt.Result)
622	}
623	return success
624}
625
626func runAesCbcPkcs5Test(ctx *C.EVP_CIPHER_CTX, wt *wycheproofTestAesCbcPkcs5) bool {
627	key, err := hex.DecodeString(wt.Key)
628	if err != nil {
629		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
630	}
631	iv, err := hex.DecodeString(wt.IV)
632	if err != nil {
633		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
634	}
635	ct, err := hex.DecodeString(wt.CT)
636	if err != nil {
637		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
638	}
639	msg, err := hex.DecodeString(wt.Msg)
640	if err != nil {
641		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
642	}
643
644	keyLen, ivLen, ctLen, msgLen := len(key), len(iv), len(ct), len(msg)
645
646	if keyLen == 0 {
647		key = append(key, 0)
648	}
649	if ivLen == 0 {
650		iv = append(iv, 0)
651	}
652	if ctLen == 0 {
653		ct = append(ct, 0)
654	}
655	if msgLen == 0 {
656		msg = append(msg, 0)
657	}
658
659	openSuccess := checkAesCbcPkcs5(ctx, 0, key, keyLen, iv, ivLen, ct, ctLen, msg, msgLen, wt)
660	sealSuccess := checkAesCbcPkcs5(ctx, 1, key, keyLen, iv, ivLen, msg, msgLen, ct, ctLen, wt)
661
662	return openSuccess && sealSuccess
663}
664
665func runAesCbcPkcs5TestGroup(algorithm string, wtg *wycheproofTestGroupAesCbcPkcs5) bool {
666	fmt.Printf("Running %v test group %v with IV size %d and key size %d...\n",
667		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize)
668
669	var cipher *C.EVP_CIPHER
670	switch wtg.KeySize {
671	case 128:
672		cipher = C.EVP_aes_128_cbc()
673	case 192:
674		cipher = C.EVP_aes_192_cbc()
675	case 256:
676		cipher = C.EVP_aes_256_cbc()
677	default:
678		log.Fatalf("Unsupported key size: %d", wtg.KeySize)
679	}
680
681	ctx := C.EVP_CIPHER_CTX_new()
682	if ctx == nil {
683		log.Fatal("EVP_CIPHER_CTX_new() failed")
684	}
685	defer C.EVP_CIPHER_CTX_free(ctx)
686
687	ret := C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 0)
688	if ret != 1 {
689		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
690	}
691
692	success := true
693	for _, wt := range wtg.Tests {
694		if !runAesCbcPkcs5Test(ctx, wt) {
695			success = false
696		}
697	}
698	return success
699}
700
701func checkAesAead(algorithm string, ctx *C.EVP_CIPHER_CTX, doEncrypt int,
702	key []byte, keyLen int, iv []byte, ivLen int, aad []byte, aadLen int,
703	in []byte, inLen int, out []byte, outLen int, tag []byte, tagLen int,
704	wt *wycheproofTestAead) bool {
705	var ctrlSetIVLen C.int
706	var ctrlSetTag C.int
707	var ctrlGetTag C.int
708
709	doCCM := false
710	switch algorithm {
711	case "AES-CCM":
712		doCCM = true
713		ctrlSetIVLen = C.EVP_CTRL_CCM_SET_IVLEN
714		ctrlSetTag = C.EVP_CTRL_CCM_SET_TAG
715		ctrlGetTag = C.EVP_CTRL_CCM_GET_TAG
716	case "AES-GCM":
717		ctrlSetIVLen = C.EVP_CTRL_GCM_SET_IVLEN
718		ctrlSetTag = C.EVP_CTRL_GCM_SET_TAG
719		ctrlGetTag = C.EVP_CTRL_GCM_GET_TAG
720	}
721
722	setTag := unsafe.Pointer(nil)
723	var action string
724
725	if doEncrypt == 1 {
726		action = "encrypting"
727	} else {
728		action = "decrypting"
729		setTag = unsafe.Pointer(&tag[0])
730	}
731
732	ret := C.EVP_CipherInit_ex(ctx, nil, nil, nil, nil, C.int(doEncrypt))
733	if ret != 1 {
734		log.Fatalf("[%v] cipher init failed", action)
735	}
736
737	ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetIVLen, C.int(ivLen), nil)
738	if ret != 1 {
739		if wt.Comment == "Nonce is too long" || wt.Comment == "Invalid nonce size" ||
740			wt.Comment == "0 size IV is not valid" || wt.Comment == "Very long nonce" {
741			return true
742		}
743		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting IV len to %d failed. got %d, want %v\n",
744			wt.TCID, wt.Comment, action, wt.Flags, ivLen, ret, wt.Result)
745		return false
746	}
747
748	if doEncrypt == 0 || doCCM {
749		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetTag, C.int(tagLen), setTag)
750		if ret != 1 {
751			if wt.Comment == "Invalid tag size" {
752				return true
753			}
754			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting tag length to %d failed. got %d, want %v\n",
755				wt.TCID, wt.Comment, action, wt.Flags, tagLen, ret, wt.Result)
756			return false
757		}
758	}
759
760	ret = C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
761		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
762	if ret != 1 {
763		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting key and IV failed. got %d, want %v\n",
764			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
765		return false
766	}
767
768	var cipherOutLen C.int
769	if doCCM {
770		ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, nil, C.int(inLen))
771		if ret != 1 {
772			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting input length to %d failed. got %d, want %v\n",
773				wt.TCID, wt.Comment, action, wt.Flags, inLen, ret, wt.Result)
774			return false
775		}
776	}
777
778	ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, (*C.uchar)(unsafe.Pointer(&aad[0])), C.int(aadLen))
779	if ret != 1 {
780		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - processing AAD failed. got %d, want %v\n",
781			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
782		return false
783	}
784
785	cipherOutLen = 0
786	cipherOut := make([]byte, inLen)
787	if inLen == 0 {
788		cipherOut = append(cipherOut, 0)
789	}
790
791	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
792		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
793	if ret != 1 {
794		if wt.Result == "invalid" {
795			return true
796		}
797		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
798			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
799		return false
800	}
801
802	if doEncrypt == 1 {
803		var tmpLen C.int
804		dummyOut := make([]byte, 16)
805
806		ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&dummyOut[0])), &tmpLen)
807		if ret != 1 {
808			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
809				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
810			return false
811		}
812		cipherOutLen += tmpLen
813	}
814
815	if cipherOutLen != C.int(outLen) {
816		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - cipherOutLen %d != outLen %d. Result %v\n",
817			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen, wt.Result)
818		return false
819	}
820
821	success := true
822	if !bytes.Equal(cipherOut, out) {
823		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed output do not match. Result: %v\n",
824			wt.TCID, wt.Comment, action, wt.Flags, wt.Result)
825		success = false
826	}
827	if doEncrypt == 1 {
828		tagOut := make([]byte, tagLen)
829		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlGetTag, C.int(tagLen), unsafe.Pointer(&tagOut[0]))
830		if ret != 1 {
831			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CIPHER_CTX_ctrl() = %d, want %v\n",
832				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
833			return false
834		}
835
836		// There are no acceptable CCM cases. All acceptable GCM tests
837		// pass. They have len(IV) <= 48. NIST SP 800-38D, 5.2.1.1, p.8,
838		// allows 1 <= len(IV) <= 2^64-1, but notes:
839		//   "For IVs it is recommended that implementations restrict
840		//    support to the length of 96 bits, to promote
841		//    interoperability, efficiency and simplicity of design."
842		if bytes.Equal(tagOut, tag) != (wt.Result == "valid" || wt.Result == "acceptable") {
843			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed tag do not match - ret: %d, Result: %v\n",
844				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
845			success = false
846		}
847		if acceptableAudit && bytes.Equal(tagOut, tag) && wt.Result == "acceptable" {
848			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
849		}
850	}
851	return success
852}
853
854func runAesAeadTest(algorithm string, ctx *C.EVP_CIPHER_CTX, aead *C.EVP_AEAD, wt *wycheproofTestAead) bool {
855	key, err := hex.DecodeString(wt.Key)
856	if err != nil {
857		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
858	}
859
860	iv, err := hex.DecodeString(wt.IV)
861	if err != nil {
862		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
863	}
864
865	aad, err := hex.DecodeString(wt.AAD)
866	if err != nil {
867		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
868	}
869
870	msg, err := hex.DecodeString(wt.Msg)
871	if err != nil {
872		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
873	}
874
875	ct, err := hex.DecodeString(wt.CT)
876	if err != nil {
877		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
878	}
879
880	tag, err := hex.DecodeString(wt.Tag)
881	if err != nil {
882		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
883	}
884
885	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
886
887	if keyLen == 0 {
888		key = append(key, 0)
889	}
890	if ivLen == 0 {
891		iv = append(iv, 0)
892	}
893	if aadLen == 0 {
894		aad = append(aad, 0)
895	}
896	if msgLen == 0 {
897		msg = append(msg, 0)
898	}
899	if ctLen == 0 {
900		ct = append(ct, 0)
901	}
902	if tagLen == 0 {
903		tag = append(tag, 0)
904	}
905
906	openEvp := checkAesAead(algorithm, ctx, 0, key, keyLen, iv, ivLen, aad, aadLen, ct, ctLen, msg, msgLen, tag, tagLen, wt)
907	sealEvp := checkAesAead(algorithm, ctx, 1, key, keyLen, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
908
909	openAead, sealAead := true, true
910	if aead != nil {
911		var ctx C.EVP_AEAD_CTX
912		if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
913			log.Fatal("Failed to initialize AEAD context")
914		}
915		defer C.EVP_AEAD_CTX_cleanup(&ctx)
916
917		// Make sure we don't accidentally prepend or compare against a 0.
918		if ctLen == 0 {
919			ct = nil
920		}
921
922		openAead = checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
923		sealAead = checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
924	}
925
926	return openEvp && sealEvp && openAead && sealAead
927}
928
929func runAesAeadTestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
930	fmt.Printf("Running %v test group %v with IV size %d, key size %d and tag size %d...\n",
931		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
932
933	var cipher *C.EVP_CIPHER
934	var aead *C.EVP_AEAD
935	switch algorithm {
936	case "AES-CCM":
937		switch wtg.KeySize {
938		case 128:
939			cipher = C.EVP_aes_128_ccm()
940		case 192:
941			cipher = C.EVP_aes_192_ccm()
942		case 256:
943			cipher = C.EVP_aes_256_ccm()
944		default:
945			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
946			return true
947		}
948	case "AES-GCM":
949		switch wtg.KeySize {
950		case 128:
951			cipher = C.EVP_aes_128_gcm()
952			aead = C.EVP_aead_aes_128_gcm()
953		case 192:
954			cipher = C.EVP_aes_192_gcm()
955		case 256:
956			cipher = C.EVP_aes_256_gcm()
957			aead = C.EVP_aead_aes_256_gcm()
958		default:
959			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
960			return true
961		}
962	default:
963		log.Fatalf("runAesAeadTestGroup() - unhandled algorithm: %v", algorithm)
964	}
965
966	ctx := C.EVP_CIPHER_CTX_new()
967	if ctx == nil {
968		log.Fatal("EVP_CIPHER_CTX_new() failed")
969	}
970	defer C.EVP_CIPHER_CTX_free(ctx)
971
972	C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 1)
973
974	success := true
975	for _, wt := range wtg.Tests {
976		if !runAesAeadTest(algorithm, ctx, aead, wt) {
977			success = false
978		}
979	}
980	return success
981}
982
983func runAesCmacTest(cipher *C.EVP_CIPHER, wt *wycheproofTestAesCmac) bool {
984	key, err := hex.DecodeString(wt.Key)
985	if err != nil {
986		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
987	}
988
989	msg, err := hex.DecodeString(wt.Msg)
990	if err != nil {
991		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
992	}
993
994	tag, err := hex.DecodeString(wt.Tag)
995	if err != nil {
996		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
997	}
998
999	keyLen, msgLen, tagLen := len(key), len(msg), len(tag)
1000
1001	if keyLen == 0 {
1002		key = append(key, 0)
1003	}
1004	if msgLen == 0 {
1005		msg = append(msg, 0)
1006	}
1007	if tagLen == 0 {
1008		tag = append(tag, 0)
1009	}
1010
1011	ctx := C.CMAC_CTX_new()
1012	if ctx == nil {
1013		log.Fatal("CMAC_CTX_new failed")
1014	}
1015	defer C.CMAC_CTX_free(ctx)
1016
1017	ret := C.CMAC_Init(ctx, unsafe.Pointer(&key[0]), C.size_t(keyLen), cipher, nil)
1018	if ret != 1 {
1019		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Init() = %d, want %v\n",
1020			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1021		return false
1022	}
1023
1024	ret = C.CMAC_Update(ctx, unsafe.Pointer(&msg[0]), C.size_t(msgLen))
1025	if ret != 1 {
1026		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Update() = %d, want %v\n",
1027			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1028		return false
1029	}
1030
1031	var outLen C.size_t
1032	outTag := make([]byte, 16)
1033
1034	ret = C.CMAC_Final(ctx, (*C.uchar)(unsafe.Pointer(&outTag[0])), &outLen)
1035	if ret != 1 {
1036		fmt.Printf("FAIL: Test case %d (%q) %v - CMAC_Final() = %d, want %v\n",
1037			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1038		return false
1039	}
1040
1041	outTag = outTag[0:tagLen]
1042
1043	success := true
1044	if bytes.Equal(tag, outTag) != (wt.Result == "valid") {
1045		fmt.Printf("FAIL: Test case %d (%q) %v - want %v\n",
1046			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1047		success = false
1048	}
1049	return success
1050}
1051
1052func runAesCmacTestGroup(algorithm string, wtg *wycheproofTestGroupAesCmac) bool {
1053	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n",
1054		algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
1055	var cipher *C.EVP_CIPHER
1056
1057	switch wtg.KeySize {
1058	case 128:
1059		cipher = C.EVP_aes_128_cbc()
1060	case 192:
1061		cipher = C.EVP_aes_192_cbc()
1062	case 256:
1063		cipher = C.EVP_aes_256_cbc()
1064	default:
1065		fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
1066		return true
1067	}
1068
1069	success := true
1070	for _, wt := range wtg.Tests {
1071		if !runAesCmacTest(cipher, wt) {
1072			success = false
1073		}
1074	}
1075	return success
1076}
1077
1078func checkAeadOpen(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte, msgLen int,
1079	ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1080	maxOutLen := ctLen + tagLen
1081
1082	opened := make([]byte, maxOutLen)
1083	if maxOutLen == 0 {
1084		opened = append(opened, 0)
1085	}
1086	var openedMsgLen C.size_t
1087
1088	catCtTag := append(ct, tag...)
1089	catCtTagLen := len(catCtTag)
1090	if catCtTagLen == 0 {
1091		catCtTag = append(catCtTag, 0)
1092	}
1093	openRet := C.EVP_AEAD_CTX_open(ctx, (*C.uint8_t)(unsafe.Pointer(&opened[0])),
1094		(*C.size_t)(unsafe.Pointer(&openedMsgLen)), C.size_t(maxOutLen),
1095		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1096		(*C.uint8_t)(unsafe.Pointer(&catCtTag[0])), C.size_t(catCtTagLen),
1097		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1098
1099	if openRet != 1 {
1100		if wt.Result == "invalid" {
1101			return true
1102		}
1103		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_open() = %d, want %v\n",
1104			wt.TCID, wt.Comment, wt.Flags, int(openRet), wt.Result)
1105		return false
1106	}
1107
1108	if openedMsgLen != C.size_t(msgLen) {
1109		fmt.Printf("FAIL: Test case %d (%q) %v - open length mismatch: got %d, want %d\n",
1110			wt.TCID, wt.Comment, wt.Flags, openedMsgLen, msgLen)
1111		return false
1112	}
1113
1114	openedMsg := opened[0:openedMsgLen]
1115	if msgLen == 0 {
1116		msg = nil
1117	}
1118
1119	success := false
1120	if bytes.Equal(openedMsg, msg) == (wt.Result != "invalid") {
1121		if acceptableAudit && wt.Result == "acceptable" {
1122			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1123		}
1124		success = true
1125	} else {
1126		fmt.Printf("FAIL: Test case %d (%q) %v - msg match: %t; want %v\n",
1127			wt.TCID, wt.Comment, wt.Flags, bytes.Equal(openedMsg, msg), wt.Result)
1128	}
1129	return success
1130}
1131
1132func checkAeadSeal(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte,
1133	msgLen int, ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1134	maxOutLen := msgLen + tagLen
1135
1136	sealed := make([]byte, maxOutLen)
1137	if maxOutLen == 0 {
1138		sealed = append(sealed, 0)
1139	}
1140	var sealedLen C.size_t
1141
1142	sealRet := C.EVP_AEAD_CTX_seal(ctx, (*C.uint8_t)(unsafe.Pointer(&sealed[0])),
1143		(*C.size_t)(unsafe.Pointer(&sealedLen)), C.size_t(maxOutLen),
1144		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1145		(*C.uint8_t)(unsafe.Pointer(&msg[0])), C.size_t(msgLen),
1146		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1147
1148	if sealRet != 1 {
1149		success := (wt.Result == "invalid")
1150		if !success {
1151			fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, int(sealRet), wt.Result)
1152		}
1153		return success
1154	}
1155
1156	if sealedLen != C.size_t(maxOutLen) {
1157		fmt.Printf("FAIL: Test case %d (%q) %v - seal length mismatch: got %d, want %d\n",
1158			wt.TCID, wt.Comment, wt.Flags, sealedLen, maxOutLen)
1159		return false
1160	}
1161
1162	sealedCt := sealed[0:msgLen]
1163	sealedTag := sealed[msgLen:maxOutLen]
1164
1165	success := false
1166	if (bytes.Equal(sealedCt, ct) && bytes.Equal(sealedTag, tag)) == (wt.Result != "invalid") {
1167		if acceptableAudit && wt.Result == "acceptable" {
1168			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1169		}
1170		success = true
1171	} else {
1172		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, ct match: %t, tag match: %t; want %v\n",
1173			wt.TCID, wt.Comment, wt.Flags, int(sealRet),
1174			bytes.Equal(sealedCt, ct), bytes.Equal(sealedTag, tag), wt.Result)
1175	}
1176	return success
1177}
1178
1179func runChaCha20Poly1305Test(algorithm string, wt *wycheproofTestAead) bool {
1180	var aead *C.EVP_AEAD
1181	switch algorithm {
1182	case "CHACHA20-POLY1305":
1183		aead = C.EVP_aead_chacha20_poly1305()
1184	case "XCHACHA20-POLY1305":
1185		aead = C.EVP_aead_xchacha20_poly1305()
1186	}
1187
1188	key, err := hex.DecodeString(wt.Key)
1189	if err != nil {
1190		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
1191	}
1192	iv, err := hex.DecodeString(wt.IV)
1193	if err != nil {
1194		log.Fatalf("Failed to decode key %q: %v", wt.IV, err)
1195	}
1196	aad, err := hex.DecodeString(wt.AAD)
1197	if err != nil {
1198		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
1199	}
1200	msg, err := hex.DecodeString(wt.Msg)
1201	if err != nil {
1202		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
1203	}
1204	ct, err := hex.DecodeString(wt.CT)
1205	if err != nil {
1206		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1207	}
1208	tag, err := hex.DecodeString(wt.Tag)
1209	if err != nil {
1210		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1211	}
1212
1213	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
1214
1215	if ivLen == 0 {
1216		iv = append(iv, 0)
1217	}
1218	if aadLen == 0 {
1219		aad = append(aad, 0)
1220	}
1221	if msgLen == 0 {
1222		msg = append(msg, 0)
1223	}
1224	if ctLen == 0 {
1225		msg = append(ct, 0)
1226	}
1227	if tagLen == 0 {
1228		msg = append(tag, 0)
1229	}
1230
1231	var ctx C.EVP_AEAD_CTX
1232	if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
1233		log.Fatal("Failed to initialize AEAD context")
1234	}
1235	defer C.EVP_AEAD_CTX_cleanup(&ctx)
1236
1237	openSuccess := checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1238	sealSuccess := checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1239
1240	return openSuccess && sealSuccess
1241}
1242
1243func runChaCha20Poly1305TestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
1244	// ChaCha20-Poly1305 currently only supports nonces of length 12 (96 bits)
1245	if algorithm == "CHACHA20-POLY1305" && wtg.IVSize != 96 {
1246		return true
1247	}
1248
1249	fmt.Printf("Running %v test group %v with IV size %d, key size %d, tag size %d...\n",
1250		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
1251
1252	success := true
1253	for _, wt := range wtg.Tests {
1254		if !runChaCha20Poly1305Test(algorithm, wt) {
1255			success = false
1256		}
1257	}
1258	return success
1259}
1260
1261// DER encode the signature (so DSA_verify() can decode and encode it again)
1262func encodeDSAP1363Sig(wtSig string) (*C.uchar, C.int) {
1263	cSig := C.DSA_SIG_new()
1264	if cSig == nil {
1265		log.Fatal("DSA_SIG_new() failed")
1266	}
1267	defer C.DSA_SIG_free(cSig)
1268
1269	sigLen := len(wtSig)
1270	r := C.CString(wtSig[:sigLen/2])
1271	s := C.CString(wtSig[sigLen/2:])
1272	defer C.free(unsafe.Pointer(r))
1273	defer C.free(unsafe.Pointer(s))
1274	if C.BN_hex2bn(&cSig.r, r) == 0 {
1275		return nil, 0
1276	}
1277	if C.BN_hex2bn(&cSig.s, s) == 0 {
1278		return nil, 0
1279	}
1280
1281	derLen := C.i2d_DSA_SIG(cSig, nil)
1282	if derLen == 0 {
1283		return nil, 0
1284	}
1285	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1286	if cDer == nil {
1287		log.Fatal("malloc failed")
1288	}
1289
1290	p := cDer
1291	ret := C.i2d_DSA_SIG(cSig, (**C.uchar)(&p))
1292	if ret == 0 || ret != derLen {
1293		C.free(unsafe.Pointer(cDer))
1294		return nil, 0
1295	}
1296
1297	return cDer, derLen
1298}
1299
1300func runDSATest(dsa *C.DSA, variant testVariant, h hash.Hash, wt *wycheproofTestDSA) bool {
1301	msg, err := hex.DecodeString(wt.Msg)
1302	if err != nil {
1303		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1304	}
1305
1306	h.Reset()
1307	h.Write(msg)
1308	msg = h.Sum(nil)
1309
1310	msgLen := len(msg)
1311	if msgLen == 0 {
1312		msg = append(msg, 0)
1313	}
1314
1315	var ret C.int
1316	if variant == P1363 {
1317		cDer, derLen := encodeDSAP1363Sig(wt.Sig)
1318		if cDer == nil {
1319			fmt.Print("FAIL: unable to decode signature")
1320			return false
1321		}
1322		defer C.free(unsafe.Pointer(cDer))
1323
1324		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1325			(*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), dsa)
1326	} else {
1327		sig, err := hex.DecodeString(wt.Sig)
1328		if err != nil {
1329			log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1330		}
1331		sigLen := len(sig)
1332		if sigLen == 0 {
1333			sig = append(msg, 0)
1334		}
1335		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1336			(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), dsa)
1337	}
1338
1339	success := true
1340	if ret == 1 != (wt.Result == "valid") {
1341		fmt.Printf("FAIL: Test case %d (%q) %v - DSA_verify() = %d, want %v\n",
1342			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1343		success = false
1344	}
1345	return success
1346}
1347
1348func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupDSA) bool {
1349	fmt.Printf("Running %v test group %v, key size %d and %v...\n",
1350		algorithm, wtg.Type, wtg.Key.KeySize, wtg.SHA)
1351
1352	dsa := C.DSA_new()
1353	if dsa == nil {
1354		log.Fatal("DSA_new failed")
1355	}
1356	defer C.DSA_free(dsa)
1357
1358	var bnG *C.BIGNUM
1359	wg := C.CString(wtg.Key.G)
1360	if C.BN_hex2bn(&bnG, wg) == 0 {
1361		log.Fatal("Failed to decode g")
1362	}
1363	C.free(unsafe.Pointer(wg))
1364
1365	var bnP *C.BIGNUM
1366	wp := C.CString(wtg.Key.P)
1367	if C.BN_hex2bn(&bnP, wp) == 0 {
1368		log.Fatal("Failed to decode p")
1369	}
1370	C.free(unsafe.Pointer(wp))
1371
1372	var bnQ *C.BIGNUM
1373	wq := C.CString(wtg.Key.Q)
1374	if C.BN_hex2bn(&bnQ, wq) == 0 {
1375		log.Fatal("Failed to decode q")
1376	}
1377	C.free(unsafe.Pointer(wq))
1378
1379	ret := C.DSA_set0_pqg(dsa, bnP, bnQ, bnG)
1380	if ret != 1 {
1381		log.Fatalf("DSA_set0_pqg returned %d", ret)
1382	}
1383
1384	var bnY *C.BIGNUM
1385	wy := C.CString(wtg.Key.Y)
1386	if C.BN_hex2bn(&bnY, wy) == 0 {
1387		log.Fatal("Failed to decode y")
1388	}
1389	C.free(unsafe.Pointer(wy))
1390
1391	ret = C.DSA_set0_key(dsa, bnY, nil)
1392	if ret != 1 {
1393		log.Fatalf("DSA_set0_key returned %d", ret)
1394	}
1395
1396	h, err := hashFromString(wtg.SHA)
1397	if err != nil {
1398		log.Fatalf("Failed to get hash: %v", err)
1399	}
1400
1401	der, err := hex.DecodeString(wtg.KeyDER)
1402	if err != nil {
1403		log.Fatalf("Failed to decode DER encoded key: %v", err)
1404	}
1405
1406	derLen := len(der)
1407	if derLen == 0 {
1408		der = append(der, 0)
1409	}
1410
1411	Cder := (*C.uchar)(C.malloc(C.ulong(derLen)))
1412	if Cder == nil {
1413		log.Fatal("malloc failed")
1414	}
1415	C.memcpy(unsafe.Pointer(Cder), unsafe.Pointer(&der[0]), C.ulong(derLen))
1416
1417	p := (*C.uchar)(Cder)
1418	dsaDER := C.d2i_DSA_PUBKEY(nil, (**C.uchar)(&p), C.long(derLen))
1419	defer C.DSA_free(dsaDER)
1420	C.free(unsafe.Pointer(Cder))
1421
1422	keyPEM := C.CString(wtg.KeyPEM)
1423	bio := C.BIO_new_mem_buf(unsafe.Pointer(keyPEM), C.int(len(wtg.KeyPEM)))
1424	if bio == nil {
1425		log.Fatal("BIO_new_mem_buf failed")
1426	}
1427	defer C.free(unsafe.Pointer(keyPEM))
1428	defer C.BIO_free(bio)
1429
1430	dsaPEM := C.PEM_read_bio_DSA_PUBKEY(bio, nil, nil, nil)
1431	if dsaPEM == nil {
1432		log.Fatal("PEM_read_bio_DSA_PUBKEY failed")
1433	}
1434	defer C.DSA_free(dsaPEM)
1435
1436	success := true
1437	for _, wt := range wtg.Tests {
1438		if !runDSATest(dsa, variant, h, wt) {
1439			success = false
1440		}
1441		if !runDSATest(dsaDER, variant, h, wt) {
1442			success = false
1443		}
1444		if !runDSATest(dsaPEM, variant, h, wt) {
1445			success = false
1446		}
1447	}
1448	return success
1449}
1450
1451func runECDHTest(nid int, variant testVariant, wt *wycheproofTestECDH) bool {
1452	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1453	if privKey == nil {
1454		log.Fatalf("EC_KEY_new_by_curve_name failed")
1455	}
1456	defer C.EC_KEY_free(privKey)
1457
1458	var bnPriv *C.BIGNUM
1459	wPriv := C.CString(wt.Private)
1460	if C.BN_hex2bn(&bnPriv, wPriv) == 0 {
1461		log.Fatal("Failed to decode wPriv")
1462	}
1463	C.free(unsafe.Pointer(wPriv))
1464	defer C.BN_free(bnPriv)
1465
1466	ret := C.EC_KEY_set_private_key(privKey, bnPriv)
1467	if ret != 1 {
1468		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1469			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1470		return false
1471	}
1472
1473	pub, err := hex.DecodeString(wt.Public)
1474	if err != nil {
1475		log.Fatalf("Failed to decode public key: %v", err)
1476	}
1477
1478	pubLen := len(pub)
1479	if pubLen == 0 {
1480		pub = append(pub, 0)
1481	}
1482
1483	Cpub := (*C.uchar)(C.malloc(C.ulong(pubLen)))
1484	if Cpub == nil {
1485		log.Fatal("malloc failed")
1486	}
1487	C.memcpy(unsafe.Pointer(Cpub), unsafe.Pointer(&pub[0]), C.ulong(pubLen))
1488
1489	p := (*C.uchar)(Cpub)
1490	var pubKey *C.EC_KEY
1491	if variant == EcPoint {
1492		pubKey = C.EC_KEY_new_by_curve_name(C.int(nid))
1493		if pubKey == nil {
1494			log.Fatal("EC_KEY_new_by_curve_name failed")
1495		}
1496		pubKey = C.o2i_ECPublicKey(&pubKey, (**C.uchar)(&p), C.long(pubLen))
1497	} else {
1498		pubKey = C.d2i_EC_PUBKEY(nil, (**C.uchar)(&p), C.long(pubLen))
1499	}
1500	defer C.EC_KEY_free(pubKey)
1501	C.free(unsafe.Pointer(Cpub))
1502
1503	if pubKey == nil {
1504		if wt.Result == "invalid" || wt.Result == "acceptable" {
1505			return true
1506		}
1507		fmt.Printf("FAIL: Test case %d (%q) %v - ASN decoding failed: want %v\n",
1508			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1509		return false
1510	}
1511
1512	privGroup := C.EC_KEY_get0_group(privKey)
1513
1514	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1515
1516	secret := make([]byte, secLen)
1517	if secLen == 0 {
1518		secret = append(secret, 0)
1519	}
1520
1521	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1522
1523	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1524	if ret != C.int(secLen) {
1525		if wt.Result == "invalid" {
1526			return true
1527		}
1528		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1529			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1530		return false
1531	}
1532
1533	shared, err := hex.DecodeString(wt.Shared)
1534	if err != nil {
1535		log.Fatalf("Failed to decode shared secret: %v", err)
1536	}
1537
1538	// XXX The shared fields of the secp224k1 test cases have a 0 byte preprended.
1539	if len(shared) == int(secLen)+1 && shared[0] == 0 {
1540		fmt.Printf("INFO: Test case %d (%q) %v - prepending 0 byte\n", wt.TCID, wt.Comment, wt.Flags)
1541		// shared = shared[1:];
1542		zero := make([]byte, 1, secLen+1)
1543		secret = append(zero, secret...)
1544	}
1545
1546	success := true
1547	if !bytes.Equal(shared, secret) {
1548		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1549			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1550		success = false
1551	}
1552	if acceptableAudit && success && wt.Result == "acceptable" {
1553		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1554	}
1555	return success
1556}
1557
1558func runECDHTestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupECDH) bool {
1559	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1560		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1561
1562	nid, err := nidFromString(wtg.Curve)
1563	if err != nil {
1564		log.Fatalf("Failed to get nid for curve: %v", err)
1565	}
1566
1567	success := true
1568	for _, wt := range wtg.Tests {
1569		if !runECDHTest(nid, variant, wt) {
1570			success = false
1571		}
1572	}
1573	return success
1574}
1575
1576func runECDHWebCryptoTest(nid int, wt *wycheproofTestECDHWebCrypto) bool {
1577	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1578	if privKey == nil {
1579		log.Fatalf("EC_KEY_new_by_curve_name failed")
1580	}
1581	defer C.EC_KEY_free(privKey)
1582
1583	d, err := base64.RawURLEncoding.DecodeString(wt.Private.D)
1584	if err != nil {
1585		log.Fatalf("Failed to base64 decode d: %v", err)
1586	}
1587	bnD := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&d[0])), C.int(len(d)), nil)
1588	if bnD == nil {
1589		log.Fatal("Failed to decode D")
1590	}
1591	defer C.BN_free(bnD)
1592
1593	ret := C.EC_KEY_set_private_key(privKey, bnD)
1594	if ret != 1 {
1595		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1596			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1597		return false
1598	}
1599
1600	x, err := base64.RawURLEncoding.DecodeString(wt.Public.X)
1601	if err != nil {
1602		log.Fatalf("Failed to base64 decode x: %v", err)
1603	}
1604	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1605	if bnX == nil {
1606		log.Fatal("Failed to decode X")
1607	}
1608	defer C.BN_free(bnX)
1609
1610	y, err := base64.RawURLEncoding.DecodeString(wt.Public.Y)
1611	if err != nil {
1612		log.Fatalf("Failed to base64 decode y: %v", err)
1613	}
1614	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1615	if bnY == nil {
1616		log.Fatal("Failed to decode Y")
1617	}
1618	defer C.BN_free(bnY)
1619
1620	pubKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1621	if pubKey == nil {
1622		log.Fatal("Failed to create EC_KEY")
1623	}
1624	defer C.EC_KEY_free(pubKey)
1625
1626	ret = C.EC_KEY_set_public_key_affine_coordinates(pubKey, bnX, bnY)
1627	if ret != 1 {
1628		if wt.Result == "invalid" {
1629			return true
1630		}
1631		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_public_key_affine_coordinates() = %d, want %v\n",
1632			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1633		return false
1634	}
1635	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1636
1637	privGroup := C.EC_KEY_get0_group(privKey)
1638
1639	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1640
1641	secret := make([]byte, secLen)
1642	if secLen == 0 {
1643		secret = append(secret, 0)
1644	}
1645
1646	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1647	if ret != C.int(secLen) {
1648		if wt.Result == "invalid" {
1649			return true
1650		}
1651		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1652			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1653		return false
1654	}
1655
1656	shared, err := hex.DecodeString(wt.Shared)
1657	if err != nil {
1658		log.Fatalf("Failed to decode shared secret: %v", err)
1659	}
1660
1661	success := true
1662	if !bytes.Equal(shared, secret) {
1663		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1664			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1665		success = false
1666	}
1667	if acceptableAudit && success && wt.Result == "acceptable" {
1668		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1669	}
1670	return success
1671}
1672
1673func runECDHWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDHWebCrypto) bool {
1674	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1675		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1676
1677	nid, err := nidFromString(wtg.Curve)
1678	if err != nil {
1679		log.Fatalf("Failed to get nid for curve: %v", err)
1680	}
1681
1682	success := true
1683	for _, wt := range wtg.Tests {
1684		if !runECDHWebCryptoTest(nid, wt) {
1685			success = false
1686		}
1687	}
1688	return success
1689}
1690
1691func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, variant testVariant, wt *wycheproofTestECDSA) bool {
1692	msg, err := hex.DecodeString(wt.Msg)
1693	if err != nil {
1694		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1695	}
1696
1697	h.Reset()
1698	h.Write(msg)
1699	msg = h.Sum(nil)
1700
1701	msgLen := len(msg)
1702	if msgLen == 0 {
1703		msg = append(msg, 0)
1704	}
1705
1706	var ret C.int
1707	if variant == Webcrypto || variant == P1363 {
1708		cDer, derLen := encodeECDSAWebCryptoSig(wt.Sig)
1709		if cDer == nil {
1710			fmt.Print("FAIL: unable to decode signature")
1711			return false
1712		}
1713		defer C.free(unsafe.Pointer(cDer))
1714
1715		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1716			(*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), ecKey)
1717	} else {
1718		sig, err := hex.DecodeString(wt.Sig)
1719		if err != nil {
1720			log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1721		}
1722
1723		sigLen := len(sig)
1724		if sigLen == 0 {
1725			sig = append(sig, 0)
1726		}
1727		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1728			(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), ecKey)
1729	}
1730
1731	// XXX audit acceptable cases...
1732	success := true
1733	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
1734		fmt.Printf("FAIL: Test case %d (%q) %v - ECDSA_verify() = %d, want %v\n",
1735			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
1736		success = false
1737	}
1738	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
1739		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1740	}
1741	return success
1742}
1743
1744func runECDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupECDSA) bool {
1745	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1746		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1747
1748	nid, err := nidFromString(wtg.Key.Curve)
1749	if err != nil {
1750		log.Fatalf("Failed to get nid for curve: %v", err)
1751	}
1752	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1753	if ecKey == nil {
1754		log.Fatal("EC_KEY_new_by_curve_name failed")
1755	}
1756	defer C.EC_KEY_free(ecKey)
1757
1758	var bnX *C.BIGNUM
1759	wx := C.CString(wtg.Key.WX)
1760	if C.BN_hex2bn(&bnX, wx) == 0 {
1761		log.Fatal("Failed to decode WX")
1762	}
1763	C.free(unsafe.Pointer(wx))
1764	defer C.BN_free(bnX)
1765
1766	var bnY *C.BIGNUM
1767	wy := C.CString(wtg.Key.WY)
1768	if C.BN_hex2bn(&bnY, wy) == 0 {
1769		log.Fatal("Failed to decode WY")
1770	}
1771	C.free(unsafe.Pointer(wy))
1772	defer C.BN_free(bnY)
1773
1774	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1775		log.Fatal("Failed to set EC public key")
1776	}
1777
1778	nid, err = nidFromString(wtg.SHA)
1779	if err != nil {
1780		log.Fatalf("Failed to get MD NID: %v", err)
1781	}
1782	h, err := hashFromString(wtg.SHA)
1783	if err != nil {
1784		log.Fatalf("Failed to get hash: %v", err)
1785	}
1786
1787	success := true
1788	for _, wt := range wtg.Tests {
1789		if !runECDSATest(ecKey, nid, h, variant, wt) {
1790			success = false
1791		}
1792	}
1793	return success
1794}
1795
1796// DER encode the signature (so that ECDSA_verify() can decode and encode it again...)
1797func encodeECDSAWebCryptoSig(wtSig string) (*C.uchar, C.int) {
1798	cSig := C.ECDSA_SIG_new()
1799	if cSig == nil {
1800		log.Fatal("ECDSA_SIG_new() failed")
1801	}
1802	defer C.ECDSA_SIG_free(cSig)
1803
1804	sigLen := len(wtSig)
1805	r := C.CString(wtSig[:sigLen/2])
1806	s := C.CString(wtSig[sigLen/2:])
1807	defer C.free(unsafe.Pointer(r))
1808	defer C.free(unsafe.Pointer(s))
1809	if C.BN_hex2bn(&cSig.r, r) == 0 {
1810		return nil, 0
1811	}
1812	if C.BN_hex2bn(&cSig.s, s) == 0 {
1813		return nil, 0
1814	}
1815
1816	derLen := C.i2d_ECDSA_SIG(cSig, nil)
1817	if derLen == 0 {
1818		return nil, 0
1819	}
1820	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1821	if cDer == nil {
1822		log.Fatal("malloc failed")
1823	}
1824
1825	p := cDer
1826	ret := C.i2d_ECDSA_SIG(cSig, (**C.uchar)(&p))
1827	if ret == 0 || ret != derLen {
1828		C.free(unsafe.Pointer(cDer))
1829		return nil, 0
1830	}
1831
1832	return cDer, derLen
1833}
1834
1835func runECDSAWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDSAWebCrypto) bool {
1836	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1837		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1838
1839	nid, err := nidFromString(wtg.JWK.Crv)
1840	if err != nil {
1841		log.Fatalf("Failed to get nid for curve: %v", err)
1842	}
1843	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1844	if ecKey == nil {
1845		log.Fatal("EC_KEY_new_by_curve_name failed")
1846	}
1847	defer C.EC_KEY_free(ecKey)
1848
1849	x, err := base64.RawURLEncoding.DecodeString(wtg.JWK.X)
1850	if err != nil {
1851		log.Fatalf("Failed to base64 decode X: %v", err)
1852	}
1853	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1854	if bnX == nil {
1855		log.Fatal("Failed to decode X")
1856	}
1857	defer C.BN_free(bnX)
1858
1859	y, err := base64.RawURLEncoding.DecodeString(wtg.JWK.Y)
1860	if err != nil {
1861		log.Fatalf("Failed to base64 decode Y: %v", err)
1862	}
1863	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1864	if bnY == nil {
1865		log.Fatal("Failed to decode Y")
1866	}
1867	defer C.BN_free(bnY)
1868
1869	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1870		log.Fatal("Failed to set EC public key")
1871	}
1872
1873	nid, err = nidFromString(wtg.SHA)
1874	if err != nil {
1875		log.Fatalf("Failed to get MD NID: %v", err)
1876	}
1877	h, err := hashFromString(wtg.SHA)
1878	if err != nil {
1879		log.Fatalf("Failed to get hash: %v", err)
1880	}
1881
1882	success := true
1883	for _, wt := range wtg.Tests {
1884		if !runECDSATest(ecKey, nid, h, Webcrypto, wt) {
1885			success = false
1886		}
1887	}
1888	return success
1889}
1890
1891func runHkdfTest(md *C.EVP_MD, wt *wycheproofTestHkdf) bool {
1892	ikm, err := hex.DecodeString(wt.Ikm)
1893	if err != nil {
1894		log.Fatalf("Failed to decode ikm %q: %v", wt.Ikm, err)
1895	}
1896	salt, err := hex.DecodeString(wt.Salt)
1897	if err != nil {
1898		log.Fatalf("Failed to decode salt %q: %v", wt.Salt, err)
1899	}
1900	info, err := hex.DecodeString(wt.Info)
1901	if err != nil {
1902		log.Fatalf("Failed to decode info %q: %v", wt.Info, err)
1903	}
1904
1905	ikmLen, saltLen, infoLen := len(ikm), len(salt), len(info)
1906	if ikmLen == 0 {
1907		ikm = append(ikm, 0)
1908	}
1909	if saltLen == 0 {
1910		salt = append(salt, 0)
1911	}
1912	if infoLen == 0 {
1913		info = append(info, 0)
1914	}
1915
1916	outLen := wt.Size
1917	out := make([]byte, outLen)
1918	if outLen == 0 {
1919		out = append(out, 0)
1920	}
1921
1922	ret := C.HKDF((*C.uchar)(unsafe.Pointer(&out[0])), C.size_t(outLen), md, (*C.uchar)(unsafe.Pointer(&ikm[0])), C.size_t(ikmLen), (*C.uchar)(&salt[0]), C.size_t(saltLen), (*C.uchar)(unsafe.Pointer(&info[0])), C.size_t(infoLen))
1923
1924	if ret != 1 {
1925		success := wt.Result == "invalid"
1926		if !success {
1927			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1928		}
1929		return success
1930	}
1931
1932	okm, err := hex.DecodeString(wt.Okm)
1933	if err != nil {
1934		log.Fatalf("Failed to decode okm %q: %v", wt.Okm, err)
1935	}
1936	if !bytes.Equal(out[:outLen], okm) {
1937		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed output don't match: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
1938	}
1939
1940	return wt.Result == "valid"
1941}
1942
1943func runHkdfTestGroup(algorithm string, wtg *wycheproofTestGroupHkdf) bool {
1944	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
1945	md, err := hashEvpMdFromString(strings.TrimPrefix(algorithm, "HKDF-"))
1946	if err != nil {
1947		log.Fatalf("Failed to get hash: %v", err)
1948	}
1949
1950	success := true
1951	for _, wt := range wtg.Tests {
1952		if !runHkdfTest(md, wt) {
1953			success = false
1954		}
1955	}
1956	return success
1957}
1958
1959func runHmacTest(md *C.EVP_MD, tagBytes int, wt *wycheproofTestHmac) bool {
1960	key, err := hex.DecodeString(wt.Key)
1961	if err != nil {
1962		log.Fatalf("failed to decode key %q: %v", wt.Key, err)
1963	}
1964
1965	msg, err := hex.DecodeString(wt.Msg)
1966	if err != nil {
1967		log.Fatalf("failed to decode msg %q: %v", wt.Msg, err)
1968	}
1969
1970	keyLen, msgLen := len(key), len(msg)
1971
1972	if keyLen == 0 {
1973		key = append(key, 0)
1974	}
1975
1976	if msgLen == 0 {
1977		msg = append(msg, 0)
1978	}
1979
1980	got := make([]byte, C.EVP_MAX_MD_SIZE)
1981	var gotLen C.uint
1982
1983	ret := C.HMAC(md, unsafe.Pointer(&key[0]), C.int(keyLen), (*C.uchar)(unsafe.Pointer(&msg[0])), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&got[0])), &gotLen)
1984
1985	success := true
1986	if ret == nil {
1987		if wt.Result != "invalid" {
1988			success = false
1989			fmt.Printf("FAIL: Test case %d (%q) %v - HMAC: got nil, want %v\n", wt.TCID, wt.Comment, wt.Flags, wt.Result)
1990		}
1991		return success
1992	}
1993
1994	if int(gotLen) < tagBytes {
1995		fmt.Printf("FAIL: Test case %d (%q) %v - HMAC length: got %d, want %d, expected %v\n", wt.TCID, wt.Comment, wt.Flags, gotLen, tagBytes, wt.Result)
1996		return false
1997	}
1998
1999	tag, err := hex.DecodeString(wt.Tag)
2000	if err != nil {
2001		log.Fatalf("failed to decode tag %q: %v", wt.Tag, err)
2002	}
2003
2004	success = bytes.Equal(got[:tagBytes], tag) == (wt.Result == "valid")
2005
2006	if !success {
2007		fmt.Printf("FAIL: Test case %d (%q) %v - got %v want %v\n", wt.TCID, wt.Comment, wt.Flags, success, wt.Result)
2008	}
2009
2010	return success
2011}
2012
2013func runHmacTestGroup(algorithm string, wtg *wycheproofTestGroupHmac) bool {
2014	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n", algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
2015	md, err := hashEvpMdFromString("SHA-" + strings.TrimPrefix(algorithm, "HMACSHA"))
2016	if err != nil {
2017		log.Fatalf("Failed to get hash: %v", err)
2018	}
2019
2020	success := true
2021	for _, wt := range wtg.Tests {
2022		if !runHmacTest(md, wtg.TagSize/8, wt) {
2023			success = false
2024		}
2025	}
2026	return success
2027}
2028
2029func runKWTestWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2030	var aesKey C.AES_KEY
2031
2032	ret := C.AES_set_encrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2033	if ret != 0 {
2034		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
2035			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2036		return false
2037	}
2038
2039	outLen := msgLen
2040	out := make([]byte, outLen)
2041	copy(out, msg)
2042	out = append(out, make([]byte, 8)...)
2043	ret = C.AES_wrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(msgLen))
2044	success := false
2045	if ret == C.int(len(out)) && bytes.Equal(out, ct) {
2046		if acceptableAudit && wt.Result == "acceptable" {
2047			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2048		}
2049		if wt.Result != "invalid" {
2050			success = true
2051		}
2052	} else if wt.Result != "valid" {
2053		success = true
2054	}
2055	if !success {
2056		fmt.Printf("FAIL: Test case %d (%q) %v - msgLen = %d, AES_wrap_key() = %d, want %v\n",
2057			wt.TCID, wt.Comment, wt.Flags, msgLen, int(ret), wt.Result)
2058	}
2059	return success
2060}
2061
2062func runKWTestUnWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2063	var aesKey C.AES_KEY
2064
2065	ret := C.AES_set_decrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2066	if ret != 0 {
2067		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
2068			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2069		return false
2070	}
2071
2072	out := make([]byte, ctLen)
2073	copy(out, ct)
2074	if ctLen == 0 {
2075		out = append(out, 0)
2076	}
2077	ret = C.AES_unwrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(ctLen))
2078	success := false
2079	if ret == C.int(ctLen-8) && bytes.Equal(out[0:ret], msg[0:ret]) {
2080		if acceptableAudit && wt.Result == "acceptable" {
2081			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2082		}
2083		if wt.Result != "invalid" {
2084			success = true
2085		}
2086	} else if wt.Result != "valid" {
2087		success = true
2088	}
2089	if !success {
2090		fmt.Printf("FAIL: Test case %d (%q) %v - keyLen = %d, AES_unwrap_key() = %d, want %v\n",
2091			wt.TCID, wt.Comment, wt.Flags, keyLen, int(ret), wt.Result)
2092	}
2093	return success
2094}
2095
2096func runKWTest(keySize int, wt *wycheproofTestKW) bool {
2097	key, err := hex.DecodeString(wt.Key)
2098	if err != nil {
2099		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
2100	}
2101	msg, err := hex.DecodeString(wt.Msg)
2102	if err != nil {
2103		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
2104	}
2105	ct, err := hex.DecodeString(wt.CT)
2106	if err != nil {
2107		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
2108	}
2109
2110	keyLen, msgLen, ctLen := len(key), len(msg), len(ct)
2111
2112	if keyLen == 0 {
2113		key = append(key, 0)
2114	}
2115	if msgLen == 0 {
2116		msg = append(msg, 0)
2117	}
2118	if ctLen == 0 {
2119		ct = append(ct, 0)
2120	}
2121
2122	wrapSuccess := runKWTestWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2123	unwrapSuccess := runKWTestUnWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2124
2125	return wrapSuccess && unwrapSuccess
2126}
2127
2128func runKWTestGroup(algorithm string, wtg *wycheproofTestGroupKW) bool {
2129	fmt.Printf("Running %v test group %v with key size %d...\n",
2130		algorithm, wtg.Type, wtg.KeySize)
2131
2132	success := true
2133	for _, wt := range wtg.Tests {
2134		if !runKWTest(wtg.KeySize, wt) {
2135			success = false
2136		}
2137	}
2138	return success
2139}
2140
2141func runRsaesOaepTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, wt *wycheproofTestRsaes) bool {
2142	ct, err := hex.DecodeString(wt.CT)
2143	if err != nil {
2144		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
2145	}
2146	ctLen := len(ct)
2147	if ctLen == 0 {
2148		ct = append(ct, 0)
2149	}
2150
2151	rsaSize := C.RSA_size(rsa)
2152	decrypted := make([]byte, rsaSize)
2153
2154	success := true
2155
2156	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_NO_PADDING)
2157
2158	if ret != rsaSize {
2159		success = (wt.Result == "invalid")
2160
2161		if !success {
2162			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
2163		}
2164		return success
2165	}
2166
2167	label, err := hex.DecodeString(wt.Label)
2168	if err != nil {
2169		log.Fatalf("Failed to decode label %q: %v", wt.Label, err)
2170	}
2171	labelLen := len(label)
2172	if labelLen == 0 {
2173		label = append(label, 0)
2174	}
2175
2176	msg, err := hex.DecodeString(wt.Msg)
2177	if err != nil {
2178		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2179	}
2180	msgLen := len(msg)
2181
2182	to := make([]byte, rsaSize)
2183
2184	ret = C.RSA_padding_check_PKCS1_OAEP_mgf1((*C.uchar)(unsafe.Pointer(&to[0])), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&decrypted[0])), C.int(rsaSize), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&label[0])), C.int(labelLen), sha, mgfSha)
2185
2186	if int(ret) != msgLen {
2187		success = (wt.Result == "invalid")
2188
2189		if !success {
2190			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
2191		}
2192		return success
2193	}
2194
2195	to = to[:msgLen]
2196	if !bytes.Equal(msg, to) {
2197		success = false
2198		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2199	}
2200
2201	return success
2202}
2203
2204func runRsaesOaepTestGroup(algorithm string, wtg *wycheproofTestGroupRsaesOaep) bool {
2205	fmt.Printf("Running %v test group %v with key size %d MGF %v and %v...\n",
2206		algorithm, wtg.Type, wtg.KeySize, wtg.MGFSHA, wtg.SHA)
2207
2208	rsa := C.RSA_new()
2209	if rsa == nil {
2210		log.Fatal("RSA_new failed")
2211	}
2212	defer C.RSA_free(rsa)
2213
2214	d := C.CString(wtg.D)
2215	if C.BN_hex2bn(&rsa.d, d) == 0 {
2216		log.Fatal("Failed to set RSA d")
2217	}
2218	C.free(unsafe.Pointer(d))
2219
2220	e := C.CString(wtg.E)
2221	if C.BN_hex2bn(&rsa.e, e) == 0 {
2222		log.Fatal("Failed to set RSA e")
2223	}
2224	C.free(unsafe.Pointer(e))
2225
2226	n := C.CString(wtg.N)
2227	if C.BN_hex2bn(&rsa.n, n) == 0 {
2228		log.Fatal("Failed to set RSA n")
2229	}
2230	C.free(unsafe.Pointer(n))
2231
2232	sha, err := hashEvpMdFromString(wtg.SHA)
2233	if err != nil {
2234		log.Fatalf("Failed to get hash: %v", err)
2235	}
2236
2237	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2238	if err != nil {
2239		log.Fatalf("Failed to get MGF hash: %v", err)
2240	}
2241
2242	success := true
2243	for _, wt := range wtg.Tests {
2244		if !runRsaesOaepTest(rsa, sha, mgfSha, wt) {
2245			success = false
2246		}
2247	}
2248	return success
2249}
2250
2251func runRsaesPkcs1Test(rsa *C.RSA, wt *wycheproofTestRsaes) bool {
2252	ct, err := hex.DecodeString(wt.CT)
2253	if err != nil {
2254		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
2255	}
2256	ctLen := len(ct)
2257	if ctLen == 0 {
2258		ct = append(ct, 0)
2259	}
2260
2261	rsaSize := C.RSA_size(rsa)
2262	decrypted := make([]byte, rsaSize)
2263
2264	success := true
2265
2266	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_PKCS1_PADDING)
2267
2268	if ret == -1 {
2269		success = (wt.Result == "invalid")
2270
2271		if !success {
2272			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(wt.Msg)/2, wt.Result)
2273		}
2274		return success
2275	}
2276
2277	msg, err := hex.DecodeString(wt.Msg)
2278	if err != nil {
2279		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2280	}
2281
2282	if int(ret) != len(msg) {
2283		success = false
2284		fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(msg), wt.Result)
2285	} else if !bytes.Equal(msg, decrypted[:len(msg)]) {
2286		success = false
2287		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2288	}
2289
2290	return success
2291}
2292
2293func runRsaesPkcs1TestGroup(algorithm string, wtg *wycheproofTestGroupRsaesPkcs1) bool {
2294	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2295	rsa := C.RSA_new()
2296	if rsa == nil {
2297		log.Fatal("RSA_new failed")
2298	}
2299	defer C.RSA_free(rsa)
2300
2301	d := C.CString(wtg.D)
2302	if C.BN_hex2bn(&rsa.d, d) == 0 {
2303		log.Fatal("Failed to set RSA d")
2304	}
2305	C.free(unsafe.Pointer(d))
2306
2307	e := C.CString(wtg.E)
2308	if C.BN_hex2bn(&rsa.e, e) == 0 {
2309		log.Fatal("Failed to set RSA e")
2310	}
2311	C.free(unsafe.Pointer(e))
2312
2313	n := C.CString(wtg.N)
2314	if C.BN_hex2bn(&rsa.n, n) == 0 {
2315		log.Fatal("Failed to set RSA n")
2316	}
2317	C.free(unsafe.Pointer(n))
2318
2319	success := true
2320	for _, wt := range wtg.Tests {
2321		if !runRsaesPkcs1Test(rsa, wt) {
2322			success = false
2323		}
2324	}
2325	return success
2326}
2327
2328func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
2329	msg, err := hex.DecodeString(wt.Msg)
2330	if err != nil {
2331		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2332	}
2333
2334	h.Reset()
2335	h.Write(msg)
2336	msg = h.Sum(nil)
2337
2338	sig, err := hex.DecodeString(wt.Sig)
2339	if err != nil {
2340		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2341	}
2342
2343	msgLen, sigLen := len(msg), len(sig)
2344	if msgLen == 0 {
2345		msg = append(msg, 0)
2346	}
2347	if sigLen == 0 {
2348		sig = append(sig, 0)
2349	}
2350
2351	sigOut := make([]byte, C.RSA_size(rsa)-11)
2352	if sigLen == 0 {
2353		sigOut = append(sigOut, 0)
2354	}
2355
2356	ret := C.RSA_public_decrypt(C.int(sigLen), (*C.uchar)(unsafe.Pointer(&sig[0])),
2357		(*C.uchar)(unsafe.Pointer(&sigOut[0])), rsa, C.RSA_NO_PADDING)
2358	if ret == -1 {
2359		if wt.Result == "invalid" {
2360			return true
2361		}
2362		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_public_decrypt() = %d, want %v\n",
2363			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2364		return false
2365	}
2366
2367	ret = C.RSA_verify_PKCS1_PSS_mgf1(rsa, (*C.uchar)(unsafe.Pointer(&msg[0])), sha, mgfSha,
2368		(*C.uchar)(unsafe.Pointer(&sigOut[0])), C.int(sLen))
2369
2370	success := false
2371	if ret == 1 && (wt.Result == "valid" || wt.Result == "acceptable") {
2372		// All acceptable cases that pass use SHA-1 and are flagged:
2373		// "WeakHash" : "The key for this test vector uses a weak hash function."
2374		if acceptableAudit && wt.Result == "acceptable" {
2375			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2376		}
2377		success = true
2378	} else if ret == 0 && (wt.Result == "invalid" || wt.Result == "acceptable") {
2379		success = true
2380	} else {
2381		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify_PKCS1_PSS_mgf1() = %d, want %v\n",
2382			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2383	}
2384	return success
2385}
2386
2387func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
2388	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2389		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2390	rsa := C.RSA_new()
2391	if rsa == nil {
2392		log.Fatal("RSA_new failed")
2393	}
2394	defer C.RSA_free(rsa)
2395
2396	e := C.CString(wtg.E)
2397	if C.BN_hex2bn(&rsa.e, e) == 0 {
2398		log.Fatal("Failed to set RSA e")
2399	}
2400	C.free(unsafe.Pointer(e))
2401
2402	n := C.CString(wtg.N)
2403	if C.BN_hex2bn(&rsa.n, n) == 0 {
2404		log.Fatal("Failed to set RSA n")
2405	}
2406	C.free(unsafe.Pointer(n))
2407
2408	h, err := hashFromString(wtg.SHA)
2409	if err != nil {
2410		log.Fatalf("Failed to get hash: %v", err)
2411	}
2412
2413	sha, err := hashEvpMdFromString(wtg.SHA)
2414	if err != nil {
2415		log.Fatalf("Failed to get hash: %v", err)
2416	}
2417
2418	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2419	if err != nil {
2420		log.Fatalf("Failed to get MGF hash: %v", err)
2421	}
2422
2423	success := true
2424	for _, wt := range wtg.Tests {
2425		if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
2426			success = false
2427		}
2428	}
2429	return success
2430}
2431
2432func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
2433	msg, err := hex.DecodeString(wt.Msg)
2434	if err != nil {
2435		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2436	}
2437
2438	h.Reset()
2439	h.Write(msg)
2440	msg = h.Sum(nil)
2441
2442	sig, err := hex.DecodeString(wt.Sig)
2443	if err != nil {
2444		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2445	}
2446
2447	msgLen, sigLen := len(msg), len(sig)
2448	if msgLen == 0 {
2449		msg = append(msg, 0)
2450	}
2451	if sigLen == 0 {
2452		sig = append(sig, 0)
2453	}
2454
2455	ret := C.RSA_verify(C.int(nid), (*C.uchar)(unsafe.Pointer(&msg[0])), C.uint(msgLen),
2456		(*C.uchar)(unsafe.Pointer(&sig[0])), C.uint(sigLen), rsa)
2457
2458	// XXX audit acceptable cases...
2459	success := true
2460	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
2461		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify() = %d, want %v\n",
2462			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2463		success = false
2464	}
2465	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
2466		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2467	}
2468	return success
2469}
2470
2471func runRSATestGroup(algorithm string, wtg *wycheproofTestGroupRSA) bool {
2472	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2473		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2474
2475	rsa := C.RSA_new()
2476	if rsa == nil {
2477		log.Fatal("RSA_new failed")
2478	}
2479	defer C.RSA_free(rsa)
2480
2481	e := C.CString(wtg.E)
2482	if C.BN_hex2bn(&rsa.e, e) == 0 {
2483		log.Fatal("Failed to set RSA e")
2484	}
2485	C.free(unsafe.Pointer(e))
2486
2487	n := C.CString(wtg.N)
2488	if C.BN_hex2bn(&rsa.n, n) == 0 {
2489		log.Fatal("Failed to set RSA n")
2490	}
2491	C.free(unsafe.Pointer(n))
2492
2493	nid, err := nidFromString(wtg.SHA)
2494	if err != nil {
2495		log.Fatalf("Failed to get MD NID: %v", err)
2496	}
2497	h, err := hashFromString(wtg.SHA)
2498	if err != nil {
2499		log.Fatalf("Failed to get hash: %v", err)
2500	}
2501
2502	success := true
2503	for _, wt := range wtg.Tests {
2504		if !runRSATest(rsa, nid, h, wt) {
2505			success = false
2506		}
2507	}
2508	return success
2509}
2510
2511func runX25519Test(wt *wycheproofTestX25519) bool {
2512	public, err := hex.DecodeString(wt.Public)
2513	if err != nil {
2514		log.Fatalf("Failed to decode public %q: %v", wt.Public, err)
2515	}
2516	private, err := hex.DecodeString(wt.Private)
2517	if err != nil {
2518		log.Fatalf("Failed to decode private %q: %v", wt.Private, err)
2519	}
2520	shared, err := hex.DecodeString(wt.Shared)
2521	if err != nil {
2522		log.Fatalf("Failed to decode shared %q: %v", wt.Shared, err)
2523	}
2524
2525	got := make([]byte, C.X25519_KEY_LENGTH)
2526	result := true
2527
2528	if C.X25519((*C.uint8_t)(unsafe.Pointer(&got[0])), (*C.uint8_t)(unsafe.Pointer(&private[0])), (*C.uint8_t)(unsafe.Pointer(&public[0]))) != 1 {
2529		result = false
2530	} else {
2531		result = bytes.Equal(got, shared)
2532	}
2533
2534	// XXX audit acceptable cases...
2535	success := true
2536	if result != (wt.Result == "valid") && wt.Result != "acceptable" {
2537		fmt.Printf("FAIL: Test case %d (%q) %v - X25519(), want %v\n",
2538			wt.TCID, wt.Comment, wt.Flags, wt.Result)
2539		success = false
2540	}
2541	if acceptableAudit && result && wt.Result == "acceptable" {
2542		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2543	}
2544	return success
2545}
2546
2547func runX25519TestGroup(algorithm string, wtg *wycheproofTestGroupX25519) bool {
2548	fmt.Printf("Running %v test group with curve %v...\n", algorithm, wtg.Curve)
2549
2550	success := true
2551	for _, wt := range wtg.Tests {
2552		if !runX25519Test(wt) {
2553			success = false
2554		}
2555	}
2556	return success
2557}
2558
2559func runTestVectors(path string, variant testVariant) bool {
2560	b, err := ioutil.ReadFile(path)
2561	if err != nil {
2562		log.Fatalf("Failed to read test vectors: %v", err)
2563	}
2564	wtv := &wycheproofTestVectors{}
2565	if err := json.Unmarshal(b, wtv); err != nil {
2566		log.Fatalf("Failed to unmarshal JSON: %v", err)
2567	}
2568	fmt.Printf("Loaded Wycheproof test vectors for %v with %d tests from %q\n",
2569		wtv.Algorithm, wtv.NumberOfTests, filepath.Base(path))
2570
2571	var wtg interface{}
2572	switch wtv.Algorithm {
2573	case "AES-CBC-PKCS5":
2574		wtg = &wycheproofTestGroupAesCbcPkcs5{}
2575	case "AES-CCM":
2576		wtg = &wycheproofTestGroupAead{}
2577	case "AES-CMAC":
2578		wtg = &wycheproofTestGroupAesCmac{}
2579	case "AES-GCM":
2580		wtg = &wycheproofTestGroupAead{}
2581	case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2582		wtg = &wycheproofTestGroupAead{}
2583	case "DSA":
2584		wtg = &wycheproofTestGroupDSA{}
2585	case "ECDH":
2586		switch variant {
2587		case Webcrypto:
2588			wtg = &wycheproofTestGroupECDHWebCrypto{}
2589		default:
2590			wtg = &wycheproofTestGroupECDH{}
2591		}
2592	case "ECDSA":
2593		switch variant {
2594		case Webcrypto:
2595			wtg = &wycheproofTestGroupECDSAWebCrypto{}
2596		default:
2597			wtg = &wycheproofTestGroupECDSA{}
2598		}
2599	case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2600		wtg = &wycheproofTestGroupHkdf{}
2601	case "HMACSHA1", "HMACSHA224", "HMACSHA256", "HMACSHA384", "HMACSHA512":
2602		wtg = &wycheproofTestGroupHmac{}
2603	case "KW":
2604		wtg = &wycheproofTestGroupKW{}
2605	case "RSAES-OAEP":
2606		wtg = &wycheproofTestGroupRsaesOaep{}
2607	case "RSAES-PKCS1-v1_5":
2608		wtg = &wycheproofTestGroupRsaesPkcs1{}
2609	case "RSASSA-PSS":
2610		wtg = &wycheproofTestGroupRsassa{}
2611	case "RSASSA-PKCS1-v1_5", "RSASig":
2612		wtg = &wycheproofTestGroupRSA{}
2613	case "XDH", "X25519":
2614		wtg = &wycheproofTestGroupX25519{}
2615	default:
2616		log.Printf("INFO: Unknown test vector algorithm %q", wtv.Algorithm)
2617		return false
2618	}
2619
2620	success := true
2621	for _, tg := range wtv.TestGroups {
2622		if err := json.Unmarshal(tg, wtg); err != nil {
2623			log.Fatalf("Failed to unmarshal test groups JSON: %v", err)
2624		}
2625		switch wtv.Algorithm {
2626		case "AES-CBC-PKCS5":
2627			if !runAesCbcPkcs5TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCbcPkcs5)) {
2628				success = false
2629			}
2630		case "AES-CCM":
2631			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2632				success = false
2633			}
2634		case "AES-CMAC":
2635			if !runAesCmacTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCmac)) {
2636				success = false
2637			}
2638		case "AES-GCM":
2639			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2640				success = false
2641			}
2642		case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2643			if !runChaCha20Poly1305TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2644				success = false
2645			}
2646		case "DSA":
2647			if !runDSATestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupDSA)) {
2648				success = false
2649			}
2650		case "ECDH":
2651			switch variant {
2652			case Webcrypto:
2653				if !runECDHWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDHWebCrypto)) {
2654					success = false
2655				}
2656			default:
2657				if !runECDHTestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupECDH)) {
2658					success = false
2659				}
2660			}
2661		case "ECDSA":
2662			switch variant {
2663			case Webcrypto:
2664				if !runECDSAWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDSAWebCrypto)) {
2665					success = false
2666				}
2667			default:
2668				if !runECDSATestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupECDSA)) {
2669					success = false
2670				}
2671			}
2672		case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2673			if !runHkdfTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupHkdf)) {
2674				success = false
2675			}
2676		case "HMACSHA1", "HMACSHA224", "HMACSHA256", "HMACSHA384", "HMACSHA512":
2677			if !runHmacTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupHmac)) {
2678				success = false
2679			}
2680		case "KW":
2681			if !runKWTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupKW)) {
2682				success = false
2683			}
2684		case "RSAES-OAEP":
2685			if !runRsaesOaepTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesOaep)) {
2686				success = false
2687			}
2688		case "RSAES-PKCS1-v1_5":
2689			if !runRsaesPkcs1TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesPkcs1)) {
2690				success = false
2691			}
2692		case "RSASSA-PSS":
2693			if !runRsassaTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsassa)) {
2694				success = false
2695			}
2696		case "RSASSA-PKCS1-v1_5", "RSASig":
2697			if !runRSATestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRSA)) {
2698				success = false
2699			}
2700		case "XDH", "X25519":
2701			if !runX25519TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupX25519)) {
2702				success = false
2703			}
2704		default:
2705			log.Fatalf("Unknown test vector algorithm %q", wtv.Algorithm)
2706		}
2707	}
2708	return success
2709}
2710
2711func main() {
2712	if _, err := os.Stat(testVectorPath); os.IsNotExist(err) {
2713		fmt.Printf("package wycheproof-testvectors is required for this regress\n")
2714		fmt.Printf("SKIPPED\n")
2715		os.Exit(0)
2716	}
2717
2718	flag.BoolVar(&acceptableAudit, "v", false, "audit acceptable cases")
2719	flag.Parse()
2720
2721	acceptableComments = make(map[string]int)
2722	acceptableFlags = make(map[string]int)
2723
2724	// TODO: Investigate the following new test vectors:
2725	//	primality_test.json
2726	//	x25519_{asn,jwk,pem}_test.json
2727	tests := []struct {
2728		name    string
2729		pattern string
2730		variant testVariant
2731	}{
2732		{"AES", "aes_[cg]*[^xv]_test.json", Normal}, // Skip AES-EAX, AES-GCM-SIV and AES-SIV-CMAC.
2733		{"ChaCha20-Poly1305", "chacha20_poly1305_test.json", Normal},
2734		{"DSA", "dsa_*test.json", Normal},
2735		{"DSA", "dsa_*_p1363_test.json", P1363},
2736		{"ECDH", "ecdh_test.json", Normal},
2737		{"ECDH", "ecdh_[^w_]*_test.json", Normal},
2738		{"ECDH EcPoint", "ecdh_*_ecpoint_test.json", EcPoint},
2739		{"ECDH webcrypto", "ecdh_webcrypto_test.json", Webcrypto},
2740		{"ECDSA", "ecdsa_test.json", Normal},
2741		{"ECDSA", "ecdsa_[^w]*test.json", Normal},
2742		{"ECDSA P1363", "ecdsa_*_p1363_test.json", P1363},
2743		{"ECDSA webcrypto", "ecdsa_webcrypto_test.json", Webcrypto},
2744		{"HKDF", "hkdf_sha*_test.json", Normal},
2745		{"HMAC", "hmac_sha*_test.json", Normal},
2746		{"KW", "kw_test.json", Normal},
2747		{"RSA", "rsa_*test.json", Normal},
2748		{"X25519", "x25519_test.json", Normal},
2749		{"X25519 ASN", "x25519_asn_test.json", Skip},
2750		{"X25519 JWK", "x25519_jwk_test.json", Skip},
2751		{"X25519 PEM", "x25519_pem_test.json", Skip},
2752		{"XCHACHA20-POLY1305", "xchacha20_poly1305_test.json", Normal},
2753	}
2754
2755	success := true
2756
2757	skipNormal := regexp.MustCompile(`_(ecpoint|p1363|sha3|sha512_(224|256))_`)
2758
2759	for _, test := range tests {
2760		tvs, err := filepath.Glob(filepath.Join(testVectorPath, test.pattern))
2761		if err != nil {
2762			log.Fatalf("Failed to glob %v test vectors: %v", test.name, err)
2763		}
2764		if len(tvs) == 0 {
2765			log.Fatalf("Failed to find %v test vectors at %q\n", test.name, testVectorPath)
2766		}
2767		for _, tv := range tvs {
2768			if test.variant == Skip || (test.variant == Normal && skipNormal.Match([]byte(tv))) {
2769				fmt.Printf("INFO: Skipping tests from \"%s\"\n", strings.TrimPrefix(tv, testVectorPath+"/"))
2770				continue
2771			}
2772			if !runTestVectors(tv, test.variant) {
2773				success = false
2774			}
2775		}
2776	}
2777
2778	if acceptableAudit {
2779		printAcceptableStatistics()
2780	}
2781
2782	if !success {
2783		os.Exit(1)
2784	}
2785}
2786