1/* $OpenBSD: wycheproof.go,v 1.121 2021/04/03 13:34:45 tb Exp $ */
2/*
3 * Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
4 * Copyright (c) 2018, 2019 Theo Buehler <tb@openbsd.org>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19// Wycheproof runs test vectors from Project Wycheproof against libcrypto.
20package main
21
22/*
23#cgo LDFLAGS: -lcrypto
24
25#include <string.h>
26
27#include <openssl/aes.h>
28#include <openssl/bio.h>
29#include <openssl/bn.h>
30#include <openssl/cmac.h>
31#include <openssl/curve25519.h>
32#include <openssl/dsa.h>
33#include <openssl/ec.h>
34#include <openssl/ecdsa.h>
35#include <openssl/evp.h>
36#include <openssl/hkdf.h>
37#include <openssl/hmac.h>
38#include <openssl/objects.h>
39#include <openssl/pem.h>
40#include <openssl/x509.h>
41#include <openssl/rsa.h>
42
43int
44evpDigestSignUpdate(EVP_MD_CTX *ctx, const void *d, size_t cnt)
45{
46	return EVP_DigestSignUpdate(ctx, d, cnt);
47}
48*/
49import "C"
50
51import (
52	"bytes"
53	"crypto/sha1"
54	"crypto/sha256"
55	"crypto/sha512"
56	"encoding/base64"
57	"encoding/hex"
58	"encoding/json"
59	"flag"
60	"fmt"
61	"hash"
62	"io/ioutil"
63	"log"
64	"os"
65	"path/filepath"
66	"regexp"
67	"sort"
68	"strings"
69	"unsafe"
70)
71
72const testVectorPath = "/usr/local/share/wycheproof/testvectors"
73
74type testVariant int
75
76const (
77	Normal    testVariant = 0
78	EcPoint   testVariant = 1
79	P1363     testVariant = 2
80	Webcrypto testVariant = 3
81	Asn1      testVariant = 4
82	Pem       testVariant = 5
83	Jwk       testVariant = 6
84	Skip      testVariant = 7
85)
86
87func (variant testVariant) String() string {
88	variants := [...]string{
89		"Normal",
90		"EcPoint",
91		"P1363",
92		"Webcrypto",
93		"Asn1",
94		"Pem",
95		"Jwk",
96		"Skip",
97	}
98	return variants[variant]
99}
100
101var acceptableAudit = false
102var acceptableComments map[string]int
103var acceptableFlags map[string]int
104
105type wycheproofJWKPublic struct {
106	Crv string `json:"crv"`
107	KID string `json:"kid"`
108	KTY string `json:"kty"`
109	X   string `json:"x"`
110	Y   string `json:"y"`
111}
112
113type wycheproofJWKPrivate struct {
114	Crv string `json:"crv"`
115	D   string `json:"d"`
116	KID string `json:"kid"`
117	KTY string `json:"kty"`
118	X   string `json:"x"`
119	Y   string `json:"y"`
120}
121
122type wycheproofTestGroupAesCbcPkcs5 struct {
123	IVSize  int                          `json:"ivSize"`
124	KeySize int                          `json:"keySize"`
125	Type    string                       `json:"type"`
126	Tests   []*wycheproofTestAesCbcPkcs5 `json:"tests"`
127}
128
129type wycheproofTestAesCbcPkcs5 struct {
130	TCID    int      `json:"tcId"`
131	Comment string   `json:"comment"`
132	Key     string   `json:"key"`
133	IV      string   `json:"iv"`
134	Msg     string   `json:"msg"`
135	CT      string   `json:"ct"`
136	Result  string   `json:"result"`
137	Flags   []string `json:"flags"`
138}
139
140type wycheproofTestGroupAead struct {
141	IVSize  int                   `json:"ivSize"`
142	KeySize int                   `json:"keySize"`
143	TagSize int                   `json:"tagSize"`
144	Type    string                `json:"type"`
145	Tests   []*wycheproofTestAead `json:"tests"`
146}
147
148type wycheproofTestAead struct {
149	TCID    int      `json:"tcId"`
150	Comment string   `json:"comment"`
151	Key     string   `json:"key"`
152	IV      string   `json:"iv"`
153	AAD     string   `json:"aad"`
154	Msg     string   `json:"msg"`
155	CT      string   `json:"ct"`
156	Tag     string   `json:"tag"`
157	Result  string   `json:"result"`
158	Flags   []string `json:"flags"`
159}
160
161type wycheproofTestGroupAesCmac struct {
162	KeySize int                      `json:"keySize"`
163	TagSize int                      `json:"tagSize"`
164	Type    string                   `json:"type"`
165	Tests   []*wycheproofTestAesCmac `json:"tests"`
166}
167
168type wycheproofTestAesCmac struct {
169	TCID    int      `json:"tcId"`
170	Comment string   `json:"comment"`
171	Key     string   `json:"key"`
172	Msg     string   `json:"msg"`
173	Tag     string   `json:"tag"`
174	Result  string   `json:"result"`
175	Flags   []string `json:"flags"`
176}
177
178type wycheproofDSAKey struct {
179	G       string `json:"g"`
180	KeySize int    `json:"keySize"`
181	P       string `json:"p"`
182	Q       string `json:"q"`
183	Type    string `json:"type"`
184	Y       string `json:"y"`
185}
186
187type wycheproofTestDSA struct {
188	TCID    int      `json:"tcId"`
189	Comment string   `json:"comment"`
190	Msg     string   `json:"msg"`
191	Sig     string   `json:"sig"`
192	Result  string   `json:"result"`
193	Flags   []string `json:"flags"`
194}
195
196type wycheproofTestGroupDSA struct {
197	Key    *wycheproofDSAKey    `json:"key"`
198	KeyDER string               `json:"keyDer"`
199	KeyPEM string               `json:"keyPem"`
200	SHA    string               `json:"sha"`
201	Type   string               `json:"type"`
202	Tests  []*wycheproofTestDSA `json:"tests"`
203}
204
205type wycheproofTestECDH struct {
206	TCID    int      `json:"tcId"`
207	Comment string   `json:"comment"`
208	Public  string   `json:"public"`
209	Private string   `json:"private"`
210	Shared  string   `json:"shared"`
211	Result  string   `json:"result"`
212	Flags   []string `json:"flags"`
213}
214
215type wycheproofTestGroupECDH struct {
216	Curve    string                `json:"curve"`
217	Encoding string                `json:"encoding"`
218	Type     string                `json:"type"`
219	Tests    []*wycheproofTestECDH `json:"tests"`
220}
221
222type wycheproofTestECDHWebCrypto struct {
223	TCID    int                   `json:"tcId"`
224	Comment string                `json:"comment"`
225	Public  *wycheproofJWKPublic  `json:"public"`
226	Private *wycheproofJWKPrivate `json:"private"`
227	Shared  string                `json:"shared"`
228	Result  string                `json:"result"`
229	Flags   []string              `json:"flags"`
230}
231
232type wycheproofTestGroupECDHWebCrypto struct {
233	Curve    string                         `json:"curve"`
234	Encoding string                         `json:"encoding"`
235	Type     string                         `json:"type"`
236	Tests    []*wycheproofTestECDHWebCrypto `json:"tests"`
237}
238
239type wycheproofECDSAKey struct {
240	Curve        string `json:"curve"`
241	KeySize      int    `json:"keySize"`
242	Type         string `json:"type"`
243	Uncompressed string `json:"uncompressed"`
244	WX           string `json:"wx"`
245	WY           string `json:"wy"`
246}
247
248type wycheproofTestECDSA struct {
249	TCID    int      `json:"tcId"`
250	Comment string   `json:"comment"`
251	Msg     string   `json:"msg"`
252	Sig     string   `json:"sig"`
253	Result  string   `json:"result"`
254	Flags   []string `json:"flags"`
255}
256
257type wycheproofTestGroupECDSA struct {
258	Key    *wycheproofECDSAKey    `json:"key"`
259	KeyDER string                 `json:"keyDer"`
260	KeyPEM string                 `json:"keyPem"`
261	SHA    string                 `json:"sha"`
262	Type   string                 `json:"type"`
263	Tests  []*wycheproofTestECDSA `json:"tests"`
264}
265
266type wycheproofTestGroupECDSAWebCrypto struct {
267	JWK    *wycheproofJWKPublic   `json:"jwk"`
268	Key    *wycheproofECDSAKey    `json:"key"`
269	KeyDER string                 `json:"keyDer"`
270	KeyPEM string                 `json:"keyPem"`
271	SHA    string                 `json:"sha"`
272	Type   string                 `json:"type"`
273	Tests  []*wycheproofTestECDSA `json:"tests"`
274}
275
276type wycheproofTestHkdf struct {
277	TCID    int      `json:"tcId"`
278	Comment string   `json:"comment"`
279	Ikm     string   `json:"ikm"`
280	Salt    string   `json:"salt"`
281	Info    string   `json:"info"`
282	Size    int      `json:"size"`
283	Okm     string   `json:"okm"`
284	Result  string   `json:"result"`
285	Flags   []string `json:"flags"`
286}
287
288type wycheproofTestGroupHkdf struct {
289	Type    string                `json:"type"`
290	KeySize int                   `json:"keySize"`
291	Tests   []*wycheproofTestHkdf `json:"tests"`
292}
293
294type wycheproofTestHmac struct {
295	TCID    int      `json:"tcId"`
296	Comment string   `json:"comment"`
297	Key     string   `json:"key"`
298	Msg     string   `json:"msg"`
299	Tag     string   `json:"tag"`
300	Result  string   `json:"result"`
301	Flags   []string `json:"flags"`
302}
303
304type wycheproofTestGroupHmac struct {
305	KeySize int                   `json:"keySize"`
306	TagSize int                   `json:"tagSize"`
307	Type    string                `json:"type"`
308	Tests   []*wycheproofTestHmac `json:"tests"`
309}
310
311type wycheproofTestKW struct {
312	TCID    int      `json:"tcId"`
313	Comment string   `json:"comment"`
314	Key     string   `json:"key"`
315	Msg     string   `json:"msg"`
316	CT      string   `json:"ct"`
317	Result  string   `json:"result"`
318	Flags   []string `json:"flags"`
319}
320
321type wycheproofTestGroupKW struct {
322	KeySize int                 `json:"keySize"`
323	Type    string              `json:"type"`
324	Tests   []*wycheproofTestKW `json:"tests"`
325}
326
327type wycheproofTestRSA struct {
328	TCID    int      `json:"tcId"`
329	Comment string   `json:"comment"`
330	Msg     string   `json:"msg"`
331	Sig     string   `json:"sig"`
332	Padding string   `json:"padding"`
333	Result  string   `json:"result"`
334	Flags   []string `json:"flags"`
335}
336
337type wycheproofTestGroupRSA struct {
338	E       string               `json:"e"`
339	KeyASN  string               `json:"keyAsn"`
340	KeyDER  string               `json:"keyDer"`
341	KeyPEM  string               `json:"keyPem"`
342	KeySize int                  `json:"keysize"`
343	N       string               `json:"n"`
344	SHA     string               `json:"sha"`
345	Type    string               `json:"type"`
346	Tests   []*wycheproofTestRSA `json:"tests"`
347}
348
349type wycheproofPrivateKeyJwk struct {
350	Alg string `json:"alg"`
351	D   string `json:"d"`
352	DP  string `json:"dp"`
353	DQ  string `json:"dq"`
354	E   string `json:"e"`
355	KID string `json:"kid"`
356	Kty string `json:"kty"`
357	N   string `json:"n"`
358	P   string `json:"p"`
359	Q   string `json:"q"`
360	QI  string `json:"qi"`
361}
362
363type wycheproofTestRsaes struct {
364	TCID    int      `json:"tcId"`
365	Comment string   `json:"comment"`
366	Msg     string   `json:"msg"`
367	CT      string   `json:"ct"`
368	Label   string   `json:"label"`
369	Result  string   `json:"result"`
370	Flags   []string `json:"flags"`
371}
372
373type wycheproofTestGroupRsaesOaep struct {
374	D               string                   `json:"d"`
375	E               string                   `json:"e"`
376	KeySize         int                      `json:"keysize"`
377	MGF             string                   `json:"mgf"`
378	MGFSHA          string                   `json:"mgfSha"`
379	N               string                   `json:"n"`
380	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
381	PrivateKeyPem   string                   `json:"privateKeyPem"`
382	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
383	SHA             string                   `json:"sha"`
384	Type            string                   `json:"type"`
385	Tests           []*wycheproofTestRsaes   `json:"tests"`
386}
387
388type wycheproofTestGroupRsaesPkcs1 struct {
389	D               string                   `json:"d"`
390	E               string                   `json:"e"`
391	KeySize         int                      `json:"keysize"`
392	N               string                   `json:"n"`
393	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
394	PrivateKeyPem   string                   `json:"privateKeyPem"`
395	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
396	Type            string                   `json:"type"`
397	Tests           []*wycheproofTestRsaes   `json:"tests"`
398}
399
400type wycheproofTestRsassa struct {
401	TCID    int      `json:"tcId"`
402	Comment string   `json:"comment"`
403	Msg     string   `json:"msg"`
404	Sig     string   `json:"sig"`
405	Result  string   `json:"result"`
406	Flags   []string `json:"flags"`
407}
408
409type wycheproofTestGroupRsassa struct {
410	E       string                  `json:"e"`
411	KeyASN  string                  `json:"keyAsn"`
412	KeyDER  string                  `json:"keyDer"`
413	KeyPEM  string                  `json:"keyPem"`
414	KeySize int                     `json:"keysize"`
415	MGF     string                  `json:"mgf"`
416	MGFSHA  string                  `json:"mgfSha"`
417	N       string                  `json:"n"`
418	SLen    int                     `json:"sLen"`
419	SHA     string                  `json:"sha"`
420	Type    string                  `json:"type"`
421	Tests   []*wycheproofTestRsassa `json:"tests"`
422}
423
424type wycheproofTestX25519 struct {
425	TCID    int      `json:"tcId"`
426	Comment string   `json:"comment"`
427	Curve   string   `json:"curve"`
428	Public  string   `json:"public"`
429	Private string   `json:"private"`
430	Shared  string   `json:"shared"`
431	Result  string   `json:"result"`
432	Flags   []string `json:"flags"`
433}
434
435type wycheproofTestGroupX25519 struct {
436	Curve string                  `json:"curve"`
437	Tests []*wycheproofTestX25519 `json:"tests"`
438}
439
440type wycheproofTestVectors struct {
441	Algorithm        string            `json:"algorithm"`
442	GeneratorVersion string            `json:"generatorVersion"`
443	Notes            map[string]string `json:"notes"`
444	NumberOfTests    int               `json:"numberOfTests"`
445	// Header
446	TestGroups []json.RawMessage `json:"testGroups"`
447}
448
449var nids = map[string]int{
450	"brainpoolP224r1": C.NID_brainpoolP224r1,
451	"brainpoolP256r1": C.NID_brainpoolP256r1,
452	"brainpoolP320r1": C.NID_brainpoolP320r1,
453	"brainpoolP384r1": C.NID_brainpoolP384r1,
454	"brainpoolP512r1": C.NID_brainpoolP512r1,
455	"brainpoolP224t1": C.NID_brainpoolP224t1,
456	"brainpoolP256t1": C.NID_brainpoolP256t1,
457	"brainpoolP320t1": C.NID_brainpoolP320t1,
458	"brainpoolP384t1": C.NID_brainpoolP384t1,
459	"brainpoolP512t1": C.NID_brainpoolP512t1,
460	"secp224k1":       C.NID_secp224k1,
461	"secp224r1":       C.NID_secp224r1,
462	"secp256k1":       C.NID_secp256k1,
463	"P-256K":          C.NID_secp256k1,
464	"secp256r1":       C.NID_X9_62_prime256v1, // RFC 8422, Table 4, p.32
465	"P-256":           C.NID_X9_62_prime256v1,
466	"secp384r1":       C.NID_secp384r1,
467	"P-384":           C.NID_secp384r1,
468	"secp521r1":       C.NID_secp521r1,
469	"P-521":           C.NID_secp521r1,
470	"SHA-1":           C.NID_sha1,
471	"SHA-224":         C.NID_sha224,
472	"SHA-256":         C.NID_sha256,
473	"SHA-384":         C.NID_sha384,
474	"SHA-512":         C.NID_sha512,
475}
476
477func gatherAcceptableStatistics(testcase int, comment string, flags []string) {
478	fmt.Printf("AUDIT: Test case %d (%q) %v\n", testcase, comment, flags)
479
480	if comment == "" {
481		acceptableComments["No comment"]++
482	} else {
483		acceptableComments[comment]++
484	}
485
486	if len(flags) == 0 {
487		acceptableFlags["NoFlag"]++
488	} else {
489		for _, flag := range flags {
490			acceptableFlags[flag]++
491		}
492	}
493}
494
495func printAcceptableStatistics() {
496	fmt.Printf("\nComment statistics:\n")
497
498	var comments []string
499	for comment := range acceptableComments {
500		comments = append(comments, comment)
501	}
502	sort.Strings(comments)
503	for _, comment := range comments {
504		prcomment := comment
505		if len(comment) > 45 {
506			prcomment = comment[0:42] + "..."
507		}
508		fmt.Printf("%-45v %5d\n", prcomment, acceptableComments[comment])
509	}
510
511	fmt.Printf("\nFlag statistics:\n")
512	var flags []string
513	for flag := range acceptableFlags {
514		flags = append(flags, flag)
515	}
516	sort.Strings(flags)
517	for _, flag := range flags {
518		fmt.Printf("%-45v %5d\n", flag, acceptableFlags[flag])
519	}
520}
521
522func nidFromString(ns string) (int, error) {
523	nid, ok := nids[ns]
524	if ok {
525		return nid, nil
526	}
527	return -1, fmt.Errorf("unknown NID %q", ns)
528}
529
530func hashFromString(hs string) (hash.Hash, error) {
531	switch hs {
532	case "SHA-1":
533		return sha1.New(), nil
534	case "SHA-224":
535		return sha256.New224(), nil
536	case "SHA-256":
537		return sha256.New(), nil
538	case "SHA-384":
539		return sha512.New384(), nil
540	case "SHA-512":
541		return sha512.New(), nil
542	default:
543		return nil, fmt.Errorf("unknown hash %q", hs)
544	}
545}
546
547func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
548	switch hs {
549	case "SHA-1":
550		return C.EVP_sha1(), nil
551	case "SHA-224":
552		return C.EVP_sha224(), nil
553	case "SHA-256":
554		return C.EVP_sha256(), nil
555	case "SHA-384":
556		return C.EVP_sha384(), nil
557	case "SHA-512":
558		return C.EVP_sha512(), nil
559	default:
560		return nil, fmt.Errorf("unknown hash %q", hs)
561	}
562}
563
564func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
565	iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
566	wt *wycheproofTestAesCbcPkcs5) bool {
567	var action string
568	if doEncrypt == 1 {
569		action = "encrypting"
570	} else {
571		action = "decrypting"
572	}
573
574	ret := C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
575		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
576	if ret != 1 {
577		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
578	}
579
580	cipherOut := make([]byte, inLen+C.EVP_MAX_BLOCK_LENGTH)
581	var cipherOutLen C.int
582
583	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
584		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
585	if ret != 1 {
586		if wt.Result == "invalid" {
587			fmt.Printf("INFO: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
588				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
589			return true
590		}
591		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
592			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
593		return false
594	}
595
596	var finallen C.int
597	ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[cipherOutLen])), &finallen)
598	if ret != 1 {
599		if wt.Result == "invalid" {
600			return true
601		}
602		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
603			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
604		return false
605	}
606
607	cipherOutLen += finallen
608	if cipherOutLen != C.int(outLen) && wt.Result != "invalid" {
609		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - open length mismatch: got %d, want %d\n",
610			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen)
611		return false
612	}
613
614	openedMsg := cipherOut[0:cipherOutLen]
615	if outLen == 0 {
616		out = nil
617	}
618
619	success := false
620	if bytes.Equal(openedMsg, out) == (wt.Result != "invalid") {
621		success = true
622		if acceptableAudit && wt.Result == "acceptable" {
623			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
624		}
625	} else {
626		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - msg match: %t; want %v\n",
627			wt.TCID, wt.Comment, action, wt.Flags, bytes.Equal(openedMsg, out), wt.Result)
628	}
629	return success
630}
631
632func runAesCbcPkcs5Test(ctx *C.EVP_CIPHER_CTX, wt *wycheproofTestAesCbcPkcs5) bool {
633	key, err := hex.DecodeString(wt.Key)
634	if err != nil {
635		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
636	}
637	iv, err := hex.DecodeString(wt.IV)
638	if err != nil {
639		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
640	}
641	ct, err := hex.DecodeString(wt.CT)
642	if err != nil {
643		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
644	}
645	msg, err := hex.DecodeString(wt.Msg)
646	if err != nil {
647		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
648	}
649
650	keyLen, ivLen, ctLen, msgLen := len(key), len(iv), len(ct), len(msg)
651
652	if keyLen == 0 {
653		key = append(key, 0)
654	}
655	if ivLen == 0 {
656		iv = append(iv, 0)
657	}
658	if ctLen == 0 {
659		ct = append(ct, 0)
660	}
661	if msgLen == 0 {
662		msg = append(msg, 0)
663	}
664
665	openSuccess := checkAesCbcPkcs5(ctx, 0, key, keyLen, iv, ivLen, ct, ctLen, msg, msgLen, wt)
666	sealSuccess := checkAesCbcPkcs5(ctx, 1, key, keyLen, iv, ivLen, msg, msgLen, ct, ctLen, wt)
667
668	return openSuccess && sealSuccess
669}
670
671func runAesCbcPkcs5TestGroup(algorithm string, wtg *wycheproofTestGroupAesCbcPkcs5) bool {
672	fmt.Printf("Running %v test group %v with IV size %d and key size %d...\n",
673		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize)
674
675	var cipher *C.EVP_CIPHER
676	switch wtg.KeySize {
677	case 128:
678		cipher = C.EVP_aes_128_cbc()
679	case 192:
680		cipher = C.EVP_aes_192_cbc()
681	case 256:
682		cipher = C.EVP_aes_256_cbc()
683	default:
684		log.Fatalf("Unsupported key size: %d", wtg.KeySize)
685	}
686
687	ctx := C.EVP_CIPHER_CTX_new()
688	if ctx == nil {
689		log.Fatal("EVP_CIPHER_CTX_new() failed")
690	}
691	defer C.EVP_CIPHER_CTX_free(ctx)
692
693	ret := C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 0)
694	if ret != 1 {
695		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
696	}
697
698	success := true
699	for _, wt := range wtg.Tests {
700		if !runAesCbcPkcs5Test(ctx, wt) {
701			success = false
702		}
703	}
704	return success
705}
706
707func checkAesAead(algorithm string, ctx *C.EVP_CIPHER_CTX, doEncrypt int,
708	key []byte, keyLen int, iv []byte, ivLen int, aad []byte, aadLen int,
709	in []byte, inLen int, out []byte, outLen int, tag []byte, tagLen int,
710	wt *wycheproofTestAead) bool {
711	var ctrlSetIVLen C.int
712	var ctrlSetTag C.int
713	var ctrlGetTag C.int
714
715	doCCM := false
716	switch algorithm {
717	case "AES-CCM":
718		doCCM = true
719		ctrlSetIVLen = C.EVP_CTRL_CCM_SET_IVLEN
720		ctrlSetTag = C.EVP_CTRL_CCM_SET_TAG
721		ctrlGetTag = C.EVP_CTRL_CCM_GET_TAG
722	case "AES-GCM":
723		ctrlSetIVLen = C.EVP_CTRL_GCM_SET_IVLEN
724		ctrlSetTag = C.EVP_CTRL_GCM_SET_TAG
725		ctrlGetTag = C.EVP_CTRL_GCM_GET_TAG
726	}
727
728	setTag := unsafe.Pointer(nil)
729	var action string
730
731	if doEncrypt == 1 {
732		action = "encrypting"
733	} else {
734		action = "decrypting"
735		setTag = unsafe.Pointer(&tag[0])
736	}
737
738	ret := C.EVP_CipherInit_ex(ctx, nil, nil, nil, nil, C.int(doEncrypt))
739	if ret != 1 {
740		log.Fatalf("[%v] cipher init failed", action)
741	}
742
743	ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetIVLen, C.int(ivLen), nil)
744	if ret != 1 {
745		if wt.Comment == "Nonce is too long" || wt.Comment == "Invalid nonce size" ||
746			wt.Comment == "0 size IV is not valid" || wt.Comment == "Very long nonce" {
747			return true
748		}
749		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting IV len to %d failed. got %d, want %v\n",
750			wt.TCID, wt.Comment, action, wt.Flags, ivLen, ret, wt.Result)
751		return false
752	}
753
754	if doEncrypt == 0 || doCCM {
755		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetTag, C.int(tagLen), setTag)
756		if ret != 1 {
757			if wt.Comment == "Invalid tag size" {
758				return true
759			}
760			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting tag length to %d failed. got %d, want %v\n",
761				wt.TCID, wt.Comment, action, wt.Flags, tagLen, ret, wt.Result)
762			return false
763		}
764	}
765
766	ret = C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])),
767		(*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
768	if ret != 1 {
769		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting key and IV failed. got %d, want %v\n",
770			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
771		return false
772	}
773
774	var cipherOutLen C.int
775	if doCCM {
776		ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, nil, C.int(inLen))
777		if ret != 1 {
778			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - setting input length to %d failed. got %d, want %v\n",
779				wt.TCID, wt.Comment, action, wt.Flags, inLen, ret, wt.Result)
780			return false
781		}
782	}
783
784	ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, (*C.uchar)(unsafe.Pointer(&aad[0])), C.int(aadLen))
785	if ret != 1 {
786		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - processing AAD failed. got %d, want %v\n",
787			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
788		return false
789	}
790
791	cipherOutLen = 0
792	cipherOut := make([]byte, inLen)
793	if inLen == 0 {
794		cipherOut = append(cipherOut, 0)
795	}
796
797	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen,
798		(*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
799	if ret != 1 {
800		if wt.Result == "invalid" {
801			return true
802		}
803		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherUpdate() = %d, want %v\n",
804			wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
805		return false
806	}
807
808	if doEncrypt == 1 {
809		var tmpLen C.int
810		dummyOut := make([]byte, 16)
811
812		ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&dummyOut[0])), &tmpLen)
813		if ret != 1 {
814			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CipherFinal_ex() = %d, want %v\n",
815				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
816			return false
817		}
818		cipherOutLen += tmpLen
819	}
820
821	if cipherOutLen != C.int(outLen) {
822		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - cipherOutLen %d != outLen %d. Result %v\n",
823			wt.TCID, wt.Comment, action, wt.Flags, cipherOutLen, outLen, wt.Result)
824		return false
825	}
826
827	success := true
828	if !bytes.Equal(cipherOut, out) {
829		fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed output do not match. Result: %v\n",
830			wt.TCID, wt.Comment, action, wt.Flags, wt.Result)
831		success = false
832	}
833	if doEncrypt == 1 {
834		tagOut := make([]byte, tagLen)
835		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlGetTag, C.int(tagLen), unsafe.Pointer(&tagOut[0]))
836		if ret != 1 {
837			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - EVP_CIPHER_CTX_ctrl() = %d, want %v\n",
838				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
839			return false
840		}
841
842		// There are no acceptable CCM cases. All acceptable GCM tests
843		// pass. They have len(IV) <= 48. NIST SP 800-38D, 5.2.1.1, p.8,
844		// allows 1 <= len(IV) <= 2^64-1, but notes:
845		//   "For IVs it is recommended that implementations restrict
846		//    support to the length of 96 bits, to promote
847		//    interoperability, efficiency and simplicity of design."
848		if bytes.Equal(tagOut, tag) != (wt.Result == "valid" || wt.Result == "acceptable") {
849			fmt.Printf("FAIL: Test case %d (%q) [%v] %v - expected and computed tag do not match - ret: %d, Result: %v\n",
850				wt.TCID, wt.Comment, action, wt.Flags, ret, wt.Result)
851			success = false
852		}
853		if acceptableAudit && bytes.Equal(tagOut, tag) && wt.Result == "acceptable" {
854			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
855		}
856	}
857	return success
858}
859
860func runAesAeadTest(algorithm string, ctx *C.EVP_CIPHER_CTX, aead *C.EVP_AEAD, wt *wycheproofTestAead) bool {
861	key, err := hex.DecodeString(wt.Key)
862	if err != nil {
863		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
864	}
865
866	iv, err := hex.DecodeString(wt.IV)
867	if err != nil {
868		log.Fatalf("Failed to decode IV %q: %v", wt.IV, err)
869	}
870
871	aad, err := hex.DecodeString(wt.AAD)
872	if err != nil {
873		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
874	}
875
876	msg, err := hex.DecodeString(wt.Msg)
877	if err != nil {
878		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
879	}
880
881	ct, err := hex.DecodeString(wt.CT)
882	if err != nil {
883		log.Fatalf("Failed to decode CT %q: %v", wt.CT, err)
884	}
885
886	tag, err := hex.DecodeString(wt.Tag)
887	if err != nil {
888		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
889	}
890
891	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
892
893	if keyLen == 0 {
894		key = append(key, 0)
895	}
896	if ivLen == 0 {
897		iv = append(iv, 0)
898	}
899	if aadLen == 0 {
900		aad = append(aad, 0)
901	}
902	if msgLen == 0 {
903		msg = append(msg, 0)
904	}
905	if ctLen == 0 {
906		ct = append(ct, 0)
907	}
908	if tagLen == 0 {
909		tag = append(tag, 0)
910	}
911
912	openEvp := checkAesAead(algorithm, ctx, 0, key, keyLen, iv, ivLen, aad, aadLen, ct, ctLen, msg, msgLen, tag, tagLen, wt)
913	sealEvp := checkAesAead(algorithm, ctx, 1, key, keyLen, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
914
915	openAead, sealAead := true, true
916	if aead != nil {
917		var ctx C.EVP_AEAD_CTX
918		if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
919			log.Fatal("Failed to initialize AEAD context")
920		}
921		defer C.EVP_AEAD_CTX_cleanup(&ctx)
922
923		// Make sure we don't accidentally prepend or compare against a 0.
924		if ctLen == 0 {
925			ct = nil
926		}
927
928		openAead = checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
929		sealAead = checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
930	}
931
932	return openEvp && sealEvp && openAead && sealAead
933}
934
935func runAesAeadTestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
936	fmt.Printf("Running %v test group %v with IV size %d, key size %d and tag size %d...\n",
937		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
938
939	var cipher *C.EVP_CIPHER
940	var aead *C.EVP_AEAD
941	switch algorithm {
942	case "AES-CCM":
943		switch wtg.KeySize {
944		case 128:
945			cipher = C.EVP_aes_128_ccm()
946		case 192:
947			cipher = C.EVP_aes_192_ccm()
948		case 256:
949			cipher = C.EVP_aes_256_ccm()
950		default:
951			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
952			return true
953		}
954	case "AES-GCM":
955		switch wtg.KeySize {
956		case 128:
957			cipher = C.EVP_aes_128_gcm()
958			aead = C.EVP_aead_aes_128_gcm()
959		case 192:
960			cipher = C.EVP_aes_192_gcm()
961		case 256:
962			cipher = C.EVP_aes_256_gcm()
963			aead = C.EVP_aead_aes_256_gcm()
964		default:
965			fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
966			return true
967		}
968	default:
969		log.Fatalf("runAesAeadTestGroup() - unhandled algorithm: %v", algorithm)
970	}
971
972	ctx := C.EVP_CIPHER_CTX_new()
973	if ctx == nil {
974		log.Fatal("EVP_CIPHER_CTX_new() failed")
975	}
976	defer C.EVP_CIPHER_CTX_free(ctx)
977
978	C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 1)
979
980	success := true
981	for _, wt := range wtg.Tests {
982		if !runAesAeadTest(algorithm, ctx, aead, wt) {
983			success = false
984		}
985	}
986	return success
987}
988
989func runAesCmacTest(cipher *C.EVP_CIPHER, wt *wycheproofTestAesCmac) bool {
990	key, err := hex.DecodeString(wt.Key)
991	if err != nil {
992		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
993	}
994
995	msg, err := hex.DecodeString(wt.Msg)
996	if err != nil {
997		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
998	}
999
1000	tag, err := hex.DecodeString(wt.Tag)
1001	if err != nil {
1002		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1003	}
1004
1005	keyLen, msgLen, tagLen := len(key), len(msg), len(tag)
1006
1007	if keyLen == 0 {
1008		key = append(key, 0)
1009	}
1010	if msgLen == 0 {
1011		msg = append(msg, 0)
1012	}
1013	if tagLen == 0 {
1014		tag = append(tag, 0)
1015	}
1016
1017	mdctx := C.EVP_MD_CTX_new()
1018	if mdctx == nil {
1019		log.Fatal("EVP_MD_CTX_new failed")
1020	}
1021	defer C.EVP_MD_CTX_free(mdctx)
1022
1023	pkey := C.EVP_PKEY_new_CMAC_key(nil, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), cipher)
1024	if pkey == nil {
1025		log.Fatal("CMAC_CTX_new failed")
1026	}
1027	defer C.EVP_PKEY_free(pkey)
1028
1029	ret := C.EVP_DigestSignInit(mdctx, nil, nil, nil, pkey)
1030	if ret != 1 {
1031		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_DigestSignInit() = %d, want %v\n",
1032			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1033		return false
1034	}
1035
1036	ret = C.evpDigestSignUpdate(mdctx, unsafe.Pointer(&msg[0]), C.size_t(msgLen))
1037	if ret != 1 {
1038		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_DigestSignUpdate() = %d, want %v\n",
1039			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1040		return false
1041	}
1042
1043	var outLen C.size_t
1044	outTag := make([]byte, 16)
1045
1046	ret = C.EVP_DigestSignFinal(mdctx, (*C.uchar)(unsafe.Pointer(&outTag[0])), &outLen)
1047	if ret != 1 {
1048		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_DigestSignFinal() = %d, want %v\n",
1049			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1050		return false
1051	}
1052
1053	outTag = outTag[0:tagLen]
1054
1055	success := true
1056	if bytes.Equal(tag, outTag) != (wt.Result == "valid") {
1057		fmt.Printf("FAIL: Test case %d (%q) %v - want %v\n",
1058			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1059		success = false
1060	}
1061	return success
1062}
1063
1064func runAesCmacTestGroup(algorithm string, wtg *wycheproofTestGroupAesCmac) bool {
1065	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n",
1066		algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
1067	var cipher *C.EVP_CIPHER
1068
1069	switch wtg.KeySize {
1070	case 128:
1071		cipher = C.EVP_aes_128_cbc()
1072	case 192:
1073		cipher = C.EVP_aes_192_cbc()
1074	case 256:
1075		cipher = C.EVP_aes_256_cbc()
1076	default:
1077		fmt.Printf("INFO: Skipping tests with invalid key size %d\n", wtg.KeySize)
1078		return true
1079	}
1080
1081	success := true
1082	for _, wt := range wtg.Tests {
1083		if !runAesCmacTest(cipher, wt) {
1084			success = false
1085		}
1086	}
1087	return success
1088}
1089
1090func checkAeadOpen(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte, msgLen int,
1091	ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1092	maxOutLen := ctLen + tagLen
1093
1094	opened := make([]byte, maxOutLen)
1095	if maxOutLen == 0 {
1096		opened = append(opened, 0)
1097	}
1098	var openedMsgLen C.size_t
1099
1100	catCtTag := append(ct, tag...)
1101	catCtTagLen := len(catCtTag)
1102	if catCtTagLen == 0 {
1103		catCtTag = append(catCtTag, 0)
1104	}
1105	openRet := C.EVP_AEAD_CTX_open(ctx, (*C.uint8_t)(unsafe.Pointer(&opened[0])),
1106		(*C.size_t)(unsafe.Pointer(&openedMsgLen)), C.size_t(maxOutLen),
1107		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1108		(*C.uint8_t)(unsafe.Pointer(&catCtTag[0])), C.size_t(catCtTagLen),
1109		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1110
1111	if openRet != 1 {
1112		if wt.Result == "invalid" {
1113			return true
1114		}
1115		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_open() = %d, want %v\n",
1116			wt.TCID, wt.Comment, wt.Flags, int(openRet), wt.Result)
1117		return false
1118	}
1119
1120	if openedMsgLen != C.size_t(msgLen) {
1121		fmt.Printf("FAIL: Test case %d (%q) %v - open length mismatch: got %d, want %d\n",
1122			wt.TCID, wt.Comment, wt.Flags, openedMsgLen, msgLen)
1123		return false
1124	}
1125
1126	openedMsg := opened[0:openedMsgLen]
1127	if msgLen == 0 {
1128		msg = nil
1129	}
1130
1131	success := false
1132	if bytes.Equal(openedMsg, msg) == (wt.Result != "invalid") {
1133		if acceptableAudit && wt.Result == "acceptable" {
1134			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1135		}
1136		success = true
1137	} else {
1138		fmt.Printf("FAIL: Test case %d (%q) %v - msg match: %t; want %v\n",
1139			wt.TCID, wt.Comment, wt.Flags, bytes.Equal(openedMsg, msg), wt.Result)
1140	}
1141	return success
1142}
1143
1144func checkAeadSeal(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte,
1145	msgLen int, ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1146	maxOutLen := msgLen + tagLen
1147
1148	sealed := make([]byte, maxOutLen)
1149	if maxOutLen == 0 {
1150		sealed = append(sealed, 0)
1151	}
1152	var sealedLen C.size_t
1153
1154	sealRet := C.EVP_AEAD_CTX_seal(ctx, (*C.uint8_t)(unsafe.Pointer(&sealed[0])),
1155		(*C.size_t)(unsafe.Pointer(&sealedLen)), C.size_t(maxOutLen),
1156		(*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen),
1157		(*C.uint8_t)(unsafe.Pointer(&msg[0])), C.size_t(msgLen),
1158		(*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1159
1160	if sealRet != 1 {
1161		success := (wt.Result == "invalid")
1162		if !success {
1163			fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, int(sealRet), wt.Result)
1164		}
1165		return success
1166	}
1167
1168	if sealedLen != C.size_t(maxOutLen) {
1169		fmt.Printf("FAIL: Test case %d (%q) %v - seal length mismatch: got %d, want %d\n",
1170			wt.TCID, wt.Comment, wt.Flags, sealedLen, maxOutLen)
1171		return false
1172	}
1173
1174	sealedCt := sealed[0:msgLen]
1175	sealedTag := sealed[msgLen:maxOutLen]
1176
1177	success := false
1178	if (bytes.Equal(sealedCt, ct) && bytes.Equal(sealedTag, tag)) == (wt.Result != "invalid") {
1179		if acceptableAudit && wt.Result == "acceptable" {
1180			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1181		}
1182		success = true
1183	} else {
1184		fmt.Printf("FAIL: Test case %d (%q) %v - EVP_AEAD_CTX_seal() = %d, ct match: %t, tag match: %t; want %v\n",
1185			wt.TCID, wt.Comment, wt.Flags, int(sealRet),
1186			bytes.Equal(sealedCt, ct), bytes.Equal(sealedTag, tag), wt.Result)
1187	}
1188	return success
1189}
1190
1191func runChaCha20Poly1305Test(algorithm string, wt *wycheproofTestAead) bool {
1192	var aead *C.EVP_AEAD
1193	switch algorithm {
1194	case "CHACHA20-POLY1305":
1195		aead = C.EVP_aead_chacha20_poly1305()
1196	case "XCHACHA20-POLY1305":
1197		aead = C.EVP_aead_xchacha20_poly1305()
1198	}
1199
1200	key, err := hex.DecodeString(wt.Key)
1201	if err != nil {
1202		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
1203	}
1204	iv, err := hex.DecodeString(wt.IV)
1205	if err != nil {
1206		log.Fatalf("Failed to decode key %q: %v", wt.IV, err)
1207	}
1208	aad, err := hex.DecodeString(wt.AAD)
1209	if err != nil {
1210		log.Fatalf("Failed to decode AAD %q: %v", wt.AAD, err)
1211	}
1212	msg, err := hex.DecodeString(wt.Msg)
1213	if err != nil {
1214		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
1215	}
1216	ct, err := hex.DecodeString(wt.CT)
1217	if err != nil {
1218		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1219	}
1220	tag, err := hex.DecodeString(wt.Tag)
1221	if err != nil {
1222		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1223	}
1224
1225	keyLen, ivLen, aadLen, msgLen, ctLen, tagLen := len(key), len(iv), len(aad), len(msg), len(ct), len(tag)
1226
1227	if ivLen == 0 {
1228		iv = append(iv, 0)
1229	}
1230	if aadLen == 0 {
1231		aad = append(aad, 0)
1232	}
1233	if msgLen == 0 {
1234		msg = append(msg, 0)
1235	}
1236	if ctLen == 0 {
1237		msg = append(ct, 0)
1238	}
1239	if tagLen == 0 {
1240		msg = append(tag, 0)
1241	}
1242
1243	var ctx C.EVP_AEAD_CTX
1244	if C.EVP_AEAD_CTX_init(&ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
1245		log.Fatal("Failed to initialize AEAD context")
1246	}
1247	defer C.EVP_AEAD_CTX_cleanup(&ctx)
1248
1249	openSuccess := checkAeadOpen(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1250	sealSuccess := checkAeadSeal(&ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1251
1252	return openSuccess && sealSuccess
1253}
1254
1255func runChaCha20Poly1305TestGroup(algorithm string, wtg *wycheproofTestGroupAead) bool {
1256	// ChaCha20-Poly1305 currently only supports nonces of length 12 (96 bits)
1257	if algorithm == "CHACHA20-POLY1305" && wtg.IVSize != 96 {
1258		return true
1259	}
1260
1261	fmt.Printf("Running %v test group %v with IV size %d, key size %d, tag size %d...\n",
1262		algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
1263
1264	success := true
1265	for _, wt := range wtg.Tests {
1266		if !runChaCha20Poly1305Test(algorithm, wt) {
1267			success = false
1268		}
1269	}
1270	return success
1271}
1272
1273// DER encode the signature (so DSA_verify() can decode and encode it again)
1274func encodeDSAP1363Sig(wtSig string) (*C.uchar, C.int) {
1275	cSig := C.DSA_SIG_new()
1276	if cSig == nil {
1277		log.Fatal("DSA_SIG_new() failed")
1278	}
1279	defer C.DSA_SIG_free(cSig)
1280
1281	sigLen := len(wtSig)
1282	r := C.CString(wtSig[:sigLen/2])
1283	s := C.CString(wtSig[sigLen/2:])
1284	defer C.free(unsafe.Pointer(r))
1285	defer C.free(unsafe.Pointer(s))
1286	if C.BN_hex2bn(&cSig.r, r) == 0 {
1287		return nil, 0
1288	}
1289	if C.BN_hex2bn(&cSig.s, s) == 0 {
1290		return nil, 0
1291	}
1292
1293	derLen := C.i2d_DSA_SIG(cSig, nil)
1294	if derLen == 0 {
1295		return nil, 0
1296	}
1297	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1298	if cDer == nil {
1299		log.Fatal("malloc failed")
1300	}
1301
1302	p := cDer
1303	ret := C.i2d_DSA_SIG(cSig, (**C.uchar)(&p))
1304	if ret == 0 || ret != derLen {
1305		C.free(unsafe.Pointer(cDer))
1306		return nil, 0
1307	}
1308
1309	return cDer, derLen
1310}
1311
1312func runDSATest(dsa *C.DSA, variant testVariant, h hash.Hash, wt *wycheproofTestDSA) bool {
1313	msg, err := hex.DecodeString(wt.Msg)
1314	if err != nil {
1315		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1316	}
1317
1318	h.Reset()
1319	h.Write(msg)
1320	msg = h.Sum(nil)
1321
1322	msgLen := len(msg)
1323	if msgLen == 0 {
1324		msg = append(msg, 0)
1325	}
1326
1327	var ret C.int
1328	if variant == P1363 {
1329		cDer, derLen := encodeDSAP1363Sig(wt.Sig)
1330		if cDer == nil {
1331			fmt.Print("FAIL: unable to decode signature")
1332			return false
1333		}
1334		defer C.free(unsafe.Pointer(cDer))
1335
1336		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1337			(*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), dsa)
1338	} else {
1339		sig, err := hex.DecodeString(wt.Sig)
1340		if err != nil {
1341			log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1342		}
1343		sigLen := len(sig)
1344		if sigLen == 0 {
1345			sig = append(msg, 0)
1346		}
1347		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1348			(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), dsa)
1349	}
1350
1351	success := true
1352	if ret == 1 != (wt.Result == "valid") {
1353		fmt.Printf("FAIL: Test case %d (%q) %v - DSA_verify() = %d, want %v\n",
1354			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1355		success = false
1356	}
1357	return success
1358}
1359
1360func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupDSA) bool {
1361	fmt.Printf("Running %v test group %v, key size %d and %v...\n",
1362		algorithm, wtg.Type, wtg.Key.KeySize, wtg.SHA)
1363
1364	dsa := C.DSA_new()
1365	if dsa == nil {
1366		log.Fatal("DSA_new failed")
1367	}
1368	defer C.DSA_free(dsa)
1369
1370	var bnG *C.BIGNUM
1371	wg := C.CString(wtg.Key.G)
1372	if C.BN_hex2bn(&bnG, wg) == 0 {
1373		log.Fatal("Failed to decode g")
1374	}
1375	C.free(unsafe.Pointer(wg))
1376
1377	var bnP *C.BIGNUM
1378	wp := C.CString(wtg.Key.P)
1379	if C.BN_hex2bn(&bnP, wp) == 0 {
1380		log.Fatal("Failed to decode p")
1381	}
1382	C.free(unsafe.Pointer(wp))
1383
1384	var bnQ *C.BIGNUM
1385	wq := C.CString(wtg.Key.Q)
1386	if C.BN_hex2bn(&bnQ, wq) == 0 {
1387		log.Fatal("Failed to decode q")
1388	}
1389	C.free(unsafe.Pointer(wq))
1390
1391	ret := C.DSA_set0_pqg(dsa, bnP, bnQ, bnG)
1392	if ret != 1 {
1393		log.Fatalf("DSA_set0_pqg returned %d", ret)
1394	}
1395
1396	var bnY *C.BIGNUM
1397	wy := C.CString(wtg.Key.Y)
1398	if C.BN_hex2bn(&bnY, wy) == 0 {
1399		log.Fatal("Failed to decode y")
1400	}
1401	C.free(unsafe.Pointer(wy))
1402
1403	ret = C.DSA_set0_key(dsa, bnY, nil)
1404	if ret != 1 {
1405		log.Fatalf("DSA_set0_key returned %d", ret)
1406	}
1407
1408	h, err := hashFromString(wtg.SHA)
1409	if err != nil {
1410		log.Fatalf("Failed to get hash: %v", err)
1411	}
1412
1413	der, err := hex.DecodeString(wtg.KeyDER)
1414	if err != nil {
1415		log.Fatalf("Failed to decode DER encoded key: %v", err)
1416	}
1417
1418	derLen := len(der)
1419	if derLen == 0 {
1420		der = append(der, 0)
1421	}
1422
1423	Cder := (*C.uchar)(C.malloc(C.ulong(derLen)))
1424	if Cder == nil {
1425		log.Fatal("malloc failed")
1426	}
1427	C.memcpy(unsafe.Pointer(Cder), unsafe.Pointer(&der[0]), C.ulong(derLen))
1428
1429	p := (*C.uchar)(Cder)
1430	dsaDER := C.d2i_DSA_PUBKEY(nil, (**C.uchar)(&p), C.long(derLen))
1431	defer C.DSA_free(dsaDER)
1432	C.free(unsafe.Pointer(Cder))
1433
1434	keyPEM := C.CString(wtg.KeyPEM)
1435	bio := C.BIO_new_mem_buf(unsafe.Pointer(keyPEM), C.int(len(wtg.KeyPEM)))
1436	if bio == nil {
1437		log.Fatal("BIO_new_mem_buf failed")
1438	}
1439	defer C.free(unsafe.Pointer(keyPEM))
1440	defer C.BIO_free(bio)
1441
1442	dsaPEM := C.PEM_read_bio_DSA_PUBKEY(bio, nil, nil, nil)
1443	if dsaPEM == nil {
1444		log.Fatal("PEM_read_bio_DSA_PUBKEY failed")
1445	}
1446	defer C.DSA_free(dsaPEM)
1447
1448	success := true
1449	for _, wt := range wtg.Tests {
1450		if !runDSATest(dsa, variant, h, wt) {
1451			success = false
1452		}
1453		if !runDSATest(dsaDER, variant, h, wt) {
1454			success = false
1455		}
1456		if !runDSATest(dsaPEM, variant, h, wt) {
1457			success = false
1458		}
1459	}
1460	return success
1461}
1462
1463func runECDHTest(nid int, variant testVariant, wt *wycheproofTestECDH) bool {
1464	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1465	if privKey == nil {
1466		log.Fatalf("EC_KEY_new_by_curve_name failed")
1467	}
1468	defer C.EC_KEY_free(privKey)
1469
1470	var bnPriv *C.BIGNUM
1471	wPriv := C.CString(wt.Private)
1472	if C.BN_hex2bn(&bnPriv, wPriv) == 0 {
1473		log.Fatal("Failed to decode wPriv")
1474	}
1475	C.free(unsafe.Pointer(wPriv))
1476	defer C.BN_free(bnPriv)
1477
1478	ret := C.EC_KEY_set_private_key(privKey, bnPriv)
1479	if ret != 1 {
1480		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1481			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1482		return false
1483	}
1484
1485	pub, err := hex.DecodeString(wt.Public)
1486	if err != nil {
1487		log.Fatalf("Failed to decode public key: %v", err)
1488	}
1489
1490	pubLen := len(pub)
1491	if pubLen == 0 {
1492		pub = append(pub, 0)
1493	}
1494
1495	Cpub := (*C.uchar)(C.malloc(C.ulong(pubLen)))
1496	if Cpub == nil {
1497		log.Fatal("malloc failed")
1498	}
1499	C.memcpy(unsafe.Pointer(Cpub), unsafe.Pointer(&pub[0]), C.ulong(pubLen))
1500
1501	p := (*C.uchar)(Cpub)
1502	var pubKey *C.EC_KEY
1503	if variant == EcPoint {
1504		pubKey = C.EC_KEY_new_by_curve_name(C.int(nid))
1505		if pubKey == nil {
1506			log.Fatal("EC_KEY_new_by_curve_name failed")
1507		}
1508		pubKey = C.o2i_ECPublicKey(&pubKey, (**C.uchar)(&p), C.long(pubLen))
1509	} else {
1510		pubKey = C.d2i_EC_PUBKEY(nil, (**C.uchar)(&p), C.long(pubLen))
1511	}
1512	defer C.EC_KEY_free(pubKey)
1513	C.free(unsafe.Pointer(Cpub))
1514
1515	if pubKey == nil {
1516		if wt.Result == "invalid" || wt.Result == "acceptable" {
1517			return true
1518		}
1519		fmt.Printf("FAIL: Test case %d (%q) %v - ASN decoding failed: want %v\n",
1520			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1521		return false
1522	}
1523
1524	privGroup := C.EC_KEY_get0_group(privKey)
1525
1526	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1527
1528	secret := make([]byte, secLen)
1529	if secLen == 0 {
1530		secret = append(secret, 0)
1531	}
1532
1533	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1534
1535	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1536	if ret != C.int(secLen) {
1537		if wt.Result == "invalid" {
1538			return true
1539		}
1540		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1541			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1542		return false
1543	}
1544
1545	shared, err := hex.DecodeString(wt.Shared)
1546	if err != nil {
1547		log.Fatalf("Failed to decode shared secret: %v", err)
1548	}
1549
1550	// XXX The shared fields of the secp224k1 test cases have a 0 byte preprended.
1551	if len(shared) == int(secLen)+1 && shared[0] == 0 {
1552		fmt.Printf("INFO: Test case %d (%q) %v - prepending 0 byte\n", wt.TCID, wt.Comment, wt.Flags)
1553		// shared = shared[1:];
1554		zero := make([]byte, 1, secLen+1)
1555		secret = append(zero, secret...)
1556	}
1557
1558	success := true
1559	if !bytes.Equal(shared, secret) {
1560		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1561			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1562		success = false
1563	}
1564	if acceptableAudit && success && wt.Result == "acceptable" {
1565		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1566	}
1567	return success
1568}
1569
1570func runECDHTestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupECDH) bool {
1571	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1572		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1573
1574	nid, err := nidFromString(wtg.Curve)
1575	if err != nil {
1576		log.Fatalf("Failed to get nid for curve: %v", err)
1577	}
1578
1579	success := true
1580	for _, wt := range wtg.Tests {
1581		if !runECDHTest(nid, variant, wt) {
1582			success = false
1583		}
1584	}
1585	return success
1586}
1587
1588func runECDHWebCryptoTest(nid int, wt *wycheproofTestECDHWebCrypto) bool {
1589	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1590	if privKey == nil {
1591		log.Fatalf("EC_KEY_new_by_curve_name failed")
1592	}
1593	defer C.EC_KEY_free(privKey)
1594
1595	d, err := base64.RawURLEncoding.DecodeString(wt.Private.D)
1596	if err != nil {
1597		log.Fatalf("Failed to base64 decode d: %v", err)
1598	}
1599	bnD := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&d[0])), C.int(len(d)), nil)
1600	if bnD == nil {
1601		log.Fatal("Failed to decode D")
1602	}
1603	defer C.BN_free(bnD)
1604
1605	ret := C.EC_KEY_set_private_key(privKey, bnD)
1606	if ret != 1 {
1607		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_private_key() = %d, want %v\n",
1608			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1609		return false
1610	}
1611
1612	x, err := base64.RawURLEncoding.DecodeString(wt.Public.X)
1613	if err != nil {
1614		log.Fatalf("Failed to base64 decode x: %v", err)
1615	}
1616	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1617	if bnX == nil {
1618		log.Fatal("Failed to decode X")
1619	}
1620	defer C.BN_free(bnX)
1621
1622	y, err := base64.RawURLEncoding.DecodeString(wt.Public.Y)
1623	if err != nil {
1624		log.Fatalf("Failed to base64 decode y: %v", err)
1625	}
1626	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1627	if bnY == nil {
1628		log.Fatal("Failed to decode Y")
1629	}
1630	defer C.BN_free(bnY)
1631
1632	pubKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1633	if pubKey == nil {
1634		log.Fatal("Failed to create EC_KEY")
1635	}
1636	defer C.EC_KEY_free(pubKey)
1637
1638	ret = C.EC_KEY_set_public_key_affine_coordinates(pubKey, bnX, bnY)
1639	if ret != 1 {
1640		if wt.Result == "invalid" {
1641			return true
1642		}
1643		fmt.Printf("FAIL: Test case %d (%q) %v - EC_KEY_set_public_key_affine_coordinates() = %d, want %v\n",
1644			wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1645		return false
1646	}
1647	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1648
1649	privGroup := C.EC_KEY_get0_group(privKey)
1650
1651	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1652
1653	secret := make([]byte, secLen)
1654	if secLen == 0 {
1655		secret = append(secret, 0)
1656	}
1657
1658	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1659	if ret != C.int(secLen) {
1660		if wt.Result == "invalid" {
1661			return true
1662		}
1663		fmt.Printf("FAIL: Test case %d (%q) %v - ECDH_compute_key() = %d, want %d, result: %v\n",
1664			wt.TCID, wt.Comment, wt.Flags, ret, int(secLen), wt.Result)
1665		return false
1666	}
1667
1668	shared, err := hex.DecodeString(wt.Shared)
1669	if err != nil {
1670		log.Fatalf("Failed to decode shared secret: %v", err)
1671	}
1672
1673	success := true
1674	if !bytes.Equal(shared, secret) {
1675		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed shared secret do not match, want %v\n",
1676			wt.TCID, wt.Comment, wt.Flags, wt.Result)
1677		success = false
1678	}
1679	if acceptableAudit && success && wt.Result == "acceptable" {
1680		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1681	}
1682	return success
1683}
1684
1685func runECDHWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDHWebCrypto) bool {
1686	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n",
1687		algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1688
1689	nid, err := nidFromString(wtg.Curve)
1690	if err != nil {
1691		log.Fatalf("Failed to get nid for curve: %v", err)
1692	}
1693
1694	success := true
1695	for _, wt := range wtg.Tests {
1696		if !runECDHWebCryptoTest(nid, wt) {
1697			success = false
1698		}
1699	}
1700	return success
1701}
1702
1703func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, variant testVariant, wt *wycheproofTestECDSA) bool {
1704	msg, err := hex.DecodeString(wt.Msg)
1705	if err != nil {
1706		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
1707	}
1708
1709	h.Reset()
1710	h.Write(msg)
1711	msg = h.Sum(nil)
1712
1713	msgLen := len(msg)
1714	if msgLen == 0 {
1715		msg = append(msg, 0)
1716	}
1717
1718	var ret C.int
1719	if variant == Webcrypto || variant == P1363 {
1720		cDer, derLen := encodeECDSAWebCryptoSig(wt.Sig)
1721		if cDer == nil {
1722			fmt.Print("FAIL: unable to decode signature")
1723			return false
1724		}
1725		defer C.free(unsafe.Pointer(cDer))
1726
1727		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1728			(*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), ecKey)
1729	} else {
1730		sig, err := hex.DecodeString(wt.Sig)
1731		if err != nil {
1732			log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
1733		}
1734
1735		sigLen := len(sig)
1736		if sigLen == 0 {
1737			sig = append(sig, 0)
1738		}
1739		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen),
1740			(*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), ecKey)
1741	}
1742
1743	// XXX audit acceptable cases...
1744	success := true
1745	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
1746		fmt.Printf("FAIL: Test case %d (%q) %v - ECDSA_verify() = %d, want %v\n",
1747			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
1748		success = false
1749	}
1750	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
1751		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
1752	}
1753	return success
1754}
1755
1756func runECDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestGroupECDSA) bool {
1757	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1758		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1759
1760	nid, err := nidFromString(wtg.Key.Curve)
1761	if err != nil {
1762		log.Fatalf("Failed to get nid for curve: %v", err)
1763	}
1764	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1765	if ecKey == nil {
1766		log.Fatal("EC_KEY_new_by_curve_name failed")
1767	}
1768	defer C.EC_KEY_free(ecKey)
1769
1770	var bnX *C.BIGNUM
1771	wx := C.CString(wtg.Key.WX)
1772	if C.BN_hex2bn(&bnX, wx) == 0 {
1773		log.Fatal("Failed to decode WX")
1774	}
1775	C.free(unsafe.Pointer(wx))
1776	defer C.BN_free(bnX)
1777
1778	var bnY *C.BIGNUM
1779	wy := C.CString(wtg.Key.WY)
1780	if C.BN_hex2bn(&bnY, wy) == 0 {
1781		log.Fatal("Failed to decode WY")
1782	}
1783	C.free(unsafe.Pointer(wy))
1784	defer C.BN_free(bnY)
1785
1786	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1787		log.Fatal("Failed to set EC public key")
1788	}
1789
1790	nid, err = nidFromString(wtg.SHA)
1791	if err != nil {
1792		log.Fatalf("Failed to get MD NID: %v", err)
1793	}
1794	h, err := hashFromString(wtg.SHA)
1795	if err != nil {
1796		log.Fatalf("Failed to get hash: %v", err)
1797	}
1798
1799	success := true
1800	for _, wt := range wtg.Tests {
1801		if !runECDSATest(ecKey, nid, h, variant, wt) {
1802			success = false
1803		}
1804	}
1805	return success
1806}
1807
1808// DER encode the signature (so that ECDSA_verify() can decode and encode it again...)
1809func encodeECDSAWebCryptoSig(wtSig string) (*C.uchar, C.int) {
1810	cSig := C.ECDSA_SIG_new()
1811	if cSig == nil {
1812		log.Fatal("ECDSA_SIG_new() failed")
1813	}
1814	defer C.ECDSA_SIG_free(cSig)
1815
1816	sigLen := len(wtSig)
1817	r := C.CString(wtSig[:sigLen/2])
1818	s := C.CString(wtSig[sigLen/2:])
1819	defer C.free(unsafe.Pointer(r))
1820	defer C.free(unsafe.Pointer(s))
1821	if C.BN_hex2bn(&cSig.r, r) == 0 {
1822		return nil, 0
1823	}
1824	if C.BN_hex2bn(&cSig.s, s) == 0 {
1825		return nil, 0
1826	}
1827
1828	derLen := C.i2d_ECDSA_SIG(cSig, nil)
1829	if derLen == 0 {
1830		return nil, 0
1831	}
1832	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1833	if cDer == nil {
1834		log.Fatal("malloc failed")
1835	}
1836
1837	p := cDer
1838	ret := C.i2d_ECDSA_SIG(cSig, (**C.uchar)(&p))
1839	if ret == 0 || ret != derLen {
1840		C.free(unsafe.Pointer(cDer))
1841		return nil, 0
1842	}
1843
1844	return cDer, derLen
1845}
1846
1847func runECDSAWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDSAWebCrypto) bool {
1848	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n",
1849		algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1850
1851	nid, err := nidFromString(wtg.JWK.Crv)
1852	if err != nil {
1853		log.Fatalf("Failed to get nid for curve: %v", err)
1854	}
1855	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1856	if ecKey == nil {
1857		log.Fatal("EC_KEY_new_by_curve_name failed")
1858	}
1859	defer C.EC_KEY_free(ecKey)
1860
1861	x, err := base64.RawURLEncoding.DecodeString(wtg.JWK.X)
1862	if err != nil {
1863		log.Fatalf("Failed to base64 decode X: %v", err)
1864	}
1865	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1866	if bnX == nil {
1867		log.Fatal("Failed to decode X")
1868	}
1869	defer C.BN_free(bnX)
1870
1871	y, err := base64.RawURLEncoding.DecodeString(wtg.JWK.Y)
1872	if err != nil {
1873		log.Fatalf("Failed to base64 decode Y: %v", err)
1874	}
1875	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1876	if bnY == nil {
1877		log.Fatal("Failed to decode Y")
1878	}
1879	defer C.BN_free(bnY)
1880
1881	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1882		log.Fatal("Failed to set EC public key")
1883	}
1884
1885	nid, err = nidFromString(wtg.SHA)
1886	if err != nil {
1887		log.Fatalf("Failed to get MD NID: %v", err)
1888	}
1889	h, err := hashFromString(wtg.SHA)
1890	if err != nil {
1891		log.Fatalf("Failed to get hash: %v", err)
1892	}
1893
1894	success := true
1895	for _, wt := range wtg.Tests {
1896		if !runECDSATest(ecKey, nid, h, Webcrypto, wt) {
1897			success = false
1898		}
1899	}
1900	return success
1901}
1902
1903func runHkdfTest(md *C.EVP_MD, wt *wycheproofTestHkdf) bool {
1904	ikm, err := hex.DecodeString(wt.Ikm)
1905	if err != nil {
1906		log.Fatalf("Failed to decode ikm %q: %v", wt.Ikm, err)
1907	}
1908	salt, err := hex.DecodeString(wt.Salt)
1909	if err != nil {
1910		log.Fatalf("Failed to decode salt %q: %v", wt.Salt, err)
1911	}
1912	info, err := hex.DecodeString(wt.Info)
1913	if err != nil {
1914		log.Fatalf("Failed to decode info %q: %v", wt.Info, err)
1915	}
1916
1917	ikmLen, saltLen, infoLen := len(ikm), len(salt), len(info)
1918	if ikmLen == 0 {
1919		ikm = append(ikm, 0)
1920	}
1921	if saltLen == 0 {
1922		salt = append(salt, 0)
1923	}
1924	if infoLen == 0 {
1925		info = append(info, 0)
1926	}
1927
1928	outLen := wt.Size
1929	out := make([]byte, outLen)
1930	if outLen == 0 {
1931		out = append(out, 0)
1932	}
1933
1934	ret := C.HKDF((*C.uchar)(unsafe.Pointer(&out[0])), C.size_t(outLen), md, (*C.uchar)(unsafe.Pointer(&ikm[0])), C.size_t(ikmLen), (*C.uchar)(&salt[0]), C.size_t(saltLen), (*C.uchar)(unsafe.Pointer(&info[0])), C.size_t(infoLen))
1935
1936	if ret != 1 {
1937		success := wt.Result == "invalid"
1938		if !success {
1939			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %v\n", wt.TCID, wt.Comment, wt.Flags, ret, wt.Result)
1940		}
1941		return success
1942	}
1943
1944	okm, err := hex.DecodeString(wt.Okm)
1945	if err != nil {
1946		log.Fatalf("Failed to decode okm %q: %v", wt.Okm, err)
1947	}
1948	if !bytes.Equal(out[:outLen], okm) {
1949		fmt.Printf("FAIL: Test case %d (%q) %v - expected and computed output don't match: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
1950	}
1951
1952	return wt.Result == "valid"
1953}
1954
1955func runHkdfTestGroup(algorithm string, wtg *wycheproofTestGroupHkdf) bool {
1956	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
1957	md, err := hashEvpMdFromString(strings.TrimPrefix(algorithm, "HKDF-"))
1958	if err != nil {
1959		log.Fatalf("Failed to get hash: %v", err)
1960	}
1961
1962	success := true
1963	for _, wt := range wtg.Tests {
1964		if !runHkdfTest(md, wt) {
1965			success = false
1966		}
1967	}
1968	return success
1969}
1970
1971func runHmacTest(md *C.EVP_MD, tagBytes int, wt *wycheproofTestHmac) bool {
1972	key, err := hex.DecodeString(wt.Key)
1973	if err != nil {
1974		log.Fatalf("failed to decode key %q: %v", wt.Key, err)
1975	}
1976
1977	msg, err := hex.DecodeString(wt.Msg)
1978	if err != nil {
1979		log.Fatalf("failed to decode msg %q: %v", wt.Msg, err)
1980	}
1981
1982	keyLen, msgLen := len(key), len(msg)
1983
1984	if keyLen == 0 {
1985		key = append(key, 0)
1986	}
1987
1988	if msgLen == 0 {
1989		msg = append(msg, 0)
1990	}
1991
1992	got := make([]byte, C.EVP_MAX_MD_SIZE)
1993	var gotLen C.uint
1994
1995	ret := C.HMAC(md, unsafe.Pointer(&key[0]), C.int(keyLen), (*C.uchar)(unsafe.Pointer(&msg[0])), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&got[0])), &gotLen)
1996
1997	success := true
1998	if ret == nil {
1999		if wt.Result != "invalid" {
2000			success = false
2001			fmt.Printf("FAIL: Test case %d (%q) %v - HMAC: got nil, want %v\n", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2002		}
2003		return success
2004	}
2005
2006	if int(gotLen) < tagBytes {
2007		fmt.Printf("FAIL: Test case %d (%q) %v - HMAC length: got %d, want %d, expected %v\n", wt.TCID, wt.Comment, wt.Flags, gotLen, tagBytes, wt.Result)
2008		return false
2009	}
2010
2011	tag, err := hex.DecodeString(wt.Tag)
2012	if err != nil {
2013		log.Fatalf("failed to decode tag %q: %v", wt.Tag, err)
2014	}
2015
2016	success = bytes.Equal(got[:tagBytes], tag) == (wt.Result == "valid")
2017
2018	if !success {
2019		fmt.Printf("FAIL: Test case %d (%q) %v - got %v want %v\n", wt.TCID, wt.Comment, wt.Flags, success, wt.Result)
2020	}
2021
2022	return success
2023}
2024
2025func runHmacTestGroup(algorithm string, wtg *wycheproofTestGroupHmac) bool {
2026	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n", algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
2027	md, err := hashEvpMdFromString("SHA-" + strings.TrimPrefix(algorithm, "HMACSHA"))
2028	if err != nil {
2029		log.Fatalf("Failed to get hash: %v", err)
2030	}
2031
2032	success := true
2033	for _, wt := range wtg.Tests {
2034		if !runHmacTest(md, wtg.TagSize/8, wt) {
2035			success = false
2036		}
2037	}
2038	return success
2039}
2040
2041func runKWTestWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2042	var aesKey C.AES_KEY
2043
2044	ret := C.AES_set_encrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2045	if ret != 0 {
2046		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
2047			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2048		return false
2049	}
2050
2051	outLen := msgLen
2052	out := make([]byte, outLen)
2053	copy(out, msg)
2054	out = append(out, make([]byte, 8)...)
2055	ret = C.AES_wrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(msgLen))
2056	success := false
2057	if ret == C.int(len(out)) && bytes.Equal(out, ct) {
2058		if acceptableAudit && wt.Result == "acceptable" {
2059			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2060		}
2061		if wt.Result != "invalid" {
2062			success = true
2063		}
2064	} else if wt.Result != "valid" {
2065		success = true
2066	}
2067	if !success {
2068		fmt.Printf("FAIL: Test case %d (%q) %v - msgLen = %d, AES_wrap_key() = %d, want %v\n",
2069			wt.TCID, wt.Comment, wt.Flags, msgLen, int(ret), wt.Result)
2070	}
2071	return success
2072}
2073
2074func runKWTestUnWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2075	var aesKey C.AES_KEY
2076
2077	ret := C.AES_set_decrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2078	if ret != 0 {
2079		fmt.Printf("FAIL: Test case %d (%q) %v - AES_set_encrypt_key() = %d, want %v\n",
2080			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2081		return false
2082	}
2083
2084	out := make([]byte, ctLen)
2085	copy(out, ct)
2086	if ctLen == 0 {
2087		out = append(out, 0)
2088	}
2089	ret = C.AES_unwrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(ctLen))
2090	success := false
2091	if ret == C.int(ctLen-8) && bytes.Equal(out[0:ret], msg[0:ret]) {
2092		if acceptableAudit && wt.Result == "acceptable" {
2093			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2094		}
2095		if wt.Result != "invalid" {
2096			success = true
2097		}
2098	} else if wt.Result != "valid" {
2099		success = true
2100	}
2101	if !success {
2102		fmt.Printf("FAIL: Test case %d (%q) %v - keyLen = %d, AES_unwrap_key() = %d, want %v\n",
2103			wt.TCID, wt.Comment, wt.Flags, keyLen, int(ret), wt.Result)
2104	}
2105	return success
2106}
2107
2108func runKWTest(keySize int, wt *wycheproofTestKW) bool {
2109	key, err := hex.DecodeString(wt.Key)
2110	if err != nil {
2111		log.Fatalf("Failed to decode key %q: %v", wt.Key, err)
2112	}
2113	msg, err := hex.DecodeString(wt.Msg)
2114	if err != nil {
2115		log.Fatalf("Failed to decode msg %q: %v", wt.Msg, err)
2116	}
2117	ct, err := hex.DecodeString(wt.CT)
2118	if err != nil {
2119		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
2120	}
2121
2122	keyLen, msgLen, ctLen := len(key), len(msg), len(ct)
2123
2124	if keyLen == 0 {
2125		key = append(key, 0)
2126	}
2127	if msgLen == 0 {
2128		msg = append(msg, 0)
2129	}
2130	if ctLen == 0 {
2131		ct = append(ct, 0)
2132	}
2133
2134	wrapSuccess := runKWTestWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2135	unwrapSuccess := runKWTestUnWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2136
2137	return wrapSuccess && unwrapSuccess
2138}
2139
2140func runKWTestGroup(algorithm string, wtg *wycheproofTestGroupKW) bool {
2141	fmt.Printf("Running %v test group %v with key size %d...\n",
2142		algorithm, wtg.Type, wtg.KeySize)
2143
2144	success := true
2145	for _, wt := range wtg.Tests {
2146		if !runKWTest(wtg.KeySize, wt) {
2147			success = false
2148		}
2149	}
2150	return success
2151}
2152
2153func runRsaesOaepTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, wt *wycheproofTestRsaes) bool {
2154	ct, err := hex.DecodeString(wt.CT)
2155	if err != nil {
2156		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
2157	}
2158	ctLen := len(ct)
2159	if ctLen == 0 {
2160		ct = append(ct, 0)
2161	}
2162
2163	rsaSize := C.RSA_size(rsa)
2164	decrypted := make([]byte, rsaSize)
2165
2166	success := true
2167
2168	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_NO_PADDING)
2169
2170	if ret != rsaSize {
2171		success = (wt.Result == "invalid")
2172
2173		if !success {
2174			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
2175		}
2176		return success
2177	}
2178
2179	label, err := hex.DecodeString(wt.Label)
2180	if err != nil {
2181		log.Fatalf("Failed to decode label %q: %v", wt.Label, err)
2182	}
2183	labelLen := len(label)
2184	if labelLen == 0 {
2185		label = append(label, 0)
2186	}
2187
2188	msg, err := hex.DecodeString(wt.Msg)
2189	if err != nil {
2190		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2191	}
2192	msgLen := len(msg)
2193
2194	to := make([]byte, rsaSize)
2195
2196	ret = C.RSA_padding_check_PKCS1_OAEP_mgf1((*C.uchar)(unsafe.Pointer(&to[0])), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&decrypted[0])), C.int(rsaSize), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&label[0])), C.int(labelLen), sha, mgfSha)
2197
2198	if int(ret) != msgLen {
2199		success = (wt.Result == "invalid")
2200
2201		if !success {
2202			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, rsaSize, wt.Result)
2203		}
2204		return success
2205	}
2206
2207	to = to[:msgLen]
2208	if !bytes.Equal(msg, to) {
2209		success = false
2210		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2211	}
2212
2213	return success
2214}
2215
2216func runRsaesOaepTestGroup(algorithm string, wtg *wycheproofTestGroupRsaesOaep) bool {
2217	fmt.Printf("Running %v test group %v with key size %d MGF %v and %v...\n",
2218		algorithm, wtg.Type, wtg.KeySize, wtg.MGFSHA, wtg.SHA)
2219
2220	rsa := C.RSA_new()
2221	if rsa == nil {
2222		log.Fatal("RSA_new failed")
2223	}
2224	defer C.RSA_free(rsa)
2225
2226	d := C.CString(wtg.D)
2227	if C.BN_hex2bn(&rsa.d, d) == 0 {
2228		log.Fatal("Failed to set RSA d")
2229	}
2230	C.free(unsafe.Pointer(d))
2231
2232	e := C.CString(wtg.E)
2233	if C.BN_hex2bn(&rsa.e, e) == 0 {
2234		log.Fatal("Failed to set RSA e")
2235	}
2236	C.free(unsafe.Pointer(e))
2237
2238	n := C.CString(wtg.N)
2239	if C.BN_hex2bn(&rsa.n, n) == 0 {
2240		log.Fatal("Failed to set RSA n")
2241	}
2242	C.free(unsafe.Pointer(n))
2243
2244	sha, err := hashEvpMdFromString(wtg.SHA)
2245	if err != nil {
2246		log.Fatalf("Failed to get hash: %v", err)
2247	}
2248
2249	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2250	if err != nil {
2251		log.Fatalf("Failed to get MGF hash: %v", err)
2252	}
2253
2254	success := true
2255	for _, wt := range wtg.Tests {
2256		if !runRsaesOaepTest(rsa, sha, mgfSha, wt) {
2257			success = false
2258		}
2259	}
2260	return success
2261}
2262
2263func runRsaesPkcs1Test(rsa *C.RSA, wt *wycheproofTestRsaes) bool {
2264	ct, err := hex.DecodeString(wt.CT)
2265	if err != nil {
2266		log.Fatalf("Failed to decode cipher text %q: %v", wt.CT, err)
2267	}
2268	ctLen := len(ct)
2269	if ctLen == 0 {
2270		ct = append(ct, 0)
2271	}
2272
2273	rsaSize := C.RSA_size(rsa)
2274	decrypted := make([]byte, rsaSize)
2275
2276	success := true
2277
2278	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_PKCS1_PADDING)
2279
2280	if ret == -1 {
2281		success = (wt.Result == "invalid")
2282
2283		if !success {
2284			fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(wt.Msg)/2, wt.Result)
2285		}
2286		return success
2287	}
2288
2289	msg, err := hex.DecodeString(wt.Msg)
2290	if err != nil {
2291		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2292	}
2293
2294	if int(ret) != len(msg) {
2295		success = false
2296		fmt.Printf("FAIL: Test case %d (%q) %v - got %d, want %d. Expected: %v\n", wt.TCID, wt.Comment, wt.Flags, ret, len(msg), wt.Result)
2297	} else if !bytes.Equal(msg, decrypted[:len(msg)]) {
2298		success = false
2299		fmt.Printf("FAIL: Test case %d (%q) %v - expected and calculated message differ. Expected: %v", wt.TCID, wt.Comment, wt.Flags, wt.Result)
2300	}
2301
2302	return success
2303}
2304
2305func runRsaesPkcs1TestGroup(algorithm string, wtg *wycheproofTestGroupRsaesPkcs1) bool {
2306	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2307	rsa := C.RSA_new()
2308	if rsa == nil {
2309		log.Fatal("RSA_new failed")
2310	}
2311	defer C.RSA_free(rsa)
2312
2313	d := C.CString(wtg.D)
2314	if C.BN_hex2bn(&rsa.d, d) == 0 {
2315		log.Fatal("Failed to set RSA d")
2316	}
2317	C.free(unsafe.Pointer(d))
2318
2319	e := C.CString(wtg.E)
2320	if C.BN_hex2bn(&rsa.e, e) == 0 {
2321		log.Fatal("Failed to set RSA e")
2322	}
2323	C.free(unsafe.Pointer(e))
2324
2325	n := C.CString(wtg.N)
2326	if C.BN_hex2bn(&rsa.n, n) == 0 {
2327		log.Fatal("Failed to set RSA n")
2328	}
2329	C.free(unsafe.Pointer(n))
2330
2331	success := true
2332	for _, wt := range wtg.Tests {
2333		if !runRsaesPkcs1Test(rsa, wt) {
2334			success = false
2335		}
2336	}
2337	return success
2338}
2339
2340func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
2341	msg, err := hex.DecodeString(wt.Msg)
2342	if err != nil {
2343		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2344	}
2345
2346	h.Reset()
2347	h.Write(msg)
2348	msg = h.Sum(nil)
2349
2350	sig, err := hex.DecodeString(wt.Sig)
2351	if err != nil {
2352		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2353	}
2354
2355	msgLen, sigLen := len(msg), len(sig)
2356	if msgLen == 0 {
2357		msg = append(msg, 0)
2358	}
2359	if sigLen == 0 {
2360		sig = append(sig, 0)
2361	}
2362
2363	sigOut := make([]byte, C.RSA_size(rsa)-11)
2364	if sigLen == 0 {
2365		sigOut = append(sigOut, 0)
2366	}
2367
2368	ret := C.RSA_public_decrypt(C.int(sigLen), (*C.uchar)(unsafe.Pointer(&sig[0])),
2369		(*C.uchar)(unsafe.Pointer(&sigOut[0])), rsa, C.RSA_NO_PADDING)
2370	if ret == -1 {
2371		if wt.Result == "invalid" {
2372			return true
2373		}
2374		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_public_decrypt() = %d, want %v\n",
2375			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2376		return false
2377	}
2378
2379	ret = C.RSA_verify_PKCS1_PSS_mgf1(rsa, (*C.uchar)(unsafe.Pointer(&msg[0])), sha, mgfSha,
2380		(*C.uchar)(unsafe.Pointer(&sigOut[0])), C.int(sLen))
2381
2382	success := false
2383	if ret == 1 && (wt.Result == "valid" || wt.Result == "acceptable") {
2384		// All acceptable cases that pass use SHA-1 and are flagged:
2385		// "WeakHash" : "The key for this test vector uses a weak hash function."
2386		if acceptableAudit && wt.Result == "acceptable" {
2387			gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2388		}
2389		success = true
2390	} else if ret == 0 && (wt.Result == "invalid" || wt.Result == "acceptable") {
2391		success = true
2392	} else {
2393		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify_PKCS1_PSS_mgf1() = %d, want %v\n",
2394			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2395	}
2396	return success
2397}
2398
2399func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
2400	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2401		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2402	rsa := C.RSA_new()
2403	if rsa == nil {
2404		log.Fatal("RSA_new failed")
2405	}
2406	defer C.RSA_free(rsa)
2407
2408	e := C.CString(wtg.E)
2409	if C.BN_hex2bn(&rsa.e, e) == 0 {
2410		log.Fatal("Failed to set RSA e")
2411	}
2412	C.free(unsafe.Pointer(e))
2413
2414	n := C.CString(wtg.N)
2415	if C.BN_hex2bn(&rsa.n, n) == 0 {
2416		log.Fatal("Failed to set RSA n")
2417	}
2418	C.free(unsafe.Pointer(n))
2419
2420	h, err := hashFromString(wtg.SHA)
2421	if err != nil {
2422		log.Fatalf("Failed to get hash: %v", err)
2423	}
2424
2425	sha, err := hashEvpMdFromString(wtg.SHA)
2426	if err != nil {
2427		log.Fatalf("Failed to get hash: %v", err)
2428	}
2429
2430	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2431	if err != nil {
2432		log.Fatalf("Failed to get MGF hash: %v", err)
2433	}
2434
2435	success := true
2436	for _, wt := range wtg.Tests {
2437		if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
2438			success = false
2439		}
2440	}
2441	return success
2442}
2443
2444func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
2445	msg, err := hex.DecodeString(wt.Msg)
2446	if err != nil {
2447		log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
2448	}
2449
2450	h.Reset()
2451	h.Write(msg)
2452	msg = h.Sum(nil)
2453
2454	sig, err := hex.DecodeString(wt.Sig)
2455	if err != nil {
2456		log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
2457	}
2458
2459	msgLen, sigLen := len(msg), len(sig)
2460	if msgLen == 0 {
2461		msg = append(msg, 0)
2462	}
2463	if sigLen == 0 {
2464		sig = append(sig, 0)
2465	}
2466
2467	ret := C.RSA_verify(C.int(nid), (*C.uchar)(unsafe.Pointer(&msg[0])), C.uint(msgLen),
2468		(*C.uchar)(unsafe.Pointer(&sig[0])), C.uint(sigLen), rsa)
2469
2470	// XXX audit acceptable cases...
2471	success := true
2472	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
2473		fmt.Printf("FAIL: Test case %d (%q) %v - RSA_verify() = %d, want %v\n",
2474			wt.TCID, wt.Comment, wt.Flags, int(ret), wt.Result)
2475		success = false
2476	}
2477	if acceptableAudit && ret == 1 && wt.Result == "acceptable" {
2478		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2479	}
2480	return success
2481}
2482
2483func runRSATestGroup(algorithm string, wtg *wycheproofTestGroupRSA) bool {
2484	fmt.Printf("Running %v test group %v with key size %d and %v...\n",
2485		algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2486
2487	rsa := C.RSA_new()
2488	if rsa == nil {
2489		log.Fatal("RSA_new failed")
2490	}
2491	defer C.RSA_free(rsa)
2492
2493	e := C.CString(wtg.E)
2494	if C.BN_hex2bn(&rsa.e, e) == 0 {
2495		log.Fatal("Failed to set RSA e")
2496	}
2497	C.free(unsafe.Pointer(e))
2498
2499	n := C.CString(wtg.N)
2500	if C.BN_hex2bn(&rsa.n, n) == 0 {
2501		log.Fatal("Failed to set RSA n")
2502	}
2503	C.free(unsafe.Pointer(n))
2504
2505	nid, err := nidFromString(wtg.SHA)
2506	if err != nil {
2507		log.Fatalf("Failed to get MD NID: %v", err)
2508	}
2509	h, err := hashFromString(wtg.SHA)
2510	if err != nil {
2511		log.Fatalf("Failed to get hash: %v", err)
2512	}
2513
2514	success := true
2515	for _, wt := range wtg.Tests {
2516		if !runRSATest(rsa, nid, h, wt) {
2517			success = false
2518		}
2519	}
2520	return success
2521}
2522
2523func runX25519Test(wt *wycheproofTestX25519) bool {
2524	public, err := hex.DecodeString(wt.Public)
2525	if err != nil {
2526		log.Fatalf("Failed to decode public %q: %v", wt.Public, err)
2527	}
2528	private, err := hex.DecodeString(wt.Private)
2529	if err != nil {
2530		log.Fatalf("Failed to decode private %q: %v", wt.Private, err)
2531	}
2532	shared, err := hex.DecodeString(wt.Shared)
2533	if err != nil {
2534		log.Fatalf("Failed to decode shared %q: %v", wt.Shared, err)
2535	}
2536
2537	got := make([]byte, C.X25519_KEY_LENGTH)
2538	result := true
2539
2540	if C.X25519((*C.uint8_t)(unsafe.Pointer(&got[0])), (*C.uint8_t)(unsafe.Pointer(&private[0])), (*C.uint8_t)(unsafe.Pointer(&public[0]))) != 1 {
2541		result = false
2542	} else {
2543		result = bytes.Equal(got, shared)
2544	}
2545
2546	// XXX audit acceptable cases...
2547	success := true
2548	if result != (wt.Result == "valid") && wt.Result != "acceptable" {
2549		fmt.Printf("FAIL: Test case %d (%q) %v - X25519(), want %v\n",
2550			wt.TCID, wt.Comment, wt.Flags, wt.Result)
2551		success = false
2552	}
2553	if acceptableAudit && result && wt.Result == "acceptable" {
2554		gatherAcceptableStatistics(wt.TCID, wt.Comment, wt.Flags)
2555	}
2556	return success
2557}
2558
2559func runX25519TestGroup(algorithm string, wtg *wycheproofTestGroupX25519) bool {
2560	fmt.Printf("Running %v test group with curve %v...\n", algorithm, wtg.Curve)
2561
2562	success := true
2563	for _, wt := range wtg.Tests {
2564		if !runX25519Test(wt) {
2565			success = false
2566		}
2567	}
2568	return success
2569}
2570
2571func runTestVectors(path string, variant testVariant) bool {
2572	b, err := ioutil.ReadFile(path)
2573	if err != nil {
2574		log.Fatalf("Failed to read test vectors: %v", err)
2575	}
2576	wtv := &wycheproofTestVectors{}
2577	if err := json.Unmarshal(b, wtv); err != nil {
2578		log.Fatalf("Failed to unmarshal JSON: %v", err)
2579	}
2580	fmt.Printf("Loaded Wycheproof test vectors for %v with %d tests from %q\n",
2581		wtv.Algorithm, wtv.NumberOfTests, filepath.Base(path))
2582
2583	var wtg interface{}
2584	switch wtv.Algorithm {
2585	case "AES-CBC-PKCS5":
2586		wtg = &wycheproofTestGroupAesCbcPkcs5{}
2587	case "AES-CCM":
2588		wtg = &wycheproofTestGroupAead{}
2589	case "AES-CMAC":
2590		wtg = &wycheproofTestGroupAesCmac{}
2591	case "AES-GCM":
2592		wtg = &wycheproofTestGroupAead{}
2593	case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2594		wtg = &wycheproofTestGroupAead{}
2595	case "DSA":
2596		wtg = &wycheproofTestGroupDSA{}
2597	case "ECDH":
2598		switch variant {
2599		case Webcrypto:
2600			wtg = &wycheproofTestGroupECDHWebCrypto{}
2601		default:
2602			wtg = &wycheproofTestGroupECDH{}
2603		}
2604	case "ECDSA":
2605		switch variant {
2606		case Webcrypto:
2607			wtg = &wycheproofTestGroupECDSAWebCrypto{}
2608		default:
2609			wtg = &wycheproofTestGroupECDSA{}
2610		}
2611	case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2612		wtg = &wycheproofTestGroupHkdf{}
2613	case "HMACSHA1", "HMACSHA224", "HMACSHA256", "HMACSHA384", "HMACSHA512":
2614		wtg = &wycheproofTestGroupHmac{}
2615	case "KW":
2616		wtg = &wycheproofTestGroupKW{}
2617	case "RSAES-OAEP":
2618		wtg = &wycheproofTestGroupRsaesOaep{}
2619	case "RSAES-PKCS1-v1_5":
2620		wtg = &wycheproofTestGroupRsaesPkcs1{}
2621	case "RSASSA-PSS":
2622		wtg = &wycheproofTestGroupRsassa{}
2623	case "RSASSA-PKCS1-v1_5", "RSASig":
2624		wtg = &wycheproofTestGroupRSA{}
2625	case "XDH", "X25519":
2626		wtg = &wycheproofTestGroupX25519{}
2627	default:
2628		log.Printf("INFO: Unknown test vector algorithm %q", wtv.Algorithm)
2629		return false
2630	}
2631
2632	success := true
2633	for _, tg := range wtv.TestGroups {
2634		if err := json.Unmarshal(tg, wtg); err != nil {
2635			log.Fatalf("Failed to unmarshal test groups JSON: %v", err)
2636		}
2637		switch wtv.Algorithm {
2638		case "AES-CBC-PKCS5":
2639			if !runAesCbcPkcs5TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCbcPkcs5)) {
2640				success = false
2641			}
2642		case "AES-CCM":
2643			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2644				success = false
2645			}
2646		case "AES-CMAC":
2647			if !runAesCmacTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAesCmac)) {
2648				success = false
2649			}
2650		case "AES-GCM":
2651			if !runAesAeadTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2652				success = false
2653			}
2654		case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2655			if !runChaCha20Poly1305TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupAead)) {
2656				success = false
2657			}
2658		case "DSA":
2659			if !runDSATestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupDSA)) {
2660				success = false
2661			}
2662		case "ECDH":
2663			switch variant {
2664			case Webcrypto:
2665				if !runECDHWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDHWebCrypto)) {
2666					success = false
2667				}
2668			default:
2669				if !runECDHTestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupECDH)) {
2670					success = false
2671				}
2672			}
2673		case "ECDSA":
2674			switch variant {
2675			case Webcrypto:
2676				if !runECDSAWebCryptoTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupECDSAWebCrypto)) {
2677					success = false
2678				}
2679			default:
2680				if !runECDSATestGroup(wtv.Algorithm, variant, wtg.(*wycheproofTestGroupECDSA)) {
2681					success = false
2682				}
2683			}
2684		case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2685			if !runHkdfTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupHkdf)) {
2686				success = false
2687			}
2688		case "HMACSHA1", "HMACSHA224", "HMACSHA256", "HMACSHA384", "HMACSHA512":
2689			if !runHmacTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupHmac)) {
2690				success = false
2691			}
2692		case "KW":
2693			if !runKWTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupKW)) {
2694				success = false
2695			}
2696		case "RSAES-OAEP":
2697			if !runRsaesOaepTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesOaep)) {
2698				success = false
2699			}
2700		case "RSAES-PKCS1-v1_5":
2701			if !runRsaesPkcs1TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsaesPkcs1)) {
2702				success = false
2703			}
2704		case "RSASSA-PSS":
2705			if !runRsassaTestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRsassa)) {
2706				success = false
2707			}
2708		case "RSASSA-PKCS1-v1_5", "RSASig":
2709			if !runRSATestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupRSA)) {
2710				success = false
2711			}
2712		case "XDH", "X25519":
2713			if !runX25519TestGroup(wtv.Algorithm, wtg.(*wycheproofTestGroupX25519)) {
2714				success = false
2715			}
2716		default:
2717			log.Fatalf("Unknown test vector algorithm %q", wtv.Algorithm)
2718		}
2719	}
2720	return success
2721}
2722
2723func main() {
2724	if _, err := os.Stat(testVectorPath); os.IsNotExist(err) {
2725		fmt.Printf("package wycheproof-testvectors is required for this regress\n")
2726		fmt.Printf("SKIPPED\n")
2727		os.Exit(0)
2728	}
2729
2730	flag.BoolVar(&acceptableAudit, "v", false, "audit acceptable cases")
2731	flag.Parse()
2732
2733	acceptableComments = make(map[string]int)
2734	acceptableFlags = make(map[string]int)
2735
2736	// TODO: Investigate the following new test vectors:
2737	//	primality_test.json
2738	//	x25519_{asn,jwk,pem}_test.json
2739	tests := []struct {
2740		name    string
2741		pattern string
2742		variant testVariant
2743	}{
2744		{"AES", "aes_[cg]*[^xv]_test.json", Normal}, // Skip AES-EAX, AES-GCM-SIV and AES-SIV-CMAC.
2745		{"ChaCha20-Poly1305", "chacha20_poly1305_test.json", Normal},
2746		{"DSA", "dsa_*test.json", Normal},
2747		{"DSA", "dsa_*_p1363_test.json", P1363},
2748		{"ECDH", "ecdh_test.json", Normal},
2749		{"ECDH", "ecdh_[^w_]*_test.json", Normal},
2750		{"ECDH EcPoint", "ecdh_*_ecpoint_test.json", EcPoint},
2751		{"ECDH webcrypto", "ecdh_webcrypto_test.json", Webcrypto},
2752		{"ECDSA", "ecdsa_test.json", Normal},
2753		{"ECDSA", "ecdsa_[^w]*test.json", Normal},
2754		{"ECDSA P1363", "ecdsa_*_p1363_test.json", P1363},
2755		{"ECDSA webcrypto", "ecdsa_webcrypto_test.json", Webcrypto},
2756		{"HKDF", "hkdf_sha*_test.json", Normal},
2757		{"HMAC", "hmac_sha*_test.json", Normal},
2758		{"KW", "kw_test.json", Normal},
2759		{"RSA", "rsa_*test.json", Normal},
2760		{"X25519", "x25519_test.json", Normal},
2761		{"X25519 ASN", "x25519_asn_test.json", Skip},
2762		{"X25519 JWK", "x25519_jwk_test.json", Skip},
2763		{"X25519 PEM", "x25519_pem_test.json", Skip},
2764		{"XCHACHA20-POLY1305", "xchacha20_poly1305_test.json", Normal},
2765	}
2766
2767	success := true
2768
2769	skipNormal := regexp.MustCompile(`_(ecpoint|p1363|sha3|sha512_(224|256))_`)
2770
2771	for _, test := range tests {
2772		tvs, err := filepath.Glob(filepath.Join(testVectorPath, test.pattern))
2773		if err != nil {
2774			log.Fatalf("Failed to glob %v test vectors: %v", test.name, err)
2775		}
2776		if len(tvs) == 0 {
2777			log.Fatalf("Failed to find %v test vectors at %q\n", test.name, testVectorPath)
2778		}
2779		for _, tv := range tvs {
2780			if test.variant == Skip || (test.variant == Normal && skipNormal.Match([]byte(tv))) {
2781				fmt.Printf("INFO: Skipping tests from \"%s\"\n", strings.TrimPrefix(tv, testVectorPath+"/"))
2782				continue
2783			}
2784			if !runTestVectors(tv, test.variant) {
2785				success = false
2786			}
2787		}
2788	}
2789
2790	if acceptableAudit {
2791		printAcceptableStatistics()
2792	}
2793
2794	if !success {
2795		os.Exit(1)
2796	}
2797}
2798