1 /*
2  * Copyright 2019 Google Inc.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "third_party/private-join-and-compute/src/crypto/ec_group.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "third_party/private-join-and-compute/src/crypto/ec_point.h"
22 #include "third_party/private-join-and-compute/src/crypto/openssl.inc"
23 #include "third_party/private-join-and-compute/src/util/status.inc"
24 
25 namespace private_join_and_compute {
26 
27 namespace {
28 
29 // Returns a group using the predefined underlying operations suggested by
30 // OpenSSL.
CreateGroup(int curve_id)31 StatusOr<ECGroup::ECGroupPtr> CreateGroup(int curve_id) {
32   auto ec_group_ptr = EC_GROUP_new_by_curve_name(curve_id);
33   // If this fails, this is usually due to an invalid curve id.
34   if (ec_group_ptr == nullptr) {
35     return InvalidArgumentError(
36         "ECGroup::CreateGroup() - Could not create group. " +
37                      OpenSSLErrorString());
38   }
39   return ECGroup::ECGroupPtr(ec_group_ptr);
40 }
41 
42 // Returns the order of the group. For more information, see
43 // https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
CreateOrder(const EC_GROUP * group,Context * context)44 StatusOr<BigNum> CreateOrder(const EC_GROUP* group, Context* context) {
45   BIGNUM* bn = BN_new();
46   if (bn == nullptr) {
47     return InternalError(
48         "ECGroup::CreateOrder - Could not create BIGNUM. " +
49                      OpenSSLErrorString());
50   }
51   BigNum::BignumPtr order = BigNum::BignumPtr(bn);
52   if (EC_GROUP_get_order(group, order.get(), context->GetBnCtx()) != 1) {
53     return InternalError(
54         "ECGroup::CreateOrder - Could not get order. " + OpenSSLErrorString());
55   }
56   return context->CreateBigNum(std::move(order));
57 }
58 
59 // Returns the cofactor of the group.
CreateCofactor(const EC_GROUP * group,Context * context)60 StatusOr<BigNum> CreateCofactor(const EC_GROUP* group, Context* context) {
61   BIGNUM* bn = BN_new();
62   if (bn == nullptr) {
63     return InternalError(
64         "ECGroup::CreateCofactor - Could not create BIGNUM. " +
65                      OpenSSLErrorString());
66   }
67   BigNum::BignumPtr cofactor = BigNum::BignumPtr(bn);
68   if (EC_GROUP_get_cofactor(group, cofactor.get(), context->GetBnCtx()) != 1) {
69     return InternalError(
70         "ECGroup::CreateCofactor - Could not get cofactor. " +
71                      OpenSSLErrorString());
72   }
73   return context->CreateBigNum(std::move(cofactor));
74 }
75 
76 // Returns the parameters that define the curve. For more information, see
77 // https://en.wikipedia.org/wiki/Elliptic-curve_cryptography#Domain_parameters.
CreateCurveParams(const EC_GROUP * group,Context * context)78 StatusOr<ECGroup::CurveParams> CreateCurveParams(const EC_GROUP* group,
79                                                  Context* context) {
80   BIGNUM* bn1 = BN_new();
81   BIGNUM* bn2 = BN_new();
82   BIGNUM* bn3 = BN_new();
83   if (bn1 == nullptr || bn2 == nullptr || bn3 == nullptr) {
84     return InternalError(
85         "ECGroup::CreateCurveParams - Could not create BIGNUM. " +
86                      OpenSSLErrorString());
87   }
88   BigNum::BignumPtr p = BigNum::BignumPtr(bn1);
89   BigNum::BignumPtr a = BigNum::BignumPtr(bn2);
90   BigNum::BignumPtr b = BigNum::BignumPtr(bn3);
91   if (EC_GROUP_get_curve_GFp(group, p.get(), a.get(), b.get(),
92                              context->GetBnCtx()) != 1) {
93     return InternalError(
94         "ECGroup::CreateCurveParams - Could not get params. " +
95                      OpenSSLErrorString());
96   }
97   BigNum p_bn = context->CreateBigNum(std::move(p));
98   if (!p_bn.IsPrime()) {
99     return InternalError(
100         "ECGroup::CreateCurveParams - p is not prime. " + OpenSSLErrorString());
101   }
102   return ECGroup::CurveParams{std::move(p_bn),
103                               context->CreateBigNum(std::move(a)),
104                               context->CreateBigNum(std::move(b))};
105 }
106 
107 // Returns (p - 1) / 2 where p is a curve-defining parameter.
GetPMinusOneOverTwo(const ECGroup::CurveParams & curve_params,Context * context)108 BigNum GetPMinusOneOverTwo(const ECGroup::CurveParams& curve_params,
109                            Context* context) {
110   return (curve_params.p - context->One()) / context->Two();
111 }
112 
113 }  // namespace
114 
ECGroup(Context * context,ECGroupPtr group,BigNum order,BigNum cofactor,CurveParams curve_params,BigNum p_minus_one_over_two)115 ECGroup::ECGroup(Context* context, ECGroupPtr group, BigNum order,
116                  BigNum cofactor, CurveParams curve_params,
117                  BigNum p_minus_one_over_two)
118     : context_(context),
119       group_(std::move(group)),
120       order_(std::move(order)),
121       cofactor_(std::move(cofactor)),
122       curve_params_(std::move(curve_params)),
123       p_minus_one_over_two_(std::move(p_minus_one_over_two)) {}
124 
Create(int curve_id,Context * context)125 StatusOr<ECGroup> ECGroup::Create(int curve_id, Context* context) {
126   ASSIGN_OR_RETURN(ECGroupPtr g, CreateGroup(curve_id));
127   ASSIGN_OR_RETURN(BigNum order, CreateOrder(g.get(), context));
128   ASSIGN_OR_RETURN(BigNum cofactor, CreateCofactor(g.get(), context));
129   ASSIGN_OR_RETURN(CurveParams params, CreateCurveParams(g.get(), context));
130   BigNum p_minus_one_over_two = GetPMinusOneOverTwo(params, context);
131   return ECGroup(context, std::move(g), std::move(order), std::move(cofactor),
132                  std::move(params), std::move(p_minus_one_over_two));
133 }
134 
GeneratePrivateKey() const135 BigNum ECGroup::GeneratePrivateKey() const {
136   return context_->GenerateRandBetween(context_->One(), order_);
137 }
138 
CheckPrivateKey(const BigNum & priv_key) const139 Status ECGroup::CheckPrivateKey(const BigNum& priv_key) const {
140   if (context_->Zero() >= priv_key || priv_key >= order_) {
141     return InvalidArgumentError(
142         "The given key is out of bounds, needs to be in [1, order) instead.");
143   }
144   return OkStatus();
145 }
146 
GetPointByHashingToCurveInternal(const BigNum & x) const147 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveInternal(
148     const BigNum& x) const {
149   BigNum mod_x = x.Mod(curve_params_.p);
150   BigNum y2 = ComputeYSquare(mod_x);
151   if (IsSquare(y2)) {
152     BigNum sqrt = y2.ModSqrt(curve_params_.p);
153     if (sqrt.IsBitSet(0)) {
154       return CreateECPoint(mod_x, sqrt.ModNegate(curve_params_.p));
155     }
156     return CreateECPoint(mod_x, sqrt);
157   }
158   return InternalError("Could not hash x to the curve.");
159 }
160 
GetPointByHashingToCurveSha256(const std::string & m) const161 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha256(
162     const std::string& m) const {
163   BigNum x = context_->RandomOracleSha256(m, curve_params_.p);
164   while (true) {
165     auto status_or_point = GetPointByHashingToCurveInternal(x);
166     if (status_or_point.ok()) {
167       return status_or_point;
168     }
169     x = context_->RandomOracleSha256(x.ToBytes(), curve_params_.p);
170   }
171 }
172 
GetPointByHashingToCurveSha512(const std::string & m) const173 StatusOr<ECPoint> ECGroup::GetPointByHashingToCurveSha512(
174     const std::string& m) const {
175   BigNum x = context_->RandomOracleSha512(m, curve_params_.p);
176   while (true) {
177     auto status_or_point = GetPointByHashingToCurveInternal(x);
178     if (status_or_point.ok()) {
179       return status_or_point;
180     }
181     x = context_->RandomOracleSha512(x.ToBytes(), curve_params_.p);
182   }
183 }
184 
ComputeYSquare(const BigNum & x) const185 BigNum ECGroup::ComputeYSquare(const BigNum& x) const {
186   return (x.Exp(context_->Three()) + curve_params_.a * x + curve_params_.b)
187       .Mod(curve_params_.p);
188 }
189 
IsValid(const ECPoint & point) const190 bool ECGroup::IsValid(const ECPoint& point) const {
191   if (!IsOnCurve(point) || IsAtInfinity(point)) {
192     return false;
193   }
194   return true;
195 }
196 
IsOnCurve(const ECPoint & point) const197 bool ECGroup::IsOnCurve(const ECPoint& point) const {
198   return 1 == EC_POINT_is_on_curve(group_.get(), point.point_.get(),
199                                    context_->GetBnCtx());
200 }
201 
IsAtInfinity(const ECPoint & point) const202 bool ECGroup::IsAtInfinity(const ECPoint& point) const {
203   return 1 == EC_POINT_is_at_infinity(group_.get(), point.point_.get());
204 }
205 
IsSquare(const BigNum & q) const206 bool ECGroup::IsSquare(const BigNum& q) const {
207   return q.ModExp(p_minus_one_over_two_, curve_params_.p).IsOne();
208 }
209 
GetFixedGenerator() const210 StatusOr<ECPoint> ECGroup::GetFixedGenerator() const {
211   const EC_POINT* ssl_generator = EC_GROUP_get0_generator(group_.get());
212   EC_POINT* dup_ssl_generator = EC_POINT_dup(ssl_generator, group_.get());
213   if (dup_ssl_generator == nullptr) {
214     return InternalError(OpenSSLErrorString());
215   }
216   return ECPoint(group_.get(), context_->GetBnCtx(),
217                  ECPoint::ECPointPtr(dup_ssl_generator));
218 }
219 
GetRandomGenerator() const220 StatusOr<ECPoint> ECGroup::GetRandomGenerator() const {
221   ASSIGN_OR_RETURN(ECPoint generator, GetFixedGenerator());
222   return generator.Mul(context_->GenerateRandBetween(context_->One(), order_));
223 }
224 
CreateECPoint(const BigNum & x,const BigNum & y) const225 StatusOr<ECPoint> ECGroup::CreateECPoint(const BigNum& x,
226                                          const BigNum& y) const {
227   ECPoint point = ECPoint(group_.get(), context_->GetBnCtx(), x, y);
228   if (!IsValid(point)) {
229     return InvalidArgumentError(
230         "ECGroup::CreateECPoint(x,y) - The point is not valid.");
231   }
232   return std::move(point);
233 }
234 
CreateECPoint(const std::string & bytes) const235 StatusOr<ECPoint> ECGroup::CreateECPoint(const std::string& bytes) const {
236   auto raw_ec_point_ptr = EC_POINT_new(group_.get());
237   if (raw_ec_point_ptr == nullptr) {
238     return InternalError("ECGroup::CreateECPoint: Failed to create point.");
239   }
240   ECPoint::ECPointPtr point(raw_ec_point_ptr);
241   if (EC_POINT_oct2point(group_.get(), point.get(),
242                          reinterpret_cast<const unsigned char*>(bytes.data()),
243                          bytes.size(), context_->GetBnCtx()) != 1) {
244     return InvalidArgumentError(
245         "ECGroup::CreateECPoint(string) - Could not decode point.\n" + OpenSSLErrorString());
246   }
247 
248   ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
249   if (!IsValid(ec_point)) {
250     return InvalidArgumentError(
251         "ECGroup::CreateECPoint(string) - Decoded point is not valid.");
252   }
253   return std::move(ec_point);
254 }
255 
GetPointAtInfinity() const256 StatusOr<ECPoint> ECGroup::GetPointAtInfinity() const {
257   EC_POINT* new_point = EC_POINT_new(group_.get());
258   if (new_point == nullptr) {
259     return InternalError(
260         "ECGroup::GetPointAtInfinity() - Could not create new point.");
261   }
262   ECPoint::ECPointPtr point(new_point);
263   if (EC_POINT_set_to_infinity(group_.get(), point.get()) != 1) {
264     return InternalError(
265         "ECGroup::GetPointAtInfinity() - Could not get point at infinity.");
266   }
267   ECPoint ec_point(group_.get(), context_->GetBnCtx(), std::move(point));
268   return std::move(ec_point);
269 }
270 
271 }  // namespace private_join_and_compute
272