1 /*
2  * (C) Copyright Projet SECRET, INRIA, Rocquencourt
3  * (C) Bhaskar Biswas and  Nicolas Sendrier
4  *
5  * (C) 2014 cryptosource GmbH
6  * (C) 2014 Falko Strenzke fstrenzke@cryptosource.de
7  *
8  * Botan is released under the Simplified BSD License (see license.txt)
9  *
10  */
11 
12 #include <botan/internal/mce_internal.h>
13 #include <botan/internal/code_based_util.h>
14 
15 namespace Botan {
16 
17 namespace {
18 
matrix_arr_mul(std::vector<uint32_t> matrix,size_t numo_rows,size_t words_per_row,const uint8_t input_vec[],uint32_t output_vec[],size_t output_vec_len)19 void matrix_arr_mul(std::vector<uint32_t> matrix,
20                     size_t numo_rows,
21                     size_t words_per_row,
22                     const uint8_t input_vec[],
23                     uint32_t output_vec[],
24                     size_t output_vec_len)
25    {
26    for(size_t j = 0; j < numo_rows; j++)
27       {
28       if((input_vec[j / 8] >> (j % 8)) & 1)
29          {
30          for(size_t i = 0; i < output_vec_len; i++)
31             {
32             output_vec[i] ^= matrix[j * (words_per_row) + i];
33             }
34          }
35       }
36    }
37 
38 /**
39 * returns the error vector to the syndrome
40 */
goppa_decode(const polyn_gf2m & syndrom_polyn,const polyn_gf2m & g,const std::vector<polyn_gf2m> & sqrtmod,const std::vector<gf2m> & Linv)41 secure_vector<gf2m> goppa_decode(const polyn_gf2m & syndrom_polyn,
42                                  const polyn_gf2m & g,
43                                  const std::vector<polyn_gf2m> & sqrtmod,
44                                  const std::vector<gf2m> & Linv)
45    {
46    const size_t code_length = Linv.size();
47    gf2m a;
48    uint32_t t = g.get_degree();
49 
50    std::shared_ptr<GF2m_Field> sp_field = g.get_sp_field();
51 
52    std::pair<polyn_gf2m, polyn_gf2m> h_aux = polyn_gf2m::eea_with_coefficients( syndrom_polyn, g, 1);
53    polyn_gf2m & h = h_aux.first;
54    polyn_gf2m & aux = h_aux.second;
55    a = sp_field->gf_inv(aux.get_coef(0));
56    gf2m log_a = sp_field->gf_log(a);
57    for(int i = 0; i <= h.get_degree(); ++i)
58       {
59       h.set_coef(i,sp_field->gf_mul_zrz(log_a,h.get_coef(i)));
60       }
61 
62    //  compute h(z) += z
63    h.add_to_coef( 1, 1);
64    // compute S square root of h (using sqrtmod)
65    polyn_gf2m S(t - 1, g.get_sp_field());
66 
67    for(uint32_t i=0;i<t;i++)
68       {
69       a = sp_field->gf_sqrt(h.get_coef(i));
70 
71       if(i & 1)
72          {
73          for(uint32_t j=0;j<t;j++)
74             {
75             S.add_to_coef( j, sp_field->gf_mul(a, sqrtmod[i/2].get_coef(j)));
76             }
77          }
78       else
79          {
80          S.add_to_coef( i/2, a);
81          }
82       } /* end for loop (i) */
83 
84 
85    S.get_degree();
86 
87    std::pair<polyn_gf2m, polyn_gf2m> v_u = polyn_gf2m::eea_with_coefficients(S, g, t/2+1);
88    polyn_gf2m & u = v_u.second;
89    polyn_gf2m & v = v_u.first;
90 
91    // sigma = u^2+z*v^2
92    polyn_gf2m sigma ( t , g.get_sp_field());
93 
94    const int u_deg = u.get_degree();
95    BOTAN_ASSERT(u_deg >= 0, "Valid degree");
96    for(int i = 0; i <= u_deg; ++i)
97       {
98       sigma.set_coef(2*i, sp_field->gf_square(u.get_coef(i)));
99       }
100 
101    const int v_deg = v.get_degree();
102    BOTAN_ASSERT(v_deg >= 0, "Valid degree");
103    for(int i = 0; i <= v_deg; ++i)
104       {
105       sigma.set_coef(2*i+1, sp_field->gf_square(v.get_coef(i)));
106       }
107 
108    secure_vector<gf2m> res = find_roots_gf2m_decomp(sigma, code_length);
109    size_t d = res.size();
110 
111    secure_vector<gf2m> result(d);
112    for(uint32_t i = 0; i < d; ++i)
113       {
114       gf2m current = res[i];
115 
116       gf2m tmp;
117       tmp = gray_to_lex(current);
118       /// XXX double assignment, possible bug?
119       if(tmp >= code_length) /* invalid root */
120          {
121          result[i] = static_cast<gf2m>(i);
122          }
123       result[i] = Linv[tmp];
124       }
125 
126    return result;
127    }
128 }
129 
mceliece_decrypt(secure_vector<uint8_t> & plaintext_out,secure_vector<uint8_t> & error_mask_out,const secure_vector<uint8_t> & ciphertext,const McEliece_PrivateKey & key)130 void mceliece_decrypt(secure_vector<uint8_t>& plaintext_out,
131                       secure_vector<uint8_t>& error_mask_out,
132                       const secure_vector<uint8_t>& ciphertext,
133                       const McEliece_PrivateKey& key)
134    {
135    mceliece_decrypt(plaintext_out, error_mask_out, ciphertext.data(), ciphertext.size(), key);
136    }
137 
mceliece_decrypt(secure_vector<uint8_t> & plaintext,secure_vector<uint8_t> & error_mask,const uint8_t ciphertext[],size_t ciphertext_len,const McEliece_PrivateKey & key)138 void mceliece_decrypt(
139    secure_vector<uint8_t>& plaintext,
140    secure_vector<uint8_t> & error_mask,
141    const uint8_t ciphertext[],
142    size_t ciphertext_len,
143    const McEliece_PrivateKey & key)
144    {
145    secure_vector<gf2m> error_pos;
146    plaintext = mceliece_decrypt(error_pos, ciphertext, ciphertext_len, key);
147 
148    const size_t code_length = key.get_code_length();
149    secure_vector<uint8_t> result((code_length+7)/8);
150    for(auto&& pos : error_pos)
151       {
152       if(pos > code_length)
153          {
154          throw Invalid_Argument("error position larger than code size");
155          }
156       result[pos / 8] |= (1 << (pos % 8));
157       }
158 
159    error_mask = result;
160    }
161 
162 /**
163 * @p p_err_pos_len must point to the available length of @p error_pos on input, the
164 * function will set it to the actual number of errors returned in the @p error_pos
165 * array */
mceliece_decrypt(secure_vector<gf2m> & error_pos,const uint8_t * ciphertext,size_t ciphertext_len,const McEliece_PrivateKey & key)166 secure_vector<uint8_t> mceliece_decrypt(
167    secure_vector<gf2m> & error_pos,
168    const uint8_t *ciphertext, size_t ciphertext_len,
169    const McEliece_PrivateKey & key)
170    {
171 
172    const size_t dimension = key.get_dimension();
173    const size_t codimension = key.get_codimension();
174    const uint32_t t = key.get_goppa_polyn().get_degree();
175    polyn_gf2m syndrome_polyn(key.get_goppa_polyn().get_sp_field()); // init as zero polyn
176    const unsigned unused_pt_bits = dimension % 8;
177    const uint8_t unused_pt_bits_mask = (1 << unused_pt_bits) - 1;
178 
179    if(ciphertext_len != (key.get_code_length()+7)/8)
180       {
181       throw Invalid_Argument("wrong size of McEliece ciphertext");
182       }
183    const size_t cleartext_len = (key.get_message_word_bit_length()+7)/8;
184 
185    if(cleartext_len != bit_size_to_byte_size(dimension))
186       {
187       throw Invalid_Argument("mce-decryption: wrong length of cleartext buffer");
188       }
189 
190    secure_vector<uint32_t> syndrome_vec(bit_size_to_32bit_size(codimension));
191    matrix_arr_mul(key.get_H_coeffs(),
192                   key.get_code_length(),
193                   bit_size_to_32bit_size(codimension),
194                   ciphertext,
195                   syndrome_vec.data(), syndrome_vec.size());
196 
197    secure_vector<uint8_t> syndrome_byte_vec(bit_size_to_byte_size(codimension));
198    const size_t syndrome_byte_vec_size = syndrome_byte_vec.size();
199    for(size_t i = 0; i < syndrome_byte_vec_size; i++)
200       {
201       syndrome_byte_vec[i] = static_cast<uint8_t>(syndrome_vec[i/4] >> (8 * (i % 4)));
202       }
203 
204    syndrome_polyn = polyn_gf2m(t-1, syndrome_byte_vec.data(), bit_size_to_byte_size(codimension), key.get_goppa_polyn().get_sp_field());
205 
206    syndrome_polyn.get_degree();
207    error_pos = goppa_decode(syndrome_polyn, key.get_goppa_polyn(), key.get_sqrtmod(), key.get_Linv());
208 
209    const size_t nb_err = error_pos.size();
210 
211    secure_vector<uint8_t> cleartext(cleartext_len);
212    copy_mem(cleartext.data(), ciphertext, cleartext_len);
213 
214    for(size_t i = 0; i < nb_err; i++)
215       {
216       gf2m current = error_pos[i];
217 
218       if(current >= cleartext_len * 8)
219          {
220          // an invalid position, this shouldn't happen
221          continue;
222          }
223       cleartext[current / 8] ^= (1 << (current % 8));
224       }
225 
226    if(unused_pt_bits)
227       {
228       cleartext[cleartext_len - 1] &= unused_pt_bits_mask;
229       }
230 
231    return cleartext;
232    }
233 
234 }
235