1 /*
2  * Copyright 2019 Google Inc.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "third_party/private-join-and-compute/src/crypto/ec_commutative_cipher.h"
17 
18 #include <utility>
19 
20 #include "third_party/private-join-and-compute/src/util/status.inc"
21 
22 namespace private_join_and_compute {
23 
ECCommutativeCipher(std::unique_ptr<Context> context,ECGroup group,BigNum private_key,HashType hash_type)24 ECCommutativeCipher::ECCommutativeCipher(std::unique_ptr<Context> context,
25                                          ECGroup group, BigNum private_key,
26                                          HashType hash_type)
27     : context_(std::move(context)),
28       group_(std::move(group)),
29       private_key_(std::move(private_key)),
30       private_key_inverse_(private_key_.ModInverse(group_.GetOrder())),
31       hash_type_(hash_type) {}
32 
ValidateHashType(HashType hash_type)33 bool ECCommutativeCipher::ValidateHashType(HashType hash_type) {
34   return (hash_type == SHA256 || hash_type == SHA512);
35 }
36 
37 StatusOr<std::unique_ptr<ECCommutativeCipher>>
CreateWithNewKey(int curve_id,HashType hash_type)38 ECCommutativeCipher::CreateWithNewKey(int curve_id, HashType hash_type) {
39   std::unique_ptr<Context> context(new Context);
40   ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
41   if (!ECCommutativeCipher::ValidateHashType(hash_type)) {
42     return InvalidArgumentError("Invalid hash type.");
43   }
44   BigNum private_key = group.GeneratePrivateKey();
45   return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher(
46       std::move(context), std::move(group), std::move(private_key), hash_type));
47 }
48 
49 StatusOr<std::unique_ptr<ECCommutativeCipher>>
CreateFromKey(int curve_id,const std::string & key_bytes,HashType hash_type)50 ECCommutativeCipher::CreateFromKey(int curve_id, const std::string& key_bytes,
51                                    HashType hash_type) {
52   std::unique_ptr<Context> context(new Context);
53   ASSIGN_OR_RETURN(ECGroup group, ECGroup::Create(curve_id, context.get()));
54   if (!ECCommutativeCipher::ValidateHashType(hash_type)) {
55     return InvalidArgumentError("Invalid hash type.");
56   }
57   BigNum private_key = context->CreateBigNum(key_bytes);
58   auto status = group.CheckPrivateKey(private_key);
59   if (!status.ok()) {
60     return status;
61   }
62   return std::unique_ptr<ECCommutativeCipher>(new ECCommutativeCipher(
63       std::move(context), std::move(group), std::move(private_key), hash_type));
64 }
65 
Encrypt(const std::string & plaintext) const66 StatusOr<std::string> ECCommutativeCipher::Encrypt(
67     const std::string& plaintext) const {
68   StatusOr<ECPoint> status_or_point;
69   if (hash_type_ == SHA512) {
70     status_or_point = group_.GetPointByHashingToCurveSha512(plaintext);
71   } else if (hash_type_ == SHA256) {
72     status_or_point = group_.GetPointByHashingToCurveSha256(plaintext);
73   } else {
74     return InvalidArgumentError("Invalid hash type.");
75   }
76 
77   if (!status_or_point.ok()) {
78     return status_or_point.status();
79   }
80   ASSIGN_OR_RETURN(ECPoint encrypted_point,
81                    Encrypt(status_or_point.ValueOrDie()));
82   return encrypted_point.ToBytesCompressed();
83 }
84 
ReEncrypt(const std::string & ciphertext) const85 StatusOr<std::string> ECCommutativeCipher::ReEncrypt(
86     const std::string& ciphertext) const {
87   ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext));
88   ASSIGN_OR_RETURN(ECPoint reencrypted_point, Encrypt(point));
89   return reencrypted_point.ToBytesCompressed();
90 }
91 
Encrypt(const ECPoint & point) const92 StatusOr<ECPoint> ECCommutativeCipher::Encrypt(const ECPoint& point) const {
93   return point.Mul(private_key_);
94 }
95 
Decrypt(const std::string & ciphertext) const96 StatusOr<std::string> ECCommutativeCipher::Decrypt(
97     const std::string& ciphertext) const {
98   ASSIGN_OR_RETURN(ECPoint point, group_.CreateECPoint(ciphertext));
99   ASSIGN_OR_RETURN(ECPoint decrypted_point, point.Mul(private_key_inverse_));
100   return decrypted_point.ToBytesCompressed();
101 }
102 
GetPrivateKeyBytes() const103 std::string ECCommutativeCipher::GetPrivateKeyBytes() const {
104   return private_key_.ToBytes();
105 }
106 
107 }  // namespace private_join_and_compute
108