1 package org.bouncycastle.pqc.crypto.ntru;
2 
3 import java.io.ByteArrayInputStream;
4 import java.io.ByteArrayOutputStream;
5 import java.io.IOException;
6 import java.io.InputStream;
7 import java.io.OutputStream;
8 import java.util.ArrayList;
9 import java.util.List;
10 
11 import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
12 import org.bouncycastle.pqc.math.ntru.polynomial.DenseTernaryPolynomial;
13 import org.bouncycastle.pqc.math.ntru.polynomial.IntegerPolynomial;
14 import org.bouncycastle.pqc.math.ntru.polynomial.Polynomial;
15 import org.bouncycastle.pqc.math.ntru.polynomial.ProductFormPolynomial;
16 import org.bouncycastle.pqc.math.ntru.polynomial.SparseTernaryPolynomial;
17 
18 /**
19  * A NtruSign private key comprises one or more {@link NTRUSigningPrivateKeyParameters.Basis} of three polynomials each,
20  * except the zeroth basis for which <code>h</code> is undefined.
21  */
22 public class NTRUSigningPrivateKeyParameters
23     extends AsymmetricKeyParameter
24 {
25     private List<Basis> bases;
26     private NTRUSigningPublicKeyParameters publicKey;
27 
28     /**
29      * Constructs a new private key from a byte array
30      *
31      * @param b      an encoded private key
32      * @param params the NtruSign parameters to use
33      */
NTRUSigningPrivateKeyParameters(byte[] b, NTRUSigningKeyGenerationParameters params)34     public NTRUSigningPrivateKeyParameters(byte[] b, NTRUSigningKeyGenerationParameters params)
35         throws IOException
36     {
37         this(new ByteArrayInputStream(b), params);
38     }
39 
40     /**
41      * Constructs a new private key from an input stream
42      *
43      * @param is     an input stream
44      * @param params the NtruSign parameters to use
45      */
NTRUSigningPrivateKeyParameters(InputStream is, NTRUSigningKeyGenerationParameters params)46     public NTRUSigningPrivateKeyParameters(InputStream is, NTRUSigningKeyGenerationParameters params)
47         throws IOException
48     {
49         super(true);
50         bases = new ArrayList<Basis>();
51         for (int i = 0; i <= params.B; i++)
52         // include a public key h[i] in all bases except for the first one
53         {
54             add(new Basis(is, params, i != 0));
55         }
56         publicKey = new NTRUSigningPublicKeyParameters(is, params.getSigningParameters());
57     }
58 
NTRUSigningPrivateKeyParameters(List<Basis> bases, NTRUSigningPublicKeyParameters publicKey)59     public NTRUSigningPrivateKeyParameters(List<Basis> bases, NTRUSigningPublicKeyParameters publicKey)
60     {
61         super(true);
62         this.bases = new ArrayList<Basis>(bases);
63         this.publicKey = publicKey;
64     }
65 
66     /**
67      * Adds a basis to the key.
68      *
69      * @param b a NtruSign basis
70      */
add(Basis b)71     private void add(Basis b)
72     {
73         bases.add(b);
74     }
75 
76     /**
77      * Returns the <code>i</code>-th basis
78      *
79      * @param i the index
80      * @return the basis at index <code>i</code>
81      */
getBasis(int i)82     public Basis getBasis(int i)
83     {
84         return bases.get(i);
85     }
86 
getPublicKey()87     public NTRUSigningPublicKeyParameters getPublicKey()
88     {
89         return publicKey;
90     }
91 
92     /**
93      * Converts the key to a byte array
94      *
95      * @return the encoded key
96      */
getEncoded()97     public byte[] getEncoded()
98         throws IOException
99     {
100         ByteArrayOutputStream os = new ByteArrayOutputStream();
101         for (int i = 0; i < bases.size(); i++)
102         {
103             // all bases except for the first one contain a public key
104             bases.get(i).encode(os, i != 0);
105         }
106 
107         os.write(publicKey.getEncoded());
108 
109         return os.toByteArray();
110     }
111 
112     /**
113      * Writes the key to an output stream
114      *
115      * @param os an output stream
116      * @throws IOException
117      */
writeTo(OutputStream os)118     public void writeTo(OutputStream os)
119         throws IOException
120     {
121         os.write(getEncoded());
122     }
123 
124     @Override
hashCode()125     public int hashCode()
126     {
127         final int prime = 31;
128         int result = 1;
129         result = prime * result;
130         if (bases==null) return result;
131         result += bases.hashCode();
132         for (Basis basis : bases)
133         {
134             result += basis.hashCode();
135         }
136         return result;
137     }
138 
139     @Override
equals(Object obj)140     public boolean equals(Object obj)
141     {
142         if (this == obj)
143         {
144             return true;
145         }
146         if (obj == null)
147         {
148             return false;
149         }
150         if (getClass() != obj.getClass())
151         {
152             return false;
153         }
154         NTRUSigningPrivateKeyParameters other = (NTRUSigningPrivateKeyParameters)obj;
155         if ((bases == null) != (other.bases == null))
156         {
157             return false;
158         }
159         if (bases == null)
160         {
161             return true;
162         }
163         if (bases.size() != other.bases.size())
164         {
165             return false;
166         }
167         for (int i = 0; i < bases.size(); i++)
168         {
169             Basis basis1 = bases.get(i);
170             Basis basis2 = other.bases.get(i);
171             if (!basis1.f.equals(basis2.f))
172             {
173                 return false;
174             }
175             if (!basis1.fPrime.equals(basis2.fPrime))
176             {
177                 return false;
178             }
179             if (i != 0 && !basis1.h.equals(basis2.h))   // don't compare h for the 0th basis
180             {
181                 return false;
182             }
183             if (!basis1.params.equals(basis2.params))
184             {
185                 return false;
186             }
187         }
188         return true;
189     }
190 
191     /**
192      * A NtruSign basis. Contains three polynomials <code>f, f', h</code>.
193      */
194     public static class Basis
195     {
196         public Polynomial f;
197         public Polynomial fPrime;
198         public IntegerPolynomial h;
199         NTRUSigningKeyGenerationParameters params;
200 
201         /**
202          * Constructs a new basis from polynomials <code>f, f', h</code>.
203          *
204          * @param f
205          * @param fPrime
206          * @param h
207          * @param params NtruSign parameters
208          */
Basis(Polynomial f, Polynomial fPrime, IntegerPolynomial h, NTRUSigningKeyGenerationParameters params)209         protected Basis(Polynomial f, Polynomial fPrime, IntegerPolynomial h, NTRUSigningKeyGenerationParameters params)
210         {
211             this.f = f;
212             this.fPrime = fPrime;
213             this.h = h;
214             this.params = params;
215         }
216 
217         /**
218          * Reads a basis from an input stream and constructs a new basis.
219          *
220          * @param is        an input stream
221          * @param params    NtruSign parameters
222          * @param include_h whether to read the polynomial <code>h</code> (<code>true</code>) or only <code>f</code> and <code>f'</code> (<code>false</code>)
223          */
Basis(InputStream is, NTRUSigningKeyGenerationParameters params, boolean include_h)224         Basis(InputStream is, NTRUSigningKeyGenerationParameters params, boolean include_h)
225             throws IOException
226         {
227             int N = params.N;
228             int q = params.q;
229             int d1 = params.d1;
230             int d2 = params.d2;
231             int d3 = params.d3;
232             boolean sparse = params.sparse;
233             this.params = params;
234 
235             if (params.polyType == NTRUParameters.TERNARY_POLYNOMIAL_TYPE_PRODUCT)
236             {
237                 f = ProductFormPolynomial.fromBinary(is, N, d1, d2, d3 + 1, d3);
238             }
239             else
240             {
241                 IntegerPolynomial fInt = IntegerPolynomial.fromBinary3Tight(is, N);
242                 f = sparse ? new SparseTernaryPolynomial(fInt) : new DenseTernaryPolynomial(fInt);
243             }
244 
245             if (params.basisType == NTRUSigningKeyGenerationParameters.BASIS_TYPE_STANDARD)
246             {
247                 IntegerPolynomial fPrimeInt = IntegerPolynomial.fromBinary(is, N, q);
248                 for (int i = 0; i < fPrimeInt.coeffs.length; i++)
249                 {
250                     fPrimeInt.coeffs[i] -= q / 2;
251                 }
252                 fPrime = fPrimeInt;
253             }
254             else if (params.polyType == NTRUParameters.TERNARY_POLYNOMIAL_TYPE_PRODUCT)
255             {
256                 fPrime = ProductFormPolynomial.fromBinary(is, N, d1, d2, d3 + 1, d3);
257             }
258             else
259             {
260                 fPrime = IntegerPolynomial.fromBinary3Tight(is, N);
261             }
262 
263             if (include_h)
264             {
265                 h = IntegerPolynomial.fromBinary(is, N, q);
266             }
267         }
268 
269         /**
270          * Writes the basis to an output stream
271          *
272          * @param os        an output stream
273          * @param include_h whether to write the polynomial <code>h</code> (<code>true</code>) or only <code>f</code> and <code>f'</code> (<code>false</code>)
274          * @throws IOException
275          */
encode(OutputStream os, boolean include_h)276         void encode(OutputStream os, boolean include_h)
277             throws IOException
278         {
279             int q = params.q;
280 
281             os.write(getEncoded(f));
282             if (params.basisType == NTRUSigningKeyGenerationParameters.BASIS_TYPE_STANDARD)
283             {
284                 IntegerPolynomial fPrimeInt = fPrime.toIntegerPolynomial();
285                 for (int i = 0; i < fPrimeInt.coeffs.length; i++)
286                 {
287                     fPrimeInt.coeffs[i] += q / 2;
288                 }
289                 os.write(fPrimeInt.toBinary(q));
290             }
291             else
292             {
293                 os.write(getEncoded(fPrime));
294             }
295             if (include_h)
296             {
297                 os.write(h.toBinary(q));
298             }
299         }
300 
getEncoded(Polynomial p)301         private byte[] getEncoded(Polynomial p)
302         {
303             if (p instanceof ProductFormPolynomial)
304             {
305                 return ((ProductFormPolynomial)p).toBinary();
306             }
307             else
308             {
309                 return p.toIntegerPolynomial().toBinary3Tight();
310             }
311         }
312 
313         @Override
hashCode()314         public int hashCode()
315         {
316             final int prime = 31;
317             int result = 1;
318             result = prime * result + ((f == null) ? 0 : f.hashCode());
319             result = prime * result + ((fPrime == null) ? 0 : fPrime.hashCode());
320             result = prime * result + ((h == null) ? 0 : h.hashCode());
321             result = prime * result + ((params == null) ? 0 : params.hashCode());
322             return result;
323         }
324 
325         @Override
equals(Object obj)326         public boolean equals(Object obj)
327         {
328             if (this == obj)
329             {
330                 return true;
331             }
332             if (obj == null)
333             {
334                 return false;
335             }
336             if (!(obj instanceof Basis))
337             {
338                 return false;
339             }
340             Basis other = (Basis)obj;
341             if (f == null)
342             {
343                 if (other.f != null)
344                 {
345                     return false;
346                 }
347             }
348             else if (!f.equals(other.f))
349             {
350                 return false;
351             }
352             if (fPrime == null)
353             {
354                 if (other.fPrime != null)
355                 {
356                     return false;
357                 }
358             }
359             else if (!fPrime.equals(other.fPrime))
360             {
361                 return false;
362             }
363             if (h == null)
364             {
365                 if (other.h != null)
366                 {
367                     return false;
368                 }
369             }
370             else if (!h.equals(other.h))
371             {
372                 return false;
373             }
374             if (params == null)
375             {
376                 if (other.params != null)
377                 {
378                     return false;
379                 }
380             }
381             else if (!params.equals(other.params))
382             {
383                 return false;
384             }
385             return true;
386         }
387     }
388 }
389