1/* $OpenBSD: wycheproof.go,v 1.161 2024/11/24 10:13:16 tb Exp $ */
2/*
3 * Copyright (c) 2018,2023 Joel Sing <jsing@openbsd.org>
4 * Copyright (c) 2018,2019,2022-2024 Theo Buehler <tb@openbsd.org>
5 *
6 * Permission to use, copy, modify, and distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19// Wycheproof runs test vectors from Project Wycheproof against libcrypto.
20package main
21
22/*
23#cgo LDFLAGS: -lcrypto
24
25#include <limits.h>
26#include <string.h>
27
28#include <openssl/aes.h>
29#include <openssl/bio.h>
30#include <openssl/bn.h>
31#include <openssl/cmac.h>
32#include <openssl/curve25519.h>
33#include <openssl/dsa.h>
34#include <openssl/ec.h>
35#include <openssl/ecdsa.h>
36#include <openssl/evp.h>
37#include <openssl/kdf.h>
38#include <openssl/hmac.h>
39#include <openssl/objects.h>
40#include <openssl/pem.h>
41#include <openssl/x509.h>
42#include <openssl/rsa.h>
43
44int
45wp_EVP_PKEY_CTX_set_hkdf_md(EVP_PKEY_CTX *pctx, const EVP_MD *md)
46{
47	return EVP_PKEY_CTX_set_hkdf_md(pctx, md);
48}
49
50int
51wp_EVP_PKEY_CTX_set1_hkdf_salt(EVP_PKEY_CTX *pctx, const unsigned char *salt, size_t salt_len)
52{
53	if (salt_len > INT_MAX)
54		return 0;
55	return EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, salt_len);
56}
57
58int
59wp_EVP_PKEY_CTX_set1_hkdf_key(EVP_PKEY_CTX *pctx, const unsigned char *ikm, size_t ikm_len)
60{
61	if (ikm_len > INT_MAX)
62		return 0;
63	return EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikm_len);
64}
65
66int
67wp_EVP_PKEY_CTX_add1_hkdf_info(EVP_PKEY_CTX *pctx, const unsigned char *info, size_t info_len)
68{
69	if (info_len > INT_MAX)
70		return 0;
71	return EVP_PKEY_CTX_add1_hkdf_info(pctx, info, info_len);
72}
73*/
74import "C"
75
76import (
77	"bytes"
78	"encoding/base64"
79	"encoding/hex"
80	"encoding/json"
81	"fmt"
82	"io/ioutil"
83	"log"
84	"os"
85	"path/filepath"
86	"regexp"
87	"runtime"
88	"strings"
89	"sync"
90	"unsafe"
91)
92
93const testVectorPath = "/usr/local/share/wycheproof/testvectors"
94
95type testVariant int
96
97const (
98	Normal    testVariant = 0
99	EcPoint   testVariant = 1
100	P1363     testVariant = 2
101	Webcrypto testVariant = 3
102	Asn1      testVariant = 4
103	Pem       testVariant = 5
104	Jwk       testVariant = 6
105	Skip      testVariant = 7
106)
107
108func (variant testVariant) String() string {
109	variants := [...]string{
110		"Normal",
111		"EcPoint",
112		"P1363",
113		"Webcrypto",
114		"Asn1",
115		"Pem",
116		"Jwk",
117		"Skip",
118	}
119	return variants[variant]
120}
121
122func wycheproofFormatTestCase(tcid int, comment string, flags []string, result string) string {
123	return fmt.Sprintf("Test case %d (%q) %v %v", tcid, comment, flags, result)
124}
125
126var testc *testCoordinator
127
128type wycheproofJWKPublic struct {
129	Crv string `json:"crv"`
130	KID string `json:"kid"`
131	KTY string `json:"kty"`
132	X   string `json:"x"`
133	Y   string `json:"y"`
134}
135
136type wycheproofJWKPrivate struct {
137	Crv string `json:"crv"`
138	D   string `json:"d"`
139	KID string `json:"kid"`
140	KTY string `json:"kty"`
141	X   string `json:"x"`
142	Y   string `json:"y"`
143}
144
145type wycheproofTestGroupAesCbcPkcs5 struct {
146	IVSize  int                          `json:"ivSize"`
147	KeySize int                          `json:"keySize"`
148	Type    string                       `json:"type"`
149	Tests   []*wycheproofTestAesCbcPkcs5 `json:"tests"`
150}
151
152type wycheproofTestAesCbcPkcs5 struct {
153	TCID    int      `json:"tcId"`
154	Comment string   `json:"comment"`
155	Key     string   `json:"key"`
156	IV      string   `json:"iv"`
157	Msg     string   `json:"msg"`
158	CT      string   `json:"ct"`
159	Result  string   `json:"result"`
160	Flags   []string `json:"flags"`
161}
162
163func (wt *wycheproofTestAesCbcPkcs5) String() string {
164	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
165}
166
167type wycheproofTestGroupAead struct {
168	IVSize  int                   `json:"ivSize"`
169	KeySize int                   `json:"keySize"`
170	TagSize int                   `json:"tagSize"`
171	Type    string                `json:"type"`
172	Tests   []*wycheproofTestAead `json:"tests"`
173}
174
175type wycheproofTestGroupAesAead wycheproofTestGroupAead
176type wycheproofTestGroupChaCha wycheproofTestGroupAead
177
178type wycheproofTestAead struct {
179	TCID    int      `json:"tcId"`
180	Comment string   `json:"comment"`
181	Key     string   `json:"key"`
182	IV      string   `json:"iv"`
183	AAD     string   `json:"aad"`
184	Msg     string   `json:"msg"`
185	CT      string   `json:"ct"`
186	Tag     string   `json:"tag"`
187	Result  string   `json:"result"`
188	Flags   []string `json:"flags"`
189}
190
191func (wt *wycheproofTestAead) String() string {
192	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
193}
194
195type wycheproofTestGroupAesCmac struct {
196	KeySize int                      `json:"keySize"`
197	TagSize int                      `json:"tagSize"`
198	Type    string                   `json:"type"`
199	Tests   []*wycheproofTestAesCmac `json:"tests"`
200}
201
202type wycheproofTestAesCmac struct {
203	TCID    int      `json:"tcId"`
204	Comment string   `json:"comment"`
205	Key     string   `json:"key"`
206	Msg     string   `json:"msg"`
207	Tag     string   `json:"tag"`
208	Result  string   `json:"result"`
209	Flags   []string `json:"flags"`
210}
211
212func (wt *wycheproofTestAesCmac) String() string {
213	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
214}
215
216type wycheproofDSAKey struct {
217	G       string `json:"g"`
218	KeySize int    `json:"keySize"`
219	P       string `json:"p"`
220	Q       string `json:"q"`
221	Type    string `json:"type"`
222	Y       string `json:"y"`
223}
224
225type wycheproofTestDSA struct {
226	TCID    int      `json:"tcId"`
227	Comment string   `json:"comment"`
228	Msg     string   `json:"msg"`
229	Sig     string   `json:"sig"`
230	Result  string   `json:"result"`
231	Flags   []string `json:"flags"`
232}
233
234func (wt *wycheproofTestDSA) String() string {
235	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
236}
237
238type wycheproofTestGroupDSA struct {
239	Key    *wycheproofDSAKey    `json:"key"`
240	KeyDER string               `json:"keyDer"`
241	KeyPEM string               `json:"keyPem"`
242	SHA    string               `json:"sha"`
243	Type   string               `json:"type"`
244	Tests  []*wycheproofTestDSA `json:"tests"`
245}
246
247type wycheproofTestECDH struct {
248	TCID    int      `json:"tcId"`
249	Comment string   `json:"comment"`
250	Public  string   `json:"public"`
251	Private string   `json:"private"`
252	Shared  string   `json:"shared"`
253	Result  string   `json:"result"`
254	Flags   []string `json:"flags"`
255}
256
257func (wt *wycheproofTestECDH) String() string {
258	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
259}
260
261type wycheproofTestGroupECDH struct {
262	Curve    string                `json:"curve"`
263	Encoding string                `json:"encoding"`
264	Type     string                `json:"type"`
265	Tests    []*wycheproofTestECDH `json:"tests"`
266}
267
268type wycheproofTestECDHWebCrypto struct {
269	TCID    int                   `json:"tcId"`
270	Comment string                `json:"comment"`
271	Public  *wycheproofJWKPublic  `json:"public"`
272	Private *wycheproofJWKPrivate `json:"private"`
273	Shared  string                `json:"shared"`
274	Result  string                `json:"result"`
275	Flags   []string              `json:"flags"`
276}
277
278func (wt *wycheproofTestECDHWebCrypto) String() string {
279	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
280}
281
282type wycheproofTestGroupECDHWebCrypto struct {
283	Curve    string                         `json:"curve"`
284	Encoding string                         `json:"encoding"`
285	Type     string                         `json:"type"`
286	Tests    []*wycheproofTestECDHWebCrypto `json:"tests"`
287}
288
289type wycheproofECDSAKey struct {
290	Curve        string `json:"curve"`
291	KeySize      int    `json:"keySize"`
292	Type         string `json:"type"`
293	Uncompressed string `json:"uncompressed"`
294	WX           string `json:"wx"`
295	WY           string `json:"wy"`
296}
297
298type wycheproofTestECDSA struct {
299	TCID    int      `json:"tcId"`
300	Comment string   `json:"comment"`
301	Msg     string   `json:"msg"`
302	Sig     string   `json:"sig"`
303	Result  string   `json:"result"`
304	Flags   []string `json:"flags"`
305}
306
307func (wt *wycheproofTestECDSA) String() string {
308	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
309}
310
311type wycheproofTestGroupECDSA struct {
312	Key    *wycheproofECDSAKey    `json:"key"`
313	KeyDER string                 `json:"keyDer"`
314	KeyPEM string                 `json:"keyPem"`
315	SHA    string                 `json:"sha"`
316	Type   string                 `json:"type"`
317	Tests  []*wycheproofTestECDSA `json:"tests"`
318}
319
320type wycheproofTestGroupECDSAWebCrypto struct {
321	JWK    *wycheproofJWKPublic   `json:"jwk"`
322	Key    *wycheproofECDSAKey    `json:"key"`
323	KeyDER string                 `json:"keyDer"`
324	KeyPEM string                 `json:"keyPem"`
325	SHA    string                 `json:"sha"`
326	Type   string                 `json:"type"`
327	Tests  []*wycheproofTestECDSA `json:"tests"`
328}
329
330type wycheproofJWKEdDSA struct {
331	Crv string `json:"crv"`
332	D   string `json:"d"`
333	KID string `json:"kid"`
334	KTY string `json:"kty"`
335	X   string `json:"x"`
336}
337
338type wycheproofEdDSAKey struct {
339	Curve   string `json:"curve"`
340	KeySize int    `json:"keySize"`
341	Pk      string `json:"pk"`
342	Sk      string `json:"sk"`
343	Type    string `json:"type"`
344}
345
346type wycheproofTestEdDSA struct {
347	TCID    int      `json:"tcId"`
348	Comment string   `json:"comment"`
349	Msg     string   `json:"msg"`
350	Sig     string   `json:"sig"`
351	Result  string   `json:"result"`
352	Flags   []string `json:"flags"`
353}
354
355func (wt *wycheproofTestEdDSA) String() string {
356	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
357}
358
359type wycheproofTestGroupEdDSA struct {
360	JWK    *wycheproofJWKEdDSA    `json:"jwk"`
361	Key    *wycheproofEdDSAKey    `json:"key"`
362	KeyDer string                 `json:"keyDer"`
363	KeyPem string                 `json:"keyPem"`
364	Type   string                 `json:"type"`
365	Tests  []*wycheproofTestEdDSA `json:"tests"`
366}
367
368type wycheproofTestHkdf struct {
369	TCID    int      `json:"tcId"`
370	Comment string   `json:"comment"`
371	Ikm     string   `json:"ikm"`
372	Salt    string   `json:"salt"`
373	Info    string   `json:"info"`
374	Size    int      `json:"size"`
375	Okm     string   `json:"okm"`
376	Result  string   `json:"result"`
377	Flags   []string `json:"flags"`
378}
379
380func (wt *wycheproofTestHkdf) String() string {
381	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
382}
383
384type wycheproofTestGroupHkdf struct {
385	Type    string                `json:"type"`
386	KeySize int                   `json:"keySize"`
387	Tests   []*wycheproofTestHkdf `json:"tests"`
388}
389
390type wycheproofTestHmac struct {
391	TCID    int      `json:"tcId"`
392	Comment string   `json:"comment"`
393	Key     string   `json:"key"`
394	Msg     string   `json:"msg"`
395	Tag     string   `json:"tag"`
396	Result  string   `json:"result"`
397	Flags   []string `json:"flags"`
398}
399
400func (wt *wycheproofTestHmac) String() string {
401	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
402}
403
404type wycheproofTestGroupHmac struct {
405	KeySize int                   `json:"keySize"`
406	TagSize int                   `json:"tagSize"`
407	Type    string                `json:"type"`
408	Tests   []*wycheproofTestHmac `json:"tests"`
409}
410
411type wycheproofTestKW struct {
412	TCID    int      `json:"tcId"`
413	Comment string   `json:"comment"`
414	Key     string   `json:"key"`
415	Msg     string   `json:"msg"`
416	CT      string   `json:"ct"`
417	Result  string   `json:"result"`
418	Flags   []string `json:"flags"`
419}
420
421func (wt *wycheproofTestKW) String() string {
422	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
423}
424
425type wycheproofTestGroupKW struct {
426	KeySize int                 `json:"keySize"`
427	Type    string              `json:"type"`
428	Tests   []*wycheproofTestKW `json:"tests"`
429}
430
431type wycheproofTestPrimality struct {
432	TCID    int      `json:"tcId"`
433	Comment string   `json:"comment"`
434	Value   string   `json:"value"`
435	Result  string   `json:"result"`
436	Flags   []string `json:"flags"`
437}
438
439func (wt *wycheproofTestPrimality) String() string {
440	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
441}
442
443type wycheproofTestGroupPrimality struct {
444	Type  string                     `json:"type"`
445	Tests []*wycheproofTestPrimality `json:"tests"`
446}
447
448type wycheproofTestRSA struct {
449	TCID    int      `json:"tcId"`
450	Comment string   `json:"comment"`
451	Msg     string   `json:"msg"`
452	Sig     string   `json:"sig"`
453	Padding string   `json:"padding"`
454	Result  string   `json:"result"`
455	Flags   []string `json:"flags"`
456}
457
458func (wt *wycheproofTestRSA) String() string {
459	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
460}
461
462type wycheproofTestGroupRSA struct {
463	E       string               `json:"e"`
464	KeyASN  string               `json:"keyAsn"`
465	KeyDER  string               `json:"keyDer"`
466	KeyPEM  string               `json:"keyPem"`
467	KeySize int                  `json:"keysize"`
468	N       string               `json:"n"`
469	SHA     string               `json:"sha"`
470	Type    string               `json:"type"`
471	Tests   []*wycheproofTestRSA `json:"tests"`
472}
473
474type wycheproofPrivateKeyJwk struct {
475	Alg string `json:"alg"`
476	D   string `json:"d"`
477	DP  string `json:"dp"`
478	DQ  string `json:"dq"`
479	E   string `json:"e"`
480	KID string `json:"kid"`
481	Kty string `json:"kty"`
482	N   string `json:"n"`
483	P   string `json:"p"`
484	Q   string `json:"q"`
485	QI  string `json:"qi"`
486}
487
488type wycheproofTestRsaes struct {
489	TCID    int      `json:"tcId"`
490	Comment string   `json:"comment"`
491	Msg     string   `json:"msg"`
492	CT      string   `json:"ct"`
493	Label   string   `json:"label"`
494	Result  string   `json:"result"`
495	Flags   []string `json:"flags"`
496}
497
498func (wt *wycheproofTestRsaes) String() string {
499	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
500}
501
502type wycheproofTestGroupRsaesOaep struct {
503	D               string                   `json:"d"`
504	E               string                   `json:"e"`
505	KeySize         int                      `json:"keysize"`
506	MGF             string                   `json:"mgf"`
507	MGFSHA          string                   `json:"mgfSha"`
508	N               string                   `json:"n"`
509	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
510	PrivateKeyPem   string                   `json:"privateKeyPem"`
511	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
512	SHA             string                   `json:"sha"`
513	Type            string                   `json:"type"`
514	Tests           []*wycheproofTestRsaes   `json:"tests"`
515}
516
517type wycheproofTestGroupRsaesPkcs1 struct {
518	D               string                   `json:"d"`
519	E               string                   `json:"e"`
520	KeySize         int                      `json:"keysize"`
521	N               string                   `json:"n"`
522	PrivateKeyJwk   *wycheproofPrivateKeyJwk `json:"privateKeyJwk"`
523	PrivateKeyPem   string                   `json:"privateKeyPem"`
524	PrivateKeyPkcs8 string                   `json:"privateKeyPkcs8"`
525	Type            string                   `json:"type"`
526	Tests           []*wycheproofTestRsaes   `json:"tests"`
527}
528
529type wycheproofTestRsassa struct {
530	TCID    int      `json:"tcId"`
531	Comment string   `json:"comment"`
532	Msg     string   `json:"msg"`
533	Sig     string   `json:"sig"`
534	Result  string   `json:"result"`
535	Flags   []string `json:"flags"`
536}
537
538func (wt *wycheproofTestRsassa) String() string {
539	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
540}
541
542type wycheproofTestGroupRsassa struct {
543	E       string                  `json:"e"`
544	KeyASN  string                  `json:"keyAsn"`
545	KeyDER  string                  `json:"keyDer"`
546	KeyPEM  string                  `json:"keyPem"`
547	KeySize int                     `json:"keysize"`
548	MGF     string                  `json:"mgf"`
549	MGFSHA  string                  `json:"mgfSha"`
550	N       string                  `json:"n"`
551	SLen    int                     `json:"sLen"`
552	SHA     string                  `json:"sha"`
553	Type    string                  `json:"type"`
554	Tests   []*wycheproofTestRsassa `json:"tests"`
555}
556
557type wycheproofTestX25519 struct {
558	TCID    int      `json:"tcId"`
559	Comment string   `json:"comment"`
560	Curve   string   `json:"curve"`
561	Public  string   `json:"public"`
562	Private string   `json:"private"`
563	Shared  string   `json:"shared"`
564	Result  string   `json:"result"`
565	Flags   []string `json:"flags"`
566}
567
568func (wt *wycheproofTestX25519) String() string {
569	return wycheproofFormatTestCase(wt.TCID, wt.Comment, wt.Flags, wt.Result)
570}
571
572type wycheproofTestGroupX25519 struct {
573	Curve string                  `json:"curve"`
574	Tests []*wycheproofTestX25519 `json:"tests"`
575}
576
577type wycheproofTestGroupRunner interface {
578	run(string, testVariant) bool
579}
580
581type wycheproofTestVectors struct {
582	Algorithm        string            `json:"algorithm"`
583	GeneratorVersion string            `json:"generatorVersion"`
584	Notes            map[string]string `json:"notes"`
585	NumberOfTests    int               `json:"numberOfTests"`
586	// Header
587	TestGroups []json.RawMessage `json:"testGroups"`
588}
589
590var nids = map[string]int{
591	"brainpoolP224r1": C.NID_brainpoolP224r1,
592	"brainpoolP256r1": C.NID_brainpoolP256r1,
593	"brainpoolP320r1": C.NID_brainpoolP320r1,
594	"brainpoolP384r1": C.NID_brainpoolP384r1,
595	"brainpoolP512r1": C.NID_brainpoolP512r1,
596	"brainpoolP224t1": C.NID_brainpoolP224t1,
597	"brainpoolP256t1": C.NID_brainpoolP256t1,
598	"brainpoolP320t1": C.NID_brainpoolP320t1,
599	"brainpoolP384t1": C.NID_brainpoolP384t1,
600	"brainpoolP512t1": C.NID_brainpoolP512t1,
601	"FRP256v1":        C.NID_FRP256v1,
602	"secp160k1":       C.NID_secp160k1,
603	"secp160r1":       C.NID_secp160r1,
604	"secp160r2":       C.NID_secp160r2,
605	"secp192k1":       C.NID_secp192k1,
606	"secp192r1":       C.NID_X9_62_prime192v1, // RFC 8422, Table 4, p.32
607	"secp224k1":       C.NID_secp224k1,
608	"secp224r1":       C.NID_secp224r1,
609	"secp256k1":       C.NID_secp256k1,
610	"P-256K":          C.NID_secp256k1,
611	"secp256r1":       C.NID_X9_62_prime256v1, // RFC 8422, Table 4, p.32
612	"P-256":           C.NID_X9_62_prime256v1,
613	"sect283k1":       C.NID_sect283k1,
614	"sect283r1":       C.NID_sect283r1,
615	"secp384r1":       C.NID_secp384r1,
616	"P-384":           C.NID_secp384r1,
617	"sect409k1":       C.NID_sect409k1,
618	"sect409r1":       C.NID_sect409r1,
619	"secp521r1":       C.NID_secp521r1,
620	"sect571k1":       C.NID_sect571k1,
621	"sect571r1":       C.NID_sect571r1,
622	"P-521":           C.NID_secp521r1,
623	"SHA-1":           C.NID_sha1,
624	"SHA-224":         C.NID_sha224,
625	"SHA-256":         C.NID_sha256,
626	"SHA-384":         C.NID_sha384,
627	"SHA-512":         C.NID_sha512,
628	"SHA-512/224":     C.NID_sha512_224,
629	"SHA-512/256":     C.NID_sha512_256,
630	"SHA3-224":        C.NID_sha3_224,
631	"SHA3-256":        C.NID_sha3_256,
632	"SHA3-384":        C.NID_sha3_384,
633	"SHA3-512":        C.NID_sha3_512,
634}
635
636func nidFromString(ns string) (int, error) {
637	nid, ok := nids[ns]
638	if ok {
639		return nid, nil
640	}
641	return -1, fmt.Errorf("unknown NID %q", ns)
642}
643
644func skipSmallCurve(nid int) bool {
645	switch C.int(nid) {
646	case C.NID_secp160k1, C.NID_secp160r1, C.NID_secp160r2, C.NID_secp192k1, C.NID_X9_62_prime192v1:
647		return true
648	}
649	return false
650}
651
652var evpMds = map[string]*C.EVP_MD{
653	"SHA-1":       C.EVP_sha1(),
654	"SHA-224":     C.EVP_sha224(),
655	"SHA-256":     C.EVP_sha256(),
656	"SHA-384":     C.EVP_sha384(),
657	"SHA-512":     C.EVP_sha512(),
658	"SHA-512/224": C.EVP_sha512_224(),
659	"SHA-512/256": C.EVP_sha512_256(),
660	"SHA3-224":    C.EVP_sha3_224(),
661	"SHA3-256":    C.EVP_sha3_256(),
662	"SHA3-384":    C.EVP_sha3_384(),
663	"SHA3-512":    C.EVP_sha3_512(),
664}
665
666func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
667	md, ok := evpMds[hs]
668	if ok {
669		return md, nil
670	}
671	return nil, fmt.Errorf("unknown hash %q", hs)
672}
673
674var aesCbcs = map[int]*C.EVP_CIPHER{
675	128: C.EVP_aes_128_cbc(),
676	192: C.EVP_aes_192_cbc(),
677	256: C.EVP_aes_256_cbc(),
678}
679
680var aesCcms = map[int]*C.EVP_CIPHER{
681	128: C.EVP_aes_128_ccm(),
682	192: C.EVP_aes_192_ccm(),
683	256: C.EVP_aes_256_ccm(),
684}
685
686var aesGcms = map[int]*C.EVP_CIPHER{
687	128: C.EVP_aes_128_gcm(),
688	192: C.EVP_aes_192_gcm(),
689	256: C.EVP_aes_256_gcm(),
690}
691
692var aeses = map[string]map[int]*C.EVP_CIPHER{
693	"AES-CBC": aesCbcs,
694	"AES-CCM": aesCcms,
695	"AES-GCM": aesGcms,
696}
697
698func cipherAes(algorithm string, size int) (*C.EVP_CIPHER, error) {
699	cipher, ok := aeses[algorithm][size]
700	if ok {
701		return cipher, nil
702	}
703	return nil, fmt.Errorf("invalid key size: %d", size)
704}
705
706var aesAeads = map[int]*C.EVP_AEAD{
707	128: C.EVP_aead_aes_128_gcm(),
708	192: nil,
709	256: C.EVP_aead_aes_256_gcm(),
710}
711
712func aeadAes(size int) (*C.EVP_AEAD, error) {
713	aead, ok := aesAeads[size]
714	if ok {
715		return aead, nil
716	}
717	return nil, fmt.Errorf("invalid key size: %d", size)
718}
719
720func mustHashHexMessage(md *C.EVP_MD, message string) ([]byte, int) {
721	size := C.EVP_MD_size(md)
722	if size <= 0 || size > C.EVP_MAX_MD_SIZE {
723		log.Fatalf("unexpected MD size %d", size)
724	}
725
726	msg, msgLen := mustDecodeHexString(message, "for hex message digest")
727
728	digest := make([]byte, size)
729
730	if C.EVP_Digest(unsafe.Pointer(&msg[0]), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&digest[0])), nil, md, nil) != 1 {
731		log.Fatalf("EVP_Digest failed")
732	}
733
734	return digest, int(size)
735}
736
737func mustDecodeHexString(str, descr string) (out []byte, outLen int) {
738	out, err := hex.DecodeString(str)
739	if err != nil {
740		log.Fatalf("Failed to decode %s %q: %v", descr, str, err)
741	}
742	outLen = len(out)
743	if outLen == 0 {
744		out = append(out, 0)
745	}
746	return out, outLen
747}
748
749func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int, iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int, wt *wycheproofTestAesCbcPkcs5) bool {
750	var action string
751	if doEncrypt == 1 {
752		action = "encrypting"
753	} else {
754		action = "decrypting"
755	}
756
757	ret := C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])), (*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
758	if ret != 1 {
759		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
760	}
761
762	cipherOut := make([]byte, inLen+C.EVP_MAX_BLOCK_LENGTH)
763	var cipherOutLen C.int
764
765	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen, (*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
766	if ret != 1 {
767		if wt.Result == "invalid" {
768			fmt.Printf("INFO: %s [%v] - EVP_CipherUpdate() = %d\n", wt, action, ret)
769			return true
770		}
771		fmt.Printf("FAIL: %s [%v] - EVP_CipherUpdate() = %d\n", wt, action, ret)
772		return false
773	}
774
775	var finallen C.int
776	ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[cipherOutLen])), &finallen)
777	if ret != 1 {
778		if wt.Result == "invalid" {
779			return true
780		}
781		fmt.Printf("FAIL: %s [%v] - EVP_CipherFinal_ex() = %d\n", wt, action, ret)
782		return false
783	}
784
785	cipherOutLen += finallen
786	if cipherOutLen != C.int(outLen) && wt.Result != "invalid" {
787		fmt.Printf("FAIL: %s [%v] - open length mismatch: got %d, want %d\n", wt, action, cipherOutLen, outLen)
788		return false
789	}
790
791	openedMsg := cipherOut[0:cipherOutLen]
792	if outLen == 0 {
793		out = nil
794	}
795
796	success := false
797	if bytes.Equal(openedMsg, out) == (wt.Result != "invalid") {
798		success = true
799	} else {
800		fmt.Printf("FAIL: %s [%v] - msg match: %t\n", wt, action, bytes.Equal(openedMsg, out))
801	}
802	return success
803}
804
805func runAesCbcPkcs5Test(ctx *C.EVP_CIPHER_CTX, wt *wycheproofTestAesCbcPkcs5) bool {
806	key, keyLen := mustDecodeHexString(wt.Key, "key")
807	iv, ivLen := mustDecodeHexString(wt.IV, "iv")
808	ct, ctLen := mustDecodeHexString(wt.CT, "ct")
809	msg, msgLen := mustDecodeHexString(wt.Msg, "message")
810
811	openSuccess := checkAesCbcPkcs5(ctx, 0, key, keyLen, iv, ivLen, ct, ctLen, msg, msgLen, wt)
812	sealSuccess := checkAesCbcPkcs5(ctx, 1, key, keyLen, iv, ivLen, msg, msgLen, ct, ctLen, wt)
813
814	return openSuccess && sealSuccess
815}
816
817func (wtg *wycheproofTestGroupAesCbcPkcs5) run(algorithm string, variant testVariant) bool {
818	fmt.Printf("Running %v test group %v with IV size %d and key size %d...\n", algorithm, wtg.Type, wtg.IVSize, wtg.KeySize)
819
820	cipher, err := cipherAes("AES-CBC", wtg.KeySize)
821	if err != nil {
822		log.Fatal(err)
823	}
824
825	ctx := C.EVP_CIPHER_CTX_new()
826	if ctx == nil {
827		log.Fatal("EVP_CIPHER_CTX_new() failed")
828	}
829	defer C.EVP_CIPHER_CTX_free(ctx)
830
831	ret := C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 0)
832	if ret != 1 {
833		log.Fatalf("EVP_CipherInit_ex failed: %d", ret)
834	}
835
836	success := true
837	for _, wt := range wtg.Tests {
838		if !runAesCbcPkcs5Test(ctx, wt) {
839			success = false
840		}
841	}
842	return success
843}
844
845func checkAesAead(algorithm string, ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int, iv []byte, ivLen int, aad []byte, aadLen int, in []byte, inLen int, out []byte, outLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
846	var ctrlSetIVLen C.int
847	var ctrlSetTag C.int
848	var ctrlGetTag C.int
849
850	doCCM := false
851	switch algorithm {
852	case "AES-CCM":
853		doCCM = true
854		ctrlSetIVLen = C.EVP_CTRL_CCM_SET_IVLEN
855		ctrlSetTag = C.EVP_CTRL_CCM_SET_TAG
856		ctrlGetTag = C.EVP_CTRL_CCM_GET_TAG
857	case "AES-GCM":
858		ctrlSetIVLen = C.EVP_CTRL_GCM_SET_IVLEN
859		ctrlSetTag = C.EVP_CTRL_GCM_SET_TAG
860		ctrlGetTag = C.EVP_CTRL_GCM_GET_TAG
861	}
862
863	setTag := unsafe.Pointer(nil)
864	var action string
865
866	if doEncrypt == 1 {
867		action = "encrypting"
868	} else {
869		action = "decrypting"
870		setTag = unsafe.Pointer(&tag[0])
871	}
872
873	ret := C.EVP_CipherInit_ex(ctx, nil, nil, nil, nil, C.int(doEncrypt))
874	if ret != 1 {
875		log.Fatalf("[%v] cipher init failed", action)
876	}
877
878	ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetIVLen, C.int(ivLen), nil)
879	if ret != 1 {
880		if wt.Comment == "Nonce is too long" || wt.Comment == "Invalid nonce size" ||
881			wt.Comment == "0 size IV is not valid" || wt.Comment == "Very long nonce" {
882			return true
883		}
884		fmt.Printf("FAIL: %s [%v] - setting IV len to %d failed: %d.\n", wt, action, ivLen, ret)
885		return false
886	}
887
888	if doEncrypt == 0 || doCCM {
889		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlSetTag, C.int(tagLen), setTag)
890		if ret != 1 {
891			if wt.Comment == "Invalid tag size" {
892				return true
893			}
894			fmt.Printf("FAIL: %s [%v] - setting tag length to %d failed: %d.\n", wt, action, tagLen, ret)
895			return false
896		}
897	}
898
899	ret = C.EVP_CipherInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])), (*C.uchar)(unsafe.Pointer(&iv[0])), C.int(doEncrypt))
900	if ret != 1 {
901		fmt.Printf("FAIL: %s [%v] - setting key and IV failed: %d.\n", wt, action, ret)
902		return false
903	}
904
905	var cipherOutLen C.int
906	if doCCM {
907		ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, nil, C.int(inLen))
908		if ret != 1 {
909			fmt.Printf("FAIL: %s [%v] - setting input length to %d failed: %d.\n", wt, action, inLen, ret)
910			return false
911		}
912	}
913
914	ret = C.EVP_CipherUpdate(ctx, nil, &cipherOutLen, (*C.uchar)(unsafe.Pointer(&aad[0])), C.int(aadLen))
915	if ret != 1 {
916		fmt.Printf("FAIL: %s [%v] - processing AAD failed: %d.\n", wt, action, ret)
917		return false
918	}
919
920	cipherOutLen = 0
921	cipherOut := make([]byte, inLen)
922	if inLen == 0 {
923		cipherOut = append(cipherOut, 0)
924	}
925
926	ret = C.EVP_CipherUpdate(ctx, (*C.uchar)(unsafe.Pointer(&cipherOut[0])), &cipherOutLen, (*C.uchar)(unsafe.Pointer(&in[0])), C.int(inLen))
927	if ret != 1 {
928		if wt.Result == "invalid" {
929			return true
930		}
931		fmt.Printf("FAIL: %s [%v] - EVP_CipherUpdate() = %d.\n", wt, action, ret)
932		return false
933	}
934
935	if doEncrypt == 1 {
936		var tmpLen C.int
937		dummyOut := make([]byte, 16)
938
939		ret = C.EVP_CipherFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&dummyOut[0])), &tmpLen)
940		if ret != 1 {
941			fmt.Printf("FAIL: %s [%v] - EVP_CipherFinal_ex() = %d.\n", wt, action, ret)
942			return false
943		}
944		cipherOutLen += tmpLen
945	}
946
947	if cipherOutLen != C.int(outLen) {
948		fmt.Printf("FAIL: %s [%v] - cipherOutLen %d != outLen %d.\n", wt, action, cipherOutLen, outLen)
949		return false
950	}
951
952	success := true
953	if !bytes.Equal(cipherOut, out) {
954		fmt.Printf("FAIL: %s [%v] - expected and computed output do not match.\n", wt, action)
955		success = false
956	}
957	if doEncrypt == 1 {
958		tagOut := make([]byte, tagLen)
959		ret = C.EVP_CIPHER_CTX_ctrl(ctx, ctrlGetTag, C.int(tagLen), unsafe.Pointer(&tagOut[0]))
960		if ret != 1 {
961			fmt.Printf("FAIL: %s [%v] - EVP_CIPHER_CTX_ctrl() = %d.\n", wt, action, ret)
962			return false
963		}
964
965		// There are no acceptable CCM cases. All acceptable GCM tests
966		// pass. They have len(IV) <= 48. NIST SP 800-38D, 5.2.1.1, p.8,
967		// allows 1 <= len(IV) <= 2^64-1, but notes:
968		//   "For IVs it is recommended that implementations restrict
969		//    support to the length of 96 bits, to promote
970		//    interoperability, efficiency and simplicity of design."
971		if bytes.Equal(tagOut, tag) != (wt.Result == "valid" || wt.Result == "acceptable") {
972			fmt.Printf("FAIL: %s [%v] - expected and computed tag do not match: %d.\n", wt, action, ret)
973			success = false
974		}
975	}
976	return success
977}
978
979func runAesAeadTest(algorithm string, ctx *C.EVP_CIPHER_CTX, aead *C.EVP_AEAD, wt *wycheproofTestAead) bool {
980	key, keyLen := mustDecodeHexString(wt.Key, "key")
981	iv, ivLen := mustDecodeHexString(wt.IV, "iv")
982	aad, aadLen := mustDecodeHexString(wt.AAD, "aad")
983	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
984	ct, ctLen := mustDecodeHexString(wt.CT, "CT")
985	tag, tagLen := mustDecodeHexString(wt.Tag, "tag")
986
987	openEvp := checkAesAead(algorithm, ctx, 0, key, keyLen, iv, ivLen, aad, aadLen, ct, ctLen, msg, msgLen, tag, tagLen, wt)
988	sealEvp := checkAesAead(algorithm, ctx, 1, key, keyLen, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
989
990	openAead, sealAead := true, true
991	if aead != nil {
992		ctx := C.EVP_AEAD_CTX_new()
993		if ctx == nil {
994			log.Fatal("EVP_AEAD_CTX_new() failed")
995		}
996		defer C.EVP_AEAD_CTX_free(ctx)
997
998		if C.EVP_AEAD_CTX_init(ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
999			log.Fatal("Failed to initialize AEAD context")
1000		}
1001
1002		// Make sure we don't accidentally prepend or compare against a 0.
1003		if ctLen == 0 {
1004			ct = nil
1005		}
1006
1007		openAead = checkAeadOpen(ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1008		sealAead = checkAeadSeal(ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1009	}
1010
1011	return openEvp && sealEvp && openAead && sealAead
1012}
1013
1014func (wtg *wycheproofTestGroupAesAead) run(algorithm string, variant testVariant) bool {
1015	fmt.Printf("Running %v test group %v with IV size %d, key size %d and tag size %d...\n", algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
1016
1017	cipher, err := cipherAes(algorithm, wtg.KeySize)
1018	if err != nil {
1019		fmt.Printf("INFO: Skipping tests with %s\n", err)
1020		return true
1021	}
1022	var aead *C.EVP_AEAD
1023	if algorithm == "AES-GCM" {
1024		aead, err = aeadAes(wtg.KeySize)
1025		if err != nil {
1026			log.Fatalf("%s", err)
1027		}
1028	}
1029
1030	ctx := C.EVP_CIPHER_CTX_new()
1031	if ctx == nil {
1032		log.Fatal("EVP_CIPHER_CTX_new() failed")
1033	}
1034	defer C.EVP_CIPHER_CTX_free(ctx)
1035
1036	C.EVP_CipherInit_ex(ctx, cipher, nil, nil, nil, 1)
1037
1038	success := true
1039	for _, wt := range wtg.Tests {
1040		if !runAesAeadTest(algorithm, ctx, aead, wt) {
1041			success = false
1042		}
1043	}
1044	return success
1045}
1046
1047func runAesCmacTest(cipher *C.EVP_CIPHER, wt *wycheproofTestAesCmac) bool {
1048	key, keyLen := mustDecodeHexString(wt.Key, "key")
1049	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
1050	tag, tagLen := mustDecodeHexString(wt.Tag, "tag")
1051
1052	mdctx := C.EVP_MD_CTX_new()
1053	if mdctx == nil {
1054		log.Fatal("EVP_MD_CTX_new failed")
1055	}
1056	defer C.EVP_MD_CTX_free(mdctx)
1057
1058	pkey := C.EVP_PKEY_new_CMAC_key(nil, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), cipher)
1059	if pkey == nil {
1060		log.Fatal("CMAC_CTX_new failed")
1061	}
1062	defer C.EVP_PKEY_free(pkey)
1063
1064	ret := C.EVP_DigestSignInit(mdctx, nil, nil, nil, pkey)
1065	if ret != 1 {
1066		fmt.Printf("FAIL: %s - EVP_DigestSignInit() = %d.\n", wt, ret)
1067		return false
1068	}
1069
1070	var outLen C.size_t
1071	outTag := make([]byte, 16)
1072
1073	ret = C.EVP_DigestSign(mdctx, (*C.uchar)(unsafe.Pointer(&outTag[0])), &outLen, (*C.uchar)(unsafe.Pointer(&msg[0])), C.size_t(msgLen))
1074	if ret != 1 {
1075		fmt.Printf("FAIL: %s - EVP_DigestSign() = %d.\n", wt, ret)
1076		return false
1077	}
1078
1079	outTag = outTag[0:tagLen]
1080
1081	success := true
1082	if bytes.Equal(tag, outTag) != (wt.Result == "valid") {
1083		fmt.Printf("FAIL: %s.\n", wt)
1084		success = false
1085	}
1086	return success
1087}
1088
1089func (wtg *wycheproofTestGroupAesCmac) run(algorithm string, variant testVariant) bool {
1090	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n", algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
1091
1092	cipher, err := cipherAes("AES-CBC", wtg.KeySize)
1093	if err != nil {
1094		fmt.Printf("INFO: Skipping tests with %d.\n", err)
1095		return true
1096	}
1097
1098	success := true
1099	for _, wt := range wtg.Tests {
1100		if !runAesCmacTest(cipher, wt) {
1101			success = false
1102		}
1103	}
1104	return success
1105}
1106
1107func checkAeadOpen(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte, msgLen int, ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1108	maxOutLen := ctLen + tagLen
1109
1110	opened := make([]byte, maxOutLen)
1111	if maxOutLen == 0 {
1112		opened = append(opened, 0)
1113	}
1114	var openedMsgLen C.size_t
1115
1116	catCtTag := append(ct, tag...)
1117	catCtTagLen := len(catCtTag)
1118	if catCtTagLen == 0 {
1119		catCtTag = append(catCtTag, 0)
1120	}
1121	openRet := C.EVP_AEAD_CTX_open(ctx, (*C.uint8_t)(unsafe.Pointer(&opened[0])), (*C.size_t)(unsafe.Pointer(&openedMsgLen)), C.size_t(maxOutLen), (*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen), (*C.uint8_t)(unsafe.Pointer(&catCtTag[0])), C.size_t(catCtTagLen), (*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1122
1123	if openRet != 1 {
1124		if wt.Result == "invalid" {
1125			return true
1126		}
1127		fmt.Printf("FAIL: %s - EVP_AEAD_CTX_open() = %d.\n", wt, int(openRet))
1128		return false
1129	}
1130
1131	if openedMsgLen != C.size_t(msgLen) {
1132		fmt.Printf("FAIL: %s - open length mismatch: got %d, want %d.\n", wt, openedMsgLen, msgLen)
1133		return false
1134	}
1135
1136	openedMsg := opened[0:openedMsgLen]
1137	if msgLen == 0 {
1138		msg = nil
1139	}
1140
1141	success := false
1142	if bytes.Equal(openedMsg, msg) == (wt.Result != "invalid") {
1143		success = true
1144	} else {
1145		fmt.Printf("FAIL: %s - msg match: %t.\n", wt, bytes.Equal(openedMsg, msg))
1146	}
1147	return success
1148}
1149
1150func checkAeadSeal(ctx *C.EVP_AEAD_CTX, iv []byte, ivLen int, aad []byte, aadLen int, msg []byte, msgLen int, ct []byte, ctLen int, tag []byte, tagLen int, wt *wycheproofTestAead) bool {
1151	maxOutLen := msgLen + tagLen
1152
1153	sealed := make([]byte, maxOutLen)
1154	if maxOutLen == 0 {
1155		sealed = append(sealed, 0)
1156	}
1157	var sealedLen C.size_t
1158
1159	sealRet := C.EVP_AEAD_CTX_seal(ctx, (*C.uint8_t)(unsafe.Pointer(&sealed[0])), (*C.size_t)(unsafe.Pointer(&sealedLen)), C.size_t(maxOutLen), (*C.uint8_t)(unsafe.Pointer(&iv[0])), C.size_t(ivLen), (*C.uint8_t)(unsafe.Pointer(&msg[0])), C.size_t(msgLen), (*C.uint8_t)(unsafe.Pointer(&aad[0])), C.size_t(aadLen))
1160
1161	if sealRet != 1 {
1162		success := (wt.Result == "invalid")
1163		if !success {
1164			fmt.Printf("FAIL: %s - EVP_AEAD_CTX_seal() = %d.\n", wt, int(sealRet))
1165		}
1166		return success
1167	}
1168
1169	if sealedLen != C.size_t(maxOutLen) {
1170		fmt.Printf("FAIL: %s - seal length mismatch: got %d, want %d.\n", wt, sealedLen, maxOutLen)
1171		return false
1172	}
1173
1174	sealedCt := sealed[0:msgLen]
1175	sealedTag := sealed[msgLen:maxOutLen]
1176
1177	success := false
1178	if (bytes.Equal(sealedCt, ct) && bytes.Equal(sealedTag, tag)) == (wt.Result != "invalid") {
1179		success = true
1180	} else {
1181		fmt.Printf("FAIL: %s - EVP_AEAD_CTX_seal() = %d, ct match: %t, tag match: %t.\n", wt, int(sealRet), bytes.Equal(sealedCt, ct), bytes.Equal(sealedTag, tag))
1182	}
1183	return success
1184}
1185
1186func runChaCha20Poly1305Test(algorithm string, wt *wycheproofTestAead) bool {
1187	var aead *C.EVP_AEAD
1188	switch algorithm {
1189	case "CHACHA20-POLY1305":
1190		aead = C.EVP_aead_chacha20_poly1305()
1191	case "XCHACHA20-POLY1305":
1192		aead = C.EVP_aead_xchacha20_poly1305()
1193	}
1194
1195	key, keyLen := mustDecodeHexString(wt.Key, "key")
1196	iv, ivLen := mustDecodeHexString(wt.IV, "iv")
1197	aad, aadLen := mustDecodeHexString(wt.AAD, "aad")
1198	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
1199
1200	// ct and tag are concatenated in checkAeadOpen(), so don't use mustDecodeHexString()
1201	ct, err := hex.DecodeString(wt.CT)
1202	if err != nil {
1203		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1204	}
1205	tag, err := hex.DecodeString(wt.Tag)
1206	if err != nil {
1207		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1208	}
1209	ctLen, tagLen := len(ct), len(tag)
1210
1211	ctx := C.EVP_AEAD_CTX_new()
1212	if ctx == nil {
1213		log.Fatal("EVP_AEAD_CTX_new() failed")
1214	}
1215	defer C.EVP_AEAD_CTX_free(ctx)
1216	if C.EVP_AEAD_CTX_init(ctx, aead, (*C.uchar)(unsafe.Pointer(&key[0])), C.size_t(keyLen), C.size_t(tagLen), nil) != 1 {
1217		log.Fatal("Failed to initialize AEAD context")
1218	}
1219
1220	openSuccess := checkAeadOpen(ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1221	sealSuccess := checkAeadSeal(ctx, iv, ivLen, aad, aadLen, msg, msgLen, ct, ctLen, tag, tagLen, wt)
1222
1223	return openSuccess && sealSuccess
1224}
1225
1226func runEvpChaCha20Poly1305Test(ctx *C.EVP_CIPHER_CTX, algorithm string, wt *wycheproofTestAead) bool {
1227	var aead *C.EVP_CIPHER
1228	switch algorithm {
1229	case "CHACHA20-POLY1305":
1230		aead = C.EVP_chacha20_poly1305()
1231	case "XCHACHA20-POLY1305":
1232		return true
1233	}
1234
1235	key, _ := mustDecodeHexString(wt.Key, "key")
1236	iv, ivLen := mustDecodeHexString(wt.IV, "iv")
1237	aad, aadLen := mustDecodeHexString(wt.AAD, "aad")
1238	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
1239
1240	ct, err := hex.DecodeString(wt.CT)
1241	if err != nil {
1242		log.Fatalf("Failed to decode ct %q: %v", wt.CT, err)
1243	}
1244	tag, err := hex.DecodeString(wt.Tag)
1245	if err != nil {
1246		log.Fatalf("Failed to decode tag %q: %v", wt.Tag, err)
1247	}
1248	ctLen, tagLen := len(ct), len(tag)
1249
1250	if C.EVP_EncryptInit_ex(ctx, aead, nil, nil, nil) != 1 {
1251		log.Fatal("Failed to initialize EVP_CIPHER_CTX with cipher")
1252	}
1253	if C.EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_AEAD_SET_IVLEN, C.int(ivLen), nil) != 1 {
1254		log.Fatal("Failed EVP_CTRL_AEAD_SET_IVLEN")
1255	}
1256	if C.EVP_EncryptInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])), nil) != 1 {
1257		log.Fatal("Failed EVP_EncryptInit_ex key")
1258	}
1259	if C.EVP_EncryptInit_ex(ctx, nil, nil, nil, (*C.uchar)(unsafe.Pointer(&iv[0]))) != 1 {
1260		log.Fatal("Failed EVP_EncryptInit_ex iv")
1261	}
1262
1263	var len C.int
1264
1265	if C.EVP_EncryptUpdate(ctx, nil, (*C.int)(unsafe.Pointer(&len)), (*C.uchar)(&aad[0]), (C.int)(aadLen)) != 1 {
1266		log.Fatal("Failed EVP_EncryptUpdate aad")
1267	}
1268
1269	sealed := make([]byte, ctLen + tagLen)
1270	copy(sealed, msg)
1271	if C.EVP_EncryptUpdate(ctx, (*C.uchar)(unsafe.Pointer(&sealed[0])), (*C.int)(unsafe.Pointer(&len)), (*C.uchar)(unsafe.Pointer(&sealed[0])), (C.int)(msgLen)) != 1 {
1272		log.Fatal("Failed EVP_EncryptUpdate msg")
1273	}
1274	outLen := len
1275	if C.EVP_EncryptFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&sealed[outLen])), (*C.int)(unsafe.Pointer(&len))) != 1 {
1276		log.Fatal("Failed EVP_EncryptFinal msg")
1277	}
1278	outLen += len
1279	if C.EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_AEAD_GET_TAG, (C.int)(tagLen), unsafe.Pointer(&sealed[outLen])) != 1 {
1280		log.Fatal("Failed EVP_CTRL_AEAD_GET_TAG")
1281	}
1282	outLen += (C.int)(tagLen)
1283
1284	if (C.int)(ctLen + tagLen) != outLen {
1285		fmt.Printf("%s\n", wt)
1286	}
1287
1288	sealSuccess := false
1289	ctMatch := bytes.Equal(ct, sealed[:ctLen])
1290	tagMatch := bytes.Equal(tag, sealed[ctLen:])
1291	if (ctMatch && tagMatch) == (wt.Result != "invalid") {
1292		sealSuccess = true
1293	} else  {
1294		fmt.Printf("%s - ct match: %t tag match: %t\n", wt, ctMatch, tagMatch)
1295	}
1296
1297	if C.EVP_DecryptInit_ex(ctx, aead, nil, nil, nil) != 1 {
1298		log.Fatal("Failed to initialize EVP_CIPHER_CTX with cipher")
1299	}
1300	if C.EVP_DecryptInit_ex(ctx, nil, nil, (*C.uchar)(unsafe.Pointer(&key[0])), nil) != 1 {
1301		log.Fatal("Failed EVP_EncryptInit_ex key")
1302	}
1303
1304	if C.EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_AEAD_SET_IVLEN, C.int(ivLen), nil) != 1 {
1305		log.Fatal("Failed EVP_CTRL_AEAD_SET_IVLEN")
1306	}
1307	if C.EVP_DecryptInit_ex(ctx, nil, nil, nil, (*C.uchar)(unsafe.Pointer(&iv[0]))) != 1 {
1308		log.Fatal("Failed EVP_EncryptInit_ex iv")
1309	}
1310
1311	if C.EVP_CIPHER_CTX_ctrl(ctx, C.EVP_CTRL_AEAD_SET_TAG, (C.int)(tagLen), unsafe.Pointer(&tag[0])) != 1 {
1312		log.Fatal("Failed EVP_CTRL_AEAD_SET_TAG")
1313	}
1314
1315	if ctLen == 0 {
1316		ct = append(ct, 0)
1317	}
1318
1319	opened := make([]byte, msgLen + tagLen)
1320	copy(opened, ct)
1321	if msgLen + aadLen == 0 {
1322		opened = append(opened, 0)
1323	}
1324
1325	if C.EVP_DecryptUpdate(ctx, nil, (*C.int)(unsafe.Pointer(&len)), (*C.uchar)(unsafe.Pointer(&aad[0])), C.int(aadLen)) != 1 {
1326		log.Fatal("Failed EVP_EncryptUpdate msg")
1327	}
1328
1329	if C.EVP_DecryptUpdate(ctx, (*C.uchar)(unsafe.Pointer(&opened[0])), (*C.int)(unsafe.Pointer(&len)), (*C.uchar)(unsafe.Pointer(&opened[0])), (C.int)(ctLen)) != 1 {
1330		log.Fatal("Failed EVP_EncryptUpdate msg")
1331	}
1332	outLen = len
1333
1334	var ret C.int
1335	if wt.Result != "invalid" {
1336		ret = 1
1337	}
1338
1339	if C.EVP_DecryptFinal_ex(ctx, (*C.uchar)(unsafe.Pointer(&opened[outLen])), (*C.int)(unsafe.Pointer(&len))) != ret {
1340		log.Fatalf("Failed EVP_EncryptFinal msg %s\n", wt)
1341	}
1342	outLen += len
1343
1344	openSuccess := true
1345	if (C.int)(msgLen) != outLen {
1346		openSuccess = false
1347		fmt.Printf("%s\n", wt)
1348	}
1349
1350	if wt.Result != "invalid" && !bytes.Equal(opened[:outLen], msg[:msgLen]) {
1351		fmt.Printf("failed %s\n", wt)
1352		openSuccess = false
1353	}
1354
1355	return sealSuccess && openSuccess
1356}
1357
1358func (wtg *wycheproofTestGroupChaCha) run(algorithm string, variant testVariant) bool {
1359	// ChaCha20-Poly1305 currently only supports nonces of length 12 (96 bits)
1360	if algorithm == "CHACHA20-POLY1305" && wtg.IVSize != 96 {
1361		return true
1362	}
1363
1364	fmt.Printf("Running %v test group %v with IV size %d, key size %d, tag size %d...\n", algorithm, wtg.Type, wtg.IVSize, wtg.KeySize, wtg.TagSize)
1365
1366	ctx := C.EVP_CIPHER_CTX_new()
1367	if ctx == nil {
1368		log.Fatal("EVP_CIPHER_CTX_new() failed")
1369	}
1370	defer C.EVP_CIPHER_CTX_free(ctx)
1371
1372	success := true
1373	for _, wt := range wtg.Tests {
1374		if !runChaCha20Poly1305Test(algorithm, wt) {
1375			success = false
1376		}
1377		if !runEvpChaCha20Poly1305Test(ctx, algorithm, wt) {
1378			success = false
1379		}
1380	}
1381	return success
1382}
1383
1384// DER encode the signature (so DSA_verify() can decode and encode it again)
1385func encodeDSAP1363Sig(wtSig string) (*C.uchar, C.int) {
1386	cSig := C.DSA_SIG_new()
1387	if cSig == nil {
1388		log.Fatal("DSA_SIG_new() failed")
1389	}
1390	defer C.DSA_SIG_free(cSig)
1391
1392	sigLen := len(wtSig)
1393	r := C.CString(wtSig[:sigLen/2])
1394	s := C.CString(wtSig[sigLen/2:])
1395	defer C.free(unsafe.Pointer(r))
1396	defer C.free(unsafe.Pointer(s))
1397	var sigR *C.BIGNUM
1398	var sigS *C.BIGNUM
1399	defer C.BN_free(sigR)
1400	defer C.BN_free(sigS)
1401	if C.BN_hex2bn(&sigR, r) == 0 {
1402		return nil, 0
1403	}
1404	if C.BN_hex2bn(&sigS, s) == 0 {
1405		return nil, 0
1406	}
1407	if C.DSA_SIG_set0(cSig, sigR, sigS) == 0 {
1408		return nil, 0
1409	}
1410	sigR = nil
1411	sigS = nil
1412
1413	derLen := C.i2d_DSA_SIG(cSig, nil)
1414	if derLen == 0 {
1415		return nil, 0
1416	}
1417	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1418	if cDer == nil {
1419		log.Fatal("malloc failed")
1420	}
1421
1422	p := cDer
1423	ret := C.i2d_DSA_SIG(cSig, (**C.uchar)(&p))
1424	if ret == 0 || ret != derLen {
1425		C.free(unsafe.Pointer(cDer))
1426		return nil, 0
1427	}
1428
1429	return cDer, derLen
1430}
1431
1432func runDSATest(dsa *C.DSA, md *C.EVP_MD, variant testVariant, wt *wycheproofTestDSA) bool {
1433	msg, msgLen := mustHashHexMessage(md, wt.Msg)
1434
1435	var ret C.int
1436	if variant == P1363 {
1437		cDer, derLen := encodeDSAP1363Sig(wt.Sig)
1438		if cDer == nil {
1439			fmt.Print("FAIL: unable to decode signature")
1440			return false
1441		}
1442		defer C.free(unsafe.Pointer(cDer))
1443
1444		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen), (*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), dsa)
1445	} else {
1446		sig, sigLen := mustDecodeHexString(wt.Sig, "sig")
1447		ret = C.DSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen), (*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), dsa)
1448	}
1449
1450	success := true
1451	if ret == 1 != (wt.Result == "valid") {
1452		fmt.Printf("FAIL: %s - DSA_verify() = %d.\n", wt, wt.Result)
1453		success = false
1454	}
1455	return success
1456}
1457
1458func (wtg *wycheproofTestGroupDSA) run(algorithm string, variant testVariant) bool {
1459	fmt.Printf("Running %v test group %v, key size %d and %v...\n", algorithm, wtg.Type, wtg.Key.KeySize, wtg.SHA)
1460
1461	dsa := C.DSA_new()
1462	if dsa == nil {
1463		log.Fatal("DSA_new failed")
1464	}
1465	defer C.DSA_free(dsa)
1466
1467	var bnG *C.BIGNUM
1468	wg := C.CString(wtg.Key.G)
1469	if C.BN_hex2bn(&bnG, wg) == 0 {
1470		log.Fatal("Failed to decode g")
1471	}
1472	C.free(unsafe.Pointer(wg))
1473
1474	var bnP *C.BIGNUM
1475	wp := C.CString(wtg.Key.P)
1476	if C.BN_hex2bn(&bnP, wp) == 0 {
1477		log.Fatal("Failed to decode p")
1478	}
1479	C.free(unsafe.Pointer(wp))
1480
1481	var bnQ *C.BIGNUM
1482	wq := C.CString(wtg.Key.Q)
1483	if C.BN_hex2bn(&bnQ, wq) == 0 {
1484		log.Fatal("Failed to decode q")
1485	}
1486	C.free(unsafe.Pointer(wq))
1487
1488	ret := C.DSA_set0_pqg(dsa, bnP, bnQ, bnG)
1489	if ret != 1 {
1490		log.Fatalf("DSA_set0_pqg returned %d", ret)
1491	}
1492
1493	var bnY *C.BIGNUM
1494	wy := C.CString(wtg.Key.Y)
1495	if C.BN_hex2bn(&bnY, wy) == 0 {
1496		log.Fatal("Failed to decode y")
1497	}
1498	C.free(unsafe.Pointer(wy))
1499
1500	ret = C.DSA_set0_key(dsa, bnY, nil)
1501	if ret != 1 {
1502		log.Fatalf("DSA_set0_key returned %d", ret)
1503	}
1504
1505	md, err := hashEvpMdFromString(wtg.SHA)
1506	if err != nil {
1507		log.Fatalf("Failed to get hash: %v", err)
1508	}
1509
1510	der, derLen := mustDecodeHexString(wtg.KeyDER, "DER encoded key")
1511
1512	Cder := (*C.uchar)(C.malloc(C.ulong(derLen)))
1513	if Cder == nil {
1514		log.Fatal("malloc failed")
1515	}
1516	C.memcpy(unsafe.Pointer(Cder), unsafe.Pointer(&der[0]), C.ulong(derLen))
1517
1518	p := (*C.uchar)(Cder)
1519	dsaDER := C.d2i_DSA_PUBKEY(nil, (**C.uchar)(&p), C.long(derLen))
1520	defer C.DSA_free(dsaDER)
1521	C.free(unsafe.Pointer(Cder))
1522
1523	keyPEM := C.CString(wtg.KeyPEM)
1524	bio := C.BIO_new_mem_buf(unsafe.Pointer(keyPEM), C.int(len(wtg.KeyPEM)))
1525	if bio == nil {
1526		log.Fatal("BIO_new_mem_buf failed")
1527	}
1528	defer C.free(unsafe.Pointer(keyPEM))
1529	defer C.BIO_free(bio)
1530
1531	dsaPEM := C.PEM_read_bio_DSA_PUBKEY(bio, nil, nil, nil)
1532	if dsaPEM == nil {
1533		log.Fatal("PEM_read_bio_DSA_PUBKEY failed")
1534	}
1535	defer C.DSA_free(dsaPEM)
1536
1537	success := true
1538	for _, wt := range wtg.Tests {
1539		if !runDSATest(dsa, md, variant, wt) {
1540			success = false
1541		}
1542		if !runDSATest(dsaDER, md, variant, wt) {
1543			success = false
1544		}
1545		if !runDSATest(dsaPEM, md, variant, wt) {
1546			success = false
1547		}
1548	}
1549	return success
1550}
1551
1552func runECDHTest(nid int, variant testVariant, wt *wycheproofTestECDH) bool {
1553	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1554	if privKey == nil {
1555		log.Fatalf("EC_KEY_new_by_curve_name failed")
1556	}
1557	defer C.EC_KEY_free(privKey)
1558
1559	var bnPriv *C.BIGNUM
1560	wPriv := C.CString(wt.Private)
1561	if C.BN_hex2bn(&bnPriv, wPriv) == 0 {
1562		log.Fatal("Failed to decode wPriv")
1563	}
1564	C.free(unsafe.Pointer(wPriv))
1565	defer C.BN_free(bnPriv)
1566
1567	ret := C.EC_KEY_set_private_key(privKey, bnPriv)
1568	if ret != 1 {
1569		fmt.Printf("FAIL: %s - EC_KEY_set_private_key() = %d.\n", wt, ret)
1570		return false
1571	}
1572
1573	pub, pubLen := mustDecodeHexString(wt.Public, "public key")
1574
1575	Cpub := (*C.uchar)(C.malloc(C.ulong(pubLen)))
1576	if Cpub == nil {
1577		log.Fatal("malloc failed")
1578	}
1579	C.memcpy(unsafe.Pointer(Cpub), unsafe.Pointer(&pub[0]), C.ulong(pubLen))
1580
1581	p := (*C.uchar)(Cpub)
1582	var pubKey *C.EC_KEY
1583	if variant == EcPoint {
1584		pubKey = C.EC_KEY_new_by_curve_name(C.int(nid))
1585		if pubKey == nil {
1586			log.Fatal("EC_KEY_new_by_curve_name failed")
1587		}
1588		pubKey = C.o2i_ECPublicKey(&pubKey, (**C.uchar)(&p), C.long(pubLen))
1589	} else {
1590		pubKey = C.d2i_EC_PUBKEY(nil, (**C.uchar)(&p), C.long(pubLen))
1591	}
1592	defer C.EC_KEY_free(pubKey)
1593	C.free(unsafe.Pointer(Cpub))
1594
1595	if pubKey == nil {
1596		if wt.Result == "invalid" || wt.Result == "acceptable" {
1597			return true
1598		}
1599		fmt.Printf("FAIL: %s - ASN decoding failed.\n", wt)
1600		return false
1601	}
1602
1603	privGroup := C.EC_KEY_get0_group(privKey)
1604
1605	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1606
1607	secret := make([]byte, secLen)
1608	if secLen == 0 {
1609		secret = append(secret, 0)
1610	}
1611
1612	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1613
1614	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1615	if ret != C.int(secLen) {
1616		if wt.Result == "invalid" {
1617			return true
1618		}
1619		fmt.Printf("FAIL: %s - ECDH_compute_key() = %d, want %d.\n", wt, ret, int(secLen))
1620		return false
1621	}
1622
1623	shared, sharedLen := mustDecodeHexString(wt.Shared, "shared secret")
1624
1625	// XXX The shared fields of the secp224k1 test cases have a 0 byte preprended.
1626	if sharedLen == int(secLen)+1 && shared[0] == 0 {
1627		fmt.Printf("INFO: %s - prepending 0 byte.\n", wt)
1628		// shared = shared[1:];
1629		zero := make([]byte, 1, secLen+1)
1630		secret = append(zero, secret...)
1631	}
1632
1633	success := true
1634	if !bytes.Equal(shared, secret) {
1635		fmt.Printf("FAIL: %s - expected and computed shared secret do not match.\n", wt)
1636		success = false
1637	}
1638	return success
1639}
1640
1641func (wtg *wycheproofTestGroupECDH) run(algorithm string, variant testVariant) bool {
1642	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n", algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1643
1644	nid, err := nidFromString(wtg.Curve)
1645	if err != nil {
1646		log.Fatalf("Failed to get nid for curve: %v", err)
1647	}
1648	if skipSmallCurve(nid) {
1649		return true
1650	}
1651
1652	success := true
1653	for _, wt := range wtg.Tests {
1654		if !runECDHTest(nid, variant, wt) {
1655			success = false
1656		}
1657	}
1658	return success
1659}
1660
1661func runECDHWebCryptoTest(nid int, wt *wycheproofTestECDHWebCrypto) bool {
1662	privKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1663	if privKey == nil {
1664		log.Fatalf("EC_KEY_new_by_curve_name failed")
1665	}
1666	defer C.EC_KEY_free(privKey)
1667
1668	d, err := base64.RawURLEncoding.DecodeString(wt.Private.D)
1669	if err != nil {
1670		log.Fatalf("Failed to base64 decode d: %v", err)
1671	}
1672	bnD := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&d[0])), C.int(len(d)), nil)
1673	if bnD == nil {
1674		log.Fatal("Failed to decode D")
1675	}
1676	defer C.BN_free(bnD)
1677
1678	ret := C.EC_KEY_set_private_key(privKey, bnD)
1679	if ret != 1 {
1680		fmt.Printf("FAIL: %s - EC_KEY_set_private_key() = %d.\n", wt, ret)
1681		return false
1682	}
1683
1684	x, err := base64.RawURLEncoding.DecodeString(wt.Public.X)
1685	if err != nil {
1686		log.Fatalf("Failed to base64 decode x: %v", err)
1687	}
1688	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1689	if bnX == nil {
1690		log.Fatal("Failed to decode X")
1691	}
1692	defer C.BN_free(bnX)
1693
1694	y, err := base64.RawURLEncoding.DecodeString(wt.Public.Y)
1695	if err != nil {
1696		log.Fatalf("Failed to base64 decode y: %v", err)
1697	}
1698	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1699	if bnY == nil {
1700		log.Fatal("Failed to decode Y")
1701	}
1702	defer C.BN_free(bnY)
1703
1704	pubKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1705	if pubKey == nil {
1706		log.Fatal("Failed to create EC_KEY")
1707	}
1708	defer C.EC_KEY_free(pubKey)
1709
1710	ret = C.EC_KEY_set_public_key_affine_coordinates(pubKey, bnX, bnY)
1711	if ret != 1 {
1712		if wt.Result == "invalid" {
1713			return true
1714		}
1715		fmt.Printf("FAIL: %s - EC_KEY_set_public_key_affine_coordinates() = %d.\n", wt, ret)
1716		return false
1717	}
1718	pubPoint := C.EC_KEY_get0_public_key(pubKey)
1719
1720	privGroup := C.EC_KEY_get0_group(privKey)
1721
1722	secLen := (C.EC_GROUP_get_degree(privGroup) + 7) / 8
1723
1724	secret := make([]byte, secLen)
1725	if secLen == 0 {
1726		secret = append(secret, 0)
1727	}
1728
1729	ret = C.ECDH_compute_key(unsafe.Pointer(&secret[0]), C.ulong(secLen), pubPoint, privKey, nil)
1730	if ret != C.int(secLen) {
1731		if wt.Result == "invalid" {
1732			return true
1733		}
1734		fmt.Printf("FAIL: %s - ECDH_compute_key() = %d, want %d.\n", wt, ret, int(secLen))
1735		return false
1736	}
1737
1738	shared, _ := mustDecodeHexString(wt.Shared, "shared secret")
1739
1740	success := true
1741	if !bytes.Equal(shared, secret) {
1742		fmt.Printf("FAIL: %s - expected and computed shared secret do not match.\n", wt)
1743		success = false
1744	}
1745	return success
1746}
1747
1748func (wtg *wycheproofTestGroupECDHWebCrypto) run(algorithm string, variant testVariant) bool {
1749	fmt.Printf("Running %v test group %v with curve %v and %v encoding...\n", algorithm, wtg.Type, wtg.Curve, wtg.Encoding)
1750
1751	nid, err := nidFromString(wtg.Curve)
1752	if err != nil {
1753		log.Fatalf("Failed to get nid for curve: %v", err)
1754	}
1755
1756	success := true
1757	for _, wt := range wtg.Tests {
1758		if !runECDHWebCryptoTest(nid, wt) {
1759			success = false
1760		}
1761	}
1762	return success
1763}
1764
1765func runECDSATest(ecKey *C.EC_KEY, md *C.EVP_MD, nid int, variant testVariant, wt *wycheproofTestECDSA) bool {
1766	msg, msgLen := mustHashHexMessage(md, wt.Msg)
1767
1768	var ret C.int
1769	if variant == Webcrypto || variant == P1363 {
1770		cDer, derLen := encodeECDSAWebCryptoSig(wt.Sig)
1771		if cDer == nil {
1772			fmt.Print("FAIL: unable to decode signature")
1773			return false
1774		}
1775		defer C.free(unsafe.Pointer(cDer))
1776
1777		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen), (*C.uchar)(unsafe.Pointer(cDer)), C.int(derLen), ecKey)
1778	} else {
1779		sig, sigLen := mustDecodeHexString(wt.Sig, "sig")
1780
1781		ret = C.ECDSA_verify(0, (*C.uchar)(unsafe.Pointer(&msg[0])), C.int(msgLen), (*C.uchar)(unsafe.Pointer(&sig[0])), C.int(sigLen), ecKey)
1782	}
1783
1784	// XXX audit acceptable cases...
1785	success := true
1786	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
1787		fmt.Printf("FAIL: %s - ECDSA_verify() = %d.\n", wt, int(ret))
1788		success = false
1789	}
1790	return success
1791}
1792
1793func (wtg *wycheproofTestGroupECDSA) run(algorithm string, variant testVariant) bool {
1794	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n", algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1795
1796	nid, err := nidFromString(wtg.Key.Curve)
1797	if err != nil {
1798		log.Fatalf("Failed to get nid for curve: %v", err)
1799	}
1800	if skipSmallCurve(nid) {
1801		return true
1802	}
1803	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1804	if ecKey == nil {
1805		log.Fatal("EC_KEY_new_by_curve_name failed")
1806	}
1807	defer C.EC_KEY_free(ecKey)
1808
1809	var bnX *C.BIGNUM
1810	wx := C.CString(wtg.Key.WX)
1811	if C.BN_hex2bn(&bnX, wx) == 0 {
1812		log.Fatal("Failed to decode WX")
1813	}
1814	C.free(unsafe.Pointer(wx))
1815	defer C.BN_free(bnX)
1816
1817	var bnY *C.BIGNUM
1818	wy := C.CString(wtg.Key.WY)
1819	if C.BN_hex2bn(&bnY, wy) == 0 {
1820		log.Fatal("Failed to decode WY")
1821	}
1822	C.free(unsafe.Pointer(wy))
1823	defer C.BN_free(bnY)
1824
1825	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1826		log.Fatal("Failed to set EC public key")
1827	}
1828
1829	nid, err = nidFromString(wtg.SHA)
1830	if err != nil {
1831		log.Fatalf("Failed to get MD NID: %v", err)
1832	}
1833	md, err := hashEvpMdFromString(wtg.SHA)
1834	if err != nil {
1835		log.Fatalf("Failed to get hash: %v", err)
1836	}
1837
1838	success := true
1839	for _, wt := range wtg.Tests {
1840		if !runECDSATest(ecKey, md, nid, variant, wt) {
1841			success = false
1842		}
1843	}
1844	return success
1845}
1846
1847// DER encode the signature (so that ECDSA_verify() can decode and encode it again...)
1848func encodeECDSAWebCryptoSig(wtSig string) (*C.uchar, C.int) {
1849	cSig := C.ECDSA_SIG_new()
1850	if cSig == nil {
1851		log.Fatal("ECDSA_SIG_new() failed")
1852	}
1853	defer C.ECDSA_SIG_free(cSig)
1854
1855	sigLen := len(wtSig)
1856	r := C.CString(wtSig[:sigLen/2])
1857	s := C.CString(wtSig[sigLen/2:])
1858	defer C.free(unsafe.Pointer(r))
1859	defer C.free(unsafe.Pointer(s))
1860	var sigR *C.BIGNUM
1861	var sigS *C.BIGNUM
1862	defer C.BN_free(sigR)
1863	defer C.BN_free(sigS)
1864	if C.BN_hex2bn(&sigR, r) == 0 {
1865		return nil, 0
1866	}
1867	if C.BN_hex2bn(&sigS, s) == 0 {
1868		return nil, 0
1869	}
1870	if C.ECDSA_SIG_set0(cSig, sigR, sigS) == 0 {
1871		return nil, 0
1872	}
1873	sigR = nil
1874	sigS = nil
1875
1876	derLen := C.i2d_ECDSA_SIG(cSig, nil)
1877	if derLen == 0 {
1878		return nil, 0
1879	}
1880	cDer := (*C.uchar)(C.malloc(C.ulong(derLen)))
1881	if cDer == nil {
1882		log.Fatal("malloc failed")
1883	}
1884
1885	p := cDer
1886	ret := C.i2d_ECDSA_SIG(cSig, (**C.uchar)(&p))
1887	if ret == 0 || ret != derLen {
1888		C.free(unsafe.Pointer(cDer))
1889		return nil, 0
1890	}
1891
1892	return cDer, derLen
1893}
1894
1895func (wtg *wycheproofTestGroupECDSAWebCrypto) run(algorithm string, variant testVariant) bool {
1896	fmt.Printf("Running %v test group %v with curve %v, key size %d and %v...\n", algorithm, wtg.Type, wtg.Key.Curve, wtg.Key.KeySize, wtg.SHA)
1897
1898	nid, err := nidFromString(wtg.JWK.Crv)
1899	if err != nil {
1900		log.Fatalf("Failed to get nid for curve: %v", err)
1901	}
1902	ecKey := C.EC_KEY_new_by_curve_name(C.int(nid))
1903	if ecKey == nil {
1904		log.Fatal("EC_KEY_new_by_curve_name failed")
1905	}
1906	defer C.EC_KEY_free(ecKey)
1907
1908	x, err := base64.RawURLEncoding.DecodeString(wtg.JWK.X)
1909	if err != nil {
1910		log.Fatalf("Failed to base64 decode X: %v", err)
1911	}
1912	bnX := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&x[0])), C.int(len(x)), nil)
1913	if bnX == nil {
1914		log.Fatal("Failed to decode X")
1915	}
1916	defer C.BN_free(bnX)
1917
1918	y, err := base64.RawURLEncoding.DecodeString(wtg.JWK.Y)
1919	if err != nil {
1920		log.Fatalf("Failed to base64 decode Y: %v", err)
1921	}
1922	bnY := C.BN_bin2bn((*C.uchar)(unsafe.Pointer(&y[0])), C.int(len(y)), nil)
1923	if bnY == nil {
1924		log.Fatal("Failed to decode Y")
1925	}
1926	defer C.BN_free(bnY)
1927
1928	if C.EC_KEY_set_public_key_affine_coordinates(ecKey, bnX, bnY) != 1 {
1929		log.Fatal("Failed to set EC public key")
1930	}
1931
1932	nid, err = nidFromString(wtg.SHA)
1933	if err != nil {
1934		log.Fatalf("Failed to get MD NID: %v", err)
1935	}
1936	md, err := hashEvpMdFromString(wtg.SHA)
1937	if err != nil {
1938		log.Fatalf("Failed to get hash: %v", err)
1939	}
1940
1941	success := true
1942	for _, wt := range wtg.Tests {
1943		if !runECDSATest(ecKey, md, nid, Webcrypto, wt) {
1944			success = false
1945		}
1946	}
1947	return success
1948}
1949
1950func runEdDSATest(pkey *C.EVP_PKEY, wt *wycheproofTestEdDSA) bool {
1951	mdctx := C.EVP_MD_CTX_new()
1952	if mdctx == nil {
1953		log.Fatal("EVP_MD_CTX_new failed")
1954	}
1955	defer C.EVP_MD_CTX_free(mdctx)
1956
1957	if C.EVP_DigestVerifyInit(mdctx, nil, nil, nil, pkey) != 1 {
1958		log.Fatal("EVP_DigestVerifyInit failed")
1959	}
1960
1961	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
1962	sig, sigLen := mustDecodeHexString(wt.Sig, "sig")
1963
1964	ret := C.EVP_DigestVerify(mdctx, (*C.uchar)(unsafe.Pointer(&sig[0])), (C.size_t)(sigLen), (*C.uchar)(unsafe.Pointer(&msg[0])), (C.size_t)(msgLen))
1965
1966	success := true
1967	if (ret == 1) != (wt.Result == "valid") {
1968		fmt.Printf("FAIL: %s - EVP_DigestVerify() = %d.\n", wt, int(ret))
1969		success = false
1970	}
1971	return success
1972}
1973
1974func (wtg *wycheproofTestGroupEdDSA) run(algorithm string, variant testVariant) bool {
1975	fmt.Printf("Running %v test group %v...\n", algorithm, wtg.Type)
1976
1977	if wtg.Key.Curve != "edwards25519" || wtg.Key.KeySize != 255 {
1978		fmt.Printf("INFO: Unexpected curve or key size. want (\"edwards25519\", 255), got (%q, %d)\n", wtg.Key.Curve, wtg.Key.KeySize)
1979		return false
1980	}
1981
1982	pubKey, pubKeyLen := mustDecodeHexString(wtg.Key.Pk, "pubkey")
1983
1984	pkey := C.EVP_PKEY_new_raw_public_key(C.EVP_PKEY_ED25519, nil, (*C.uchar)(unsafe.Pointer(&pubKey[0])), (C.size_t)(pubKeyLen))
1985	if pkey == nil {
1986		log.Fatal("EVP_PKEY_new_raw_public_key failed")
1987	}
1988	defer C.EVP_PKEY_free(pkey)
1989
1990	success := true
1991	for _, wt := range wtg.Tests {
1992		if !runEdDSATest(pkey, wt) {
1993			success = false
1994		}
1995	}
1996	return success
1997}
1998
1999func runHkdfTest(md *C.EVP_MD, wt *wycheproofTestHkdf) bool {
2000	ikm, ikmLen := mustDecodeHexString(wt.Ikm, "ikm")
2001	salt, saltLen := mustDecodeHexString(wt.Salt, "salt")
2002	info, infoLen := mustDecodeHexString(wt.Info, "info")
2003
2004	outLen := wt.Size
2005	out := make([]byte, outLen)
2006	if outLen == 0 {
2007		out = append(out, 0)
2008	}
2009
2010	pctx := C.EVP_PKEY_CTX_new_id(C.EVP_PKEY_HKDF, nil)
2011	if pctx == nil {
2012		log.Fatalf("EVP_PKEY_CTX_new_id failed")
2013	}
2014	defer C.EVP_PKEY_CTX_free(pctx)
2015
2016	ret := C.EVP_PKEY_derive_init(pctx)
2017	if ret <= 0 {
2018		log.Fatalf("EVP_PKEY_derive_init failed, want 1, got %d", ret)
2019	}
2020
2021	ret = C.wp_EVP_PKEY_CTX_set_hkdf_md(pctx, md)
2022	if ret <= 0 {
2023		log.Fatalf("EVP_PKEY_CTX_set_hkdf_md failed, want 1, got %d", ret)
2024	}
2025
2026	ret = C.wp_EVP_PKEY_CTX_set1_hkdf_salt(pctx, (*C.uchar)(&salt[0]), C.size_t(saltLen))
2027	if ret <= 0 {
2028		log.Fatalf("EVP_PKEY_CTX_set1_hkdf_salt failed, want 1, got %d", ret)
2029	}
2030
2031	ret = C.wp_EVP_PKEY_CTX_set1_hkdf_key(pctx, (*C.uchar)(&ikm[0]), C.size_t(ikmLen))
2032	if ret <= 0 {
2033		log.Fatalf("EVP_PKEY_CTX_set1_hkdf_key failed, want 1, got %d", ret)
2034	}
2035
2036	ret = C.wp_EVP_PKEY_CTX_add1_hkdf_info(pctx, (*C.uchar)(&info[0]), C.size_t(infoLen))
2037	if ret <= 0 {
2038		log.Fatalf("EVP_PKEY_CTX_add1_hkdf_info failed, want 1, got %d", ret)
2039	}
2040
2041	ret = C.EVP_PKEY_derive(pctx, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.size_t)(unsafe.Pointer(&outLen)))
2042	if ret <= 0 {
2043		success := wt.Result == "invalid"
2044		if !success {
2045			fmt.Printf("FAIL: %s - got %d.\n", wt, ret)
2046		}
2047		return success
2048	}
2049
2050	okm, _ := mustDecodeHexString(wt.Okm, "okm")
2051	if !bytes.Equal(out[:outLen], okm) {
2052		fmt.Printf("FAIL: %s - expected and computed output don't match.\n", wt)
2053	}
2054
2055	return wt.Result == "valid"
2056}
2057
2058func (wtg *wycheproofTestGroupHkdf) run(algorithm string, variant testVariant) bool {
2059	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2060	md, err := hashEvpMdFromString(strings.TrimPrefix(algorithm, "HKDF-"))
2061	if err != nil {
2062		log.Fatalf("Failed to get hash: %v", err)
2063	}
2064
2065	success := true
2066	for _, wt := range wtg.Tests {
2067		if !runHkdfTest(md, wt) {
2068			success = false
2069		}
2070	}
2071	return success
2072}
2073
2074func runHmacTest(md *C.EVP_MD, tagBytes int, wt *wycheproofTestHmac) bool {
2075	key, keyLen := mustDecodeHexString(wt.Key, "key")
2076	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
2077
2078	got := make([]byte, C.EVP_MAX_MD_SIZE)
2079	var gotLen C.uint
2080
2081	ret := C.HMAC(md, unsafe.Pointer(&key[0]), C.int(keyLen), (*C.uchar)(unsafe.Pointer(&msg[0])), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&got[0])), &gotLen)
2082
2083	success := true
2084	if ret == nil {
2085		if wt.Result != "invalid" {
2086			success = false
2087			fmt.Printf("FAIL: Test case %s - HMAC: got nil.\n", wt)
2088		}
2089		return success
2090	}
2091
2092	if int(gotLen) < tagBytes {
2093		fmt.Printf("FAIL: %s - HMAC length: got %d, want %d.\n", wt, gotLen, tagBytes)
2094		return false
2095	}
2096
2097	tag, _ := mustDecodeHexString(wt.Tag, "tag")
2098	success = bytes.Equal(got[:tagBytes], tag) == (wt.Result == "valid")
2099
2100	if !success {
2101		fmt.Printf("FAIL: %s - got %t.\n", wt, success)
2102	}
2103
2104	return success
2105}
2106
2107func (wtg *wycheproofTestGroupHmac) run(algorithm string, variant testVariant) bool {
2108	fmt.Printf("Running %v test group %v with key size %d and tag size %d...\n", algorithm, wtg.Type, wtg.KeySize, wtg.TagSize)
2109	prefix := "SHA-"
2110	if strings.HasPrefix(algorithm, "HMACSHA3-") {
2111		prefix = "SHA"
2112	}
2113	md, err := hashEvpMdFromString(prefix + strings.TrimPrefix(algorithm, "HMACSHA"))
2114	if err != nil {
2115		log.Fatalf("Failed to get hash: %v", err)
2116	}
2117
2118	success := true
2119	for _, wt := range wtg.Tests {
2120		if !runHmacTest(md, wtg.TagSize/8, wt) {
2121			success = false
2122		}
2123	}
2124	return success
2125}
2126
2127func runKWTestWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2128	var aesKey C.AES_KEY
2129
2130	ret := C.AES_set_encrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2131	if ret != 0 {
2132		fmt.Printf("FAIL: %s - AES_set_encrypt_key() = %d.\n", wt, int(ret))
2133		return false
2134	}
2135
2136	outLen := msgLen
2137	out := make([]byte, outLen)
2138	copy(out, msg)
2139	out = append(out, make([]byte, 8)...)
2140	ret = C.AES_wrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(msgLen))
2141	success := false
2142	if ret == C.int(len(out)) && bytes.Equal(out, ct) {
2143		if wt.Result != "invalid" {
2144			success = true
2145		}
2146	} else if wt.Result != "valid" {
2147		success = true
2148	}
2149	if !success {
2150		fmt.Printf("FAIL: %s - msgLen = %d, AES_wrap_key() = %d.\n", wt, msgLen, int(ret))
2151	}
2152	return success
2153}
2154
2155func runKWTestUnWrap(keySize int, key []byte, keyLen int, msg []byte, msgLen int, ct []byte, ctLen int, wt *wycheproofTestKW) bool {
2156	var aesKey C.AES_KEY
2157
2158	ret := C.AES_set_decrypt_key((*C.uchar)(unsafe.Pointer(&key[0])), (C.int)(keySize), (*C.AES_KEY)(unsafe.Pointer(&aesKey)))
2159	if ret != 0 {
2160		fmt.Printf("FAIL: %s - AES_set_encrypt_key() = %d.\n", wt, int(ret))
2161		return false
2162	}
2163
2164	out := make([]byte, ctLen)
2165	copy(out, ct)
2166	if ctLen == 0 {
2167		out = append(out, 0)
2168	}
2169	ret = C.AES_unwrap_key((*C.AES_KEY)(unsafe.Pointer(&aesKey)), nil, (*C.uchar)(unsafe.Pointer(&out[0])), (*C.uchar)(unsafe.Pointer(&out[0])), (C.uint)(ctLen))
2170	success := false
2171	if ret == C.int(ctLen-8) && bytes.Equal(out[0:ret], msg[0:ret]) {
2172		if wt.Result != "invalid" {
2173			success = true
2174		}
2175	} else if wt.Result != "valid" {
2176		success = true
2177	}
2178	if !success {
2179		fmt.Printf("FAIL: %s - keyLen = %d, AES_unwrap_key() = %d.\n", wt, keyLen, int(ret))
2180	}
2181	return success
2182}
2183
2184func runKWTest(keySize int, wt *wycheproofTestKW) bool {
2185	key, keyLen := mustDecodeHexString(wt.Key, "key")
2186	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
2187	ct, ctLen := mustDecodeHexString(wt.CT, "CT")
2188
2189	wrapSuccess := runKWTestWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2190	unwrapSuccess := runKWTestUnWrap(keySize, key, keyLen, msg, msgLen, ct, ctLen, wt)
2191
2192	return wrapSuccess && unwrapSuccess
2193}
2194
2195func (wtg *wycheproofTestGroupKW) run(algorithm string, variant testVariant) bool {
2196	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2197
2198	success := true
2199	for _, wt := range wtg.Tests {
2200		if !runKWTest(wtg.KeySize, wt) {
2201			success = false
2202		}
2203	}
2204	return success
2205}
2206
2207func runPrimalityTest(wt *wycheproofTestPrimality) bool {
2208	var bnValue *C.BIGNUM
2209	value := C.CString(wt.Value)
2210	if C.BN_hex2bn(&bnValue, value) == 0 {
2211		log.Fatal("Failed to set bnValue")
2212	}
2213	C.free(unsafe.Pointer(value))
2214	defer C.BN_free(bnValue)
2215
2216	ret := C.BN_is_prime_ex(bnValue, C.BN_prime_checks, (*C.BN_CTX)(unsafe.Pointer(nil)), (*C.BN_GENCB)(unsafe.Pointer(nil)))
2217	success := wt.Result == "acceptable" || (ret == 0 && wt.Result == "invalid") || (ret == 1 && wt.Result == "valid")
2218	if !success {
2219		fmt.Printf("FAIL: %s - got %d.\n", wt, ret)
2220	}
2221	return success
2222}
2223
2224func (wtg *wycheproofTestGroupPrimality) run(algorithm string, variant testVariant) bool {
2225	fmt.Printf("Running %v test group...\n", algorithm)
2226
2227	success := true
2228	for _, wt := range wtg.Tests {
2229		if !runPrimalityTest(wt) {
2230			success = false
2231		}
2232	}
2233	return success
2234}
2235
2236func runRsaesOaepTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, wt *wycheproofTestRsaes) bool {
2237	ct, ctLen := mustDecodeHexString(wt.CT, "CT")
2238
2239	rsaSize := C.RSA_size(rsa)
2240	decrypted := make([]byte, rsaSize)
2241
2242	success := true
2243
2244	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_NO_PADDING)
2245
2246	if ret != rsaSize {
2247		success = (wt.Result == "invalid")
2248
2249		if !success {
2250			fmt.Printf("FAIL: %s - RSA size got %d, want %d.\n", wt, ret, rsaSize)
2251		}
2252		return success
2253	}
2254
2255	label, labelLen := mustDecodeHexString(wt.Label, "label")
2256	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
2257
2258	to := make([]byte, rsaSize)
2259
2260	ret = C.RSA_padding_check_PKCS1_OAEP_mgf1((*C.uchar)(unsafe.Pointer(&to[0])), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&decrypted[0])), C.int(rsaSize), C.int(rsaSize), (*C.uchar)(unsafe.Pointer(&label[0])), C.int(labelLen), sha, mgfSha)
2261
2262	if int(ret) != msgLen {
2263		success = (wt.Result == "invalid")
2264
2265		if !success {
2266			fmt.Printf("FAIL: %s - got %d, want %d.\n", wt, ret, msgLen)
2267		}
2268		return success
2269	}
2270
2271	to = to[:msgLen]
2272	if !bytes.Equal(msg[:msgLen], to) {
2273		success = false
2274		fmt.Printf("FAIL: %s - expected and calculated message differ.\n", wt)
2275	}
2276
2277	return success
2278}
2279
2280func (wtg *wycheproofTestGroupRsaesOaep) run(algorithm string, variant testVariant) bool {
2281	fmt.Printf("Running %v test group %v with key size %d MGF %v and %v...\n", algorithm, wtg.Type, wtg.KeySize, wtg.MGFSHA, wtg.SHA)
2282
2283	rsa := C.RSA_new()
2284	if rsa == nil {
2285		log.Fatal("RSA_new failed")
2286	}
2287	defer C.RSA_free(rsa)
2288
2289	d := C.CString(wtg.D)
2290	var rsaD *C.BIGNUM
2291	defer C.BN_free(rsaD)
2292	if C.BN_hex2bn(&rsaD, d) == 0 {
2293		log.Fatal("Failed to set RSA d")
2294	}
2295	C.free(unsafe.Pointer(d))
2296
2297	e := C.CString(wtg.E)
2298	var rsaE *C.BIGNUM
2299	defer C.BN_free(rsaE)
2300	if C.BN_hex2bn(&rsaE, e) == 0 {
2301		log.Fatal("Failed to set RSA e")
2302	}
2303	C.free(unsafe.Pointer(e))
2304
2305	n := C.CString(wtg.N)
2306	var rsaN *C.BIGNUM
2307	defer C.BN_free(rsaN)
2308	if C.BN_hex2bn(&rsaN, n) == 0 {
2309		log.Fatal("Failed to set RSA n")
2310	}
2311	C.free(unsafe.Pointer(n))
2312
2313	if C.RSA_set0_key(rsa, rsaN, rsaE, rsaD) == 0 {
2314		log.Fatal("RSA_set0_key failed")
2315	}
2316	rsaN = nil
2317	rsaE = nil
2318	rsaD = nil
2319
2320	sha, err := hashEvpMdFromString(wtg.SHA)
2321	if err != nil {
2322		log.Fatalf("Failed to get hash: %v", err)
2323	}
2324
2325	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2326	if err != nil {
2327		log.Fatalf("Failed to get MGF hash: %v", err)
2328	}
2329
2330	success := true
2331	for _, wt := range wtg.Tests {
2332		if !runRsaesOaepTest(rsa, sha, mgfSha, wt) {
2333			success = false
2334		}
2335	}
2336	return success
2337}
2338
2339func runRsaesPkcs1Test(rsa *C.RSA, wt *wycheproofTestRsaes) bool {
2340	ct, ctLen := mustDecodeHexString(wt.CT, "CT")
2341
2342	rsaSize := C.RSA_size(rsa)
2343	decrypted := make([]byte, rsaSize)
2344
2345	success := true
2346
2347	ret := C.RSA_private_decrypt(C.int(ctLen), (*C.uchar)(unsafe.Pointer(&ct[0])), (*C.uchar)(unsafe.Pointer(&decrypted[0])), rsa, C.RSA_PKCS1_PADDING)
2348
2349	if ret == -1 {
2350		success = (wt.Result == "invalid")
2351
2352		if !success {
2353			fmt.Printf("FAIL: %s - got %d, want %d.\n", wt, ret, len(wt.Msg)/2)
2354		}
2355		return success
2356	}
2357
2358	msg, msgLen := mustDecodeHexString(wt.Msg, "msg")
2359
2360	if int(ret) != msgLen {
2361		success = false
2362		fmt.Printf("FAIL: %s - got %d, want %d.\n", wt, ret, msgLen)
2363	} else if !bytes.Equal(msg[:msgLen], decrypted[:msgLen]) {
2364		success = false
2365		fmt.Printf("FAIL: %s - expected and calculated message differ.\n", wt)
2366	}
2367
2368	return success
2369}
2370
2371func (wtg *wycheproofTestGroupRsaesPkcs1) run(algorithm string, variant testVariant) bool {
2372	fmt.Printf("Running %v test group %v with key size %d...\n", algorithm, wtg.Type, wtg.KeySize)
2373	rsa := C.RSA_new()
2374	if rsa == nil {
2375		log.Fatal("RSA_new failed")
2376	}
2377	defer C.RSA_free(rsa)
2378
2379	d := C.CString(wtg.D)
2380	var rsaD *C.BIGNUM
2381	defer C.BN_free(rsaD)
2382	if C.BN_hex2bn(&rsaD, d) == 0 {
2383		log.Fatal("Failed to set RSA d")
2384	}
2385	C.free(unsafe.Pointer(d))
2386
2387	e := C.CString(wtg.E)
2388	var rsaE *C.BIGNUM
2389	defer C.BN_free(rsaE)
2390	if C.BN_hex2bn(&rsaE, e) == 0 {
2391		log.Fatal("Failed to set RSA e")
2392	}
2393	C.free(unsafe.Pointer(e))
2394
2395	n := C.CString(wtg.N)
2396	var rsaN *C.BIGNUM
2397	defer C.BN_free(rsaN)
2398	if C.BN_hex2bn(&rsaN, n) == 0 {
2399		log.Fatal("Failed to set RSA n")
2400	}
2401	C.free(unsafe.Pointer(n))
2402
2403	if C.RSA_set0_key(rsa, rsaN, rsaE, rsaD) == 0 {
2404		log.Fatal("RSA_set0_key failed")
2405	}
2406	rsaN = nil
2407	rsaE = nil
2408	rsaD = nil
2409
2410	success := true
2411	for _, wt := range wtg.Tests {
2412		if !runRsaesPkcs1Test(rsa, wt) {
2413			success = false
2414		}
2415	}
2416	return success
2417}
2418
2419func runRsassaTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
2420	msg, _ := mustHashHexMessage(sha, wt.Msg)
2421	sig, sigLen := mustDecodeHexString(wt.Sig, "sig")
2422
2423	sigOut := make([]byte, C.RSA_size(rsa)-11)
2424	if sigLen == 0 {
2425		sigOut = append(sigOut, 0)
2426	}
2427
2428	ret := C.RSA_public_decrypt(C.int(sigLen), (*C.uchar)(unsafe.Pointer(&sig[0])), (*C.uchar)(unsafe.Pointer(&sigOut[0])), rsa, C.RSA_NO_PADDING)
2429	if ret == -1 {
2430		if wt.Result == "invalid" {
2431			return true
2432		}
2433		fmt.Printf("FAIL: %s - RSA_public_decrypt() = %d.\n", wt, int(ret))
2434		return false
2435	}
2436
2437	ret = C.RSA_verify_PKCS1_PSS_mgf1(rsa, (*C.uchar)(unsafe.Pointer(&msg[0])), sha, mgfSha, (*C.uchar)(unsafe.Pointer(&sigOut[0])), C.int(sLen))
2438
2439	success := false
2440	if ret == 1 && (wt.Result == "valid" || wt.Result == "acceptable") {
2441		// All acceptable cases that pass use SHA-1 and are flagged:
2442		// "WeakHash" : "The key for this test vector uses a weak hash function."
2443		success = true
2444	} else if ret == 0 && (wt.Result == "invalid" || wt.Result == "acceptable") {
2445		success = true
2446	} else {
2447		fmt.Printf("FAIL: %s - RSA_verify_PKCS1_PSS_mgf1() = %d.\n", wt, int(ret))
2448	}
2449	return success
2450}
2451
2452func (wtg *wycheproofTestGroupRsassa) run(algorithm string, variant testVariant) bool {
2453	fmt.Printf("Running %v test group %v with key size %d and %v...\n", algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2454	rsa := C.RSA_new()
2455	if rsa == nil {
2456		log.Fatal("RSA_new failed")
2457	}
2458	defer C.RSA_free(rsa)
2459
2460	e := C.CString(wtg.E)
2461	var rsaE *C.BIGNUM
2462	defer C.BN_free(rsaE)
2463	if C.BN_hex2bn(&rsaE, e) == 0 {
2464		log.Fatal("Failed to set RSA e")
2465	}
2466	C.free(unsafe.Pointer(e))
2467
2468	n := C.CString(wtg.N)
2469	var rsaN *C.BIGNUM
2470	defer C.BN_free(rsaN)
2471	if C.BN_hex2bn(&rsaN, n) == 0 {
2472		log.Fatal("Failed to set RSA n")
2473	}
2474	C.free(unsafe.Pointer(n))
2475
2476	if C.RSA_set0_key(rsa, rsaN, rsaE, nil) == 0 {
2477		log.Fatal("RSA_set0_key failed")
2478	}
2479	rsaN = nil
2480	rsaE = nil
2481
2482	sha, err := hashEvpMdFromString(wtg.SHA)
2483	if err != nil {
2484		log.Fatalf("Failed to get hash: %v", err)
2485	}
2486
2487	mgfSha, err := hashEvpMdFromString(wtg.MGFSHA)
2488	if err != nil {
2489		log.Fatalf("Failed to get MGF hash: %v", err)
2490	}
2491
2492	success := true
2493	for _, wt := range wtg.Tests {
2494		if !runRsassaTest(rsa, sha, mgfSha, wtg.SLen, wt) {
2495			success = false
2496		}
2497	}
2498	return success
2499}
2500
2501func runRSATest(rsa *C.RSA, md *C.EVP_MD, nid int, wt *wycheproofTestRSA) bool {
2502	msg, msgLen := mustHashHexMessage(md, wt.Msg)
2503	sig, sigLen := mustDecodeHexString(wt.Sig, "sig")
2504
2505	ret := C.RSA_verify(C.int(nid), (*C.uchar)(unsafe.Pointer(&msg[0])), C.uint(msgLen), (*C.uchar)(unsafe.Pointer(&sig[0])), C.uint(sigLen), rsa)
2506
2507	// XXX audit acceptable cases...
2508	success := true
2509	if ret == 1 != (wt.Result == "valid") && wt.Result != "acceptable" {
2510		fmt.Printf("FAIL: %s - RSA_verify() = %d.\n", wt, int(ret))
2511		success = false
2512	}
2513	return success
2514}
2515
2516func (wtg *wycheproofTestGroupRSA) run(algorithm string, variant testVariant) bool {
2517	fmt.Printf("Running %v test group %v with key size %d and %v...\n", algorithm, wtg.Type, wtg.KeySize, wtg.SHA)
2518
2519	rsa := C.RSA_new()
2520	if rsa == nil {
2521		log.Fatal("RSA_new failed")
2522	}
2523	defer C.RSA_free(rsa)
2524
2525	e := C.CString(wtg.E)
2526	var rsaE *C.BIGNUM
2527	defer C.BN_free(rsaE)
2528	if C.BN_hex2bn(&rsaE, e) == 0 {
2529		log.Fatal("Failed to set RSA e")
2530	}
2531	C.free(unsafe.Pointer(e))
2532
2533	n := C.CString(wtg.N)
2534	var rsaN *C.BIGNUM
2535	defer C.BN_free(rsaN)
2536	if C.BN_hex2bn(&rsaN, n) == 0 {
2537		log.Fatal("Failed to set RSA n")
2538	}
2539	C.free(unsafe.Pointer(n))
2540
2541	if C.RSA_set0_key(rsa, rsaN, rsaE, nil) == 0 {
2542		log.Fatal("RSA_set0_key failed")
2543	}
2544	rsaN = nil
2545	rsaE = nil
2546
2547	nid, err := nidFromString(wtg.SHA)
2548	if err != nil {
2549		log.Fatalf("Failed to get MD NID: %v", err)
2550	}
2551	md, err := hashEvpMdFromString(wtg.SHA)
2552	if err != nil {
2553		log.Fatalf("Failed to get hash: %v", err)
2554	}
2555
2556	success := true
2557	for _, wt := range wtg.Tests {
2558		if !runRSATest(rsa, md, nid, wt) {
2559			success = false
2560		}
2561	}
2562	return success
2563}
2564
2565func runX25519Test(wt *wycheproofTestX25519) bool {
2566	public, _ := mustDecodeHexString(wt.Public, "public")
2567	private, _ := mustDecodeHexString(wt.Private, "private")
2568	shared, _ := mustDecodeHexString(wt.Shared, "shared")
2569
2570	got := make([]byte, C.X25519_KEY_LENGTH)
2571	result := true
2572
2573	if C.X25519((*C.uint8_t)(unsafe.Pointer(&got[0])), (*C.uint8_t)(unsafe.Pointer(&private[0])), (*C.uint8_t)(unsafe.Pointer(&public[0]))) != 1 {
2574		result = false
2575	} else {
2576		result = bytes.Equal(got, shared)
2577	}
2578
2579	// XXX audit acceptable cases...
2580	success := true
2581	if result != (wt.Result == "valid") && wt.Result != "acceptable" {
2582		fmt.Printf("FAIL: %s - X25519().\n", wt)
2583		success = false
2584	}
2585	return success
2586}
2587
2588func (wtg *wycheproofTestGroupX25519) run(algorithm string, variant testVariant) bool {
2589	fmt.Printf("Running %v test group with curve %v...\n", algorithm, wtg.Curve)
2590
2591	success := true
2592	for _, wt := range wtg.Tests {
2593		if !runX25519Test(wt) {
2594			success = false
2595		}
2596	}
2597	return success
2598}
2599
2600func testGroupFromAlgorithm(algorithm string, variant testVariant) wycheproofTestGroupRunner {
2601	if algorithm == "ECDH" && variant == Webcrypto {
2602		return &wycheproofTestGroupECDHWebCrypto{}
2603	}
2604	if algorithm == "ECDSA" && variant == Webcrypto {
2605		return &wycheproofTestGroupECDSAWebCrypto{}
2606	}
2607	switch algorithm {
2608	case "AES-CBC-PKCS5":
2609		return &wycheproofTestGroupAesCbcPkcs5{}
2610	case "AES-CCM", "AES-GCM":
2611		return &wycheproofTestGroupAesAead{}
2612	case "AES-CMAC":
2613		return &wycheproofTestGroupAesCmac{}
2614	case "CHACHA20-POLY1305", "XCHACHA20-POLY1305":
2615		return &wycheproofTestGroupChaCha{}
2616	case "DSA":
2617		return &wycheproofTestGroupDSA{}
2618	case "ECDH":
2619		return &wycheproofTestGroupECDH{}
2620	case "ECDSA":
2621		return &wycheproofTestGroupECDSA{}
2622	case "EDDSA":
2623		return &wycheproofTestGroupEdDSA{}
2624	case "HKDF-SHA-1", "HKDF-SHA-256", "HKDF-SHA-384", "HKDF-SHA-512":
2625		return &wycheproofTestGroupHkdf{}
2626	case "HMACSHA1", "HMACSHA224", "HMACSHA256", "HMACSHA384", "HMACSHA512", "HMACSHA3-224", "HMACSHA3-256", "HMACSHA3-384", "HMACSHA3-512":
2627		return &wycheproofTestGroupHmac{}
2628	case "KW":
2629		return &wycheproofTestGroupKW{}
2630	case "PrimalityTest":
2631		return &wycheproofTestGroupPrimality{}
2632	case "RSAES-OAEP":
2633		return &wycheproofTestGroupRsaesOaep{}
2634	case "RSAES-PKCS1-v1_5":
2635		return &wycheproofTestGroupRsaesPkcs1{}
2636	case "RSASSA-PSS":
2637		return &wycheproofTestGroupRsassa{}
2638	case "RSASSA-PKCS1-v1_5", "RSASig":
2639		return &wycheproofTestGroupRSA{}
2640	case "XDH", "X25519":
2641		return &wycheproofTestGroupX25519{}
2642	default:
2643		return nil
2644	}
2645}
2646
2647func runTestVectors(path string, variant testVariant) bool {
2648	b, err := ioutil.ReadFile(path)
2649	if err != nil {
2650		log.Fatalf("Failed to read test vectors: %v", err)
2651	}
2652	wtv := &wycheproofTestVectors{}
2653	if err := json.Unmarshal(b, wtv); err != nil {
2654		log.Fatalf("Failed to unmarshal JSON: %v", err)
2655	}
2656	fmt.Printf("Loaded Wycheproof test vectors for %v with %d tests from %q\n", wtv.Algorithm, wtv.NumberOfTests, filepath.Base(path))
2657
2658	success := true
2659	for _, tg := range wtv.TestGroups {
2660		wtg := testGroupFromAlgorithm(wtv.Algorithm, variant)
2661		if wtg == nil {
2662			log.Printf("INFO: Unknown test vector algorithm %q", wtv.Algorithm)
2663			return false
2664		}
2665		if err := json.Unmarshal(tg, wtg); err != nil {
2666			log.Fatalf("Failed to unmarshal test groups JSON: %v", err)
2667		}
2668		testc.runTest(func() bool {
2669			return wtg.run(wtv.Algorithm, variant)
2670		})
2671	}
2672	for _ = range wtv.TestGroups {
2673		result := <-testc.resultCh
2674		if !result {
2675			success = false
2676		}
2677	}
2678	return success
2679}
2680
2681type testCoordinator struct {
2682	testFuncCh chan func() bool
2683	resultCh   chan bool
2684}
2685
2686func newTestCoordinator() *testCoordinator {
2687	runnerCount := runtime.NumCPU()
2688	tc := &testCoordinator{
2689		testFuncCh: make(chan func() bool, runnerCount),
2690		resultCh:   make(chan bool, 1024),
2691	}
2692	for i := 0; i < runnerCount; i++ {
2693		go tc.testRunner(tc.testFuncCh, tc.resultCh)
2694	}
2695	return tc
2696}
2697
2698func (tc *testCoordinator) testRunner(testFuncCh <-chan func() bool, resultCh chan<- bool) {
2699	for testFunc := range testFuncCh {
2700		select {
2701		case resultCh <- testFunc():
2702		default:
2703			log.Fatal("result channel is full")
2704		}
2705	}
2706}
2707
2708func (tc *testCoordinator) runTest(testFunc func() bool) {
2709	tc.testFuncCh <- testFunc
2710}
2711
2712func (tc *testCoordinator) shutdown() {
2713	close(tc.testFuncCh)
2714}
2715
2716func main() {
2717	if _, err := os.Stat(testVectorPath); os.IsNotExist(err) {
2718		fmt.Printf("package wycheproof-testvectors is required for this regress\n")
2719		fmt.Printf("SKIPPED\n")
2720		os.Exit(0)
2721	}
2722
2723	tests := []struct {
2724		name    string
2725		pattern string
2726		variant testVariant
2727	}{
2728		{"AES", "aes_[cg]*[^xv]_test.json", Normal}, // Skip AES-EAX, AES-GCM-SIV and AES-SIV-CMAC.
2729		{"ChaCha20-Poly1305", "chacha20_poly1305_test.json", Normal},
2730		{"DSA", "dsa_*test.json", Normal},
2731		{"DSA", "dsa_*_p1363_test.json", P1363},
2732		{"ECDH", "ecdh_test.json", Normal},
2733		{"ECDH", "ecdh_[^w_]*_test.json", Normal},
2734		{"ECDH EcPoint", "ecdh_*_ecpoint_test.json", EcPoint},
2735		{"ECDH webcrypto", "ecdh_webcrypto_test.json", Webcrypto},
2736		{"ECDSA", "ecdsa_test.json", Normal},
2737		{"ECDSA", "ecdsa_[^w]*test.json", Normal},
2738		{"ECDSA P1363", "ecdsa_*_p1363_test.json", P1363},
2739		{"ECDSA webcrypto", "ecdsa_webcrypto_test.json", Webcrypto},
2740		{"EDDSA", "eddsa_test.json", Normal},
2741		{"ED448", "ed448_test.json", Skip},
2742		{"HKDF", "hkdf_sha*_test.json", Normal},
2743		{"HMAC", "hmac_sha*_test.json", Normal},
2744		{"JSON webcrypto", "json_web_*_test.json", Skip},
2745		{"KW", "kw_test.json", Normal},
2746		{"Primality test", "primality_test.json", Normal},
2747		{"RSA", "rsa_*test.json", Normal},
2748		{"X25519", "x25519_test.json", Normal},
2749		{"X25519 ASN", "x25519_asn_test.json", Skip},
2750		{"X25519 JWK", "x25519_jwk_test.json", Skip},
2751		{"X25519 PEM", "x25519_pem_test.json", Skip},
2752		{"XCHACHA20-POLY1305", "xchacha20_poly1305_test.json", Normal},
2753	}
2754
2755	success := true
2756
2757	var wg sync.WaitGroup
2758
2759	vectorsRateLimitCh := make(chan bool, 4)
2760	for i := 0; i < cap(vectorsRateLimitCh); i++ {
2761		vectorsRateLimitCh <- true
2762	}
2763	resultCh := make(chan bool, 1024)
2764
2765	testc = newTestCoordinator()
2766
2767	skipNormal := regexp.MustCompile(`_(ecpoint|p1363|sect\d{3}[rk]1|secp(160|192))_`)
2768
2769	for _, test := range tests {
2770		tvs, err := filepath.Glob(filepath.Join(testVectorPath, test.pattern))
2771		if err != nil {
2772			log.Fatalf("Failed to glob %v test vectors: %v", test.name, err)
2773		}
2774		if len(tvs) == 0 {
2775			log.Fatalf("Failed to find %v test vectors at %q\n", test.name, testVectorPath)
2776		}
2777		for _, tv := range tvs {
2778			if test.variant == Skip || (test.variant == Normal && skipNormal.Match([]byte(tv))) {
2779				fmt.Printf("INFO: Skipping tests from \"%s\"\n", strings.TrimPrefix(tv, testVectorPath+"/"))
2780				continue
2781			}
2782			wg.Add(1)
2783			<-vectorsRateLimitCh
2784			go func(tv string, variant testVariant) {
2785				select {
2786				case resultCh <- runTestVectors(tv, variant):
2787				default:
2788					log.Fatal("result channel is full")
2789				}
2790				vectorsRateLimitCh <- true
2791				wg.Done()
2792			}(tv, test.variant)
2793		}
2794	}
2795
2796	wg.Wait()
2797	close(resultCh)
2798
2799	for result := range resultCh {
2800		if !result {
2801			success = false
2802		}
2803	}
2804
2805	testc.shutdown()
2806
2807	C.OPENSSL_cleanup()
2808
2809	if !success {
2810		os.Exit(1)
2811	}
2812}
2813