1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/seal.h"
5 #include "seal/util/rlwe.h"
6 #include "bench.h"
7 
8 using namespace benchmark;
9 using namespace sealbench;
10 using namespace seal;
11 using namespace std;
12 
13 /**
14 This file defines benchmarks for CKKS-specific HE primitives.
15 */
16 
17 namespace sealbench
18 {
bm_ckks_encrypt_secret(State & state,shared_ptr<BMEnv> bm_env)19     void bm_ckks_encrypt_secret(State &state, shared_ptr<BMEnv> bm_env)
20     {
21         vector<Ciphertext> &ct = bm_env->ct();
22         Plaintext &pt = bm_env->pt()[0];
23         for (auto _ : state)
24         {
25             state.PauseTiming();
26             bm_env->randomize_pt_ckks(pt);
27 
28             state.ResumeTiming();
29             bm_env->encryptor()->encrypt_symmetric(pt, ct[2]);
30         }
31     }
32 
bm_ckks_encrypt_public(State & state,shared_ptr<BMEnv> bm_env)33     void bm_ckks_encrypt_public(State &state, shared_ptr<BMEnv> bm_env)
34     {
35         vector<Ciphertext> &ct = bm_env->ct();
36         Plaintext &pt = bm_env->pt()[0];
37         for (auto _ : state)
38         {
39             state.PauseTiming();
40             bm_env->randomize_pt_ckks(pt);
41 
42             state.ResumeTiming();
43             bm_env->encryptor()->encrypt(pt, ct[2]);
44         }
45     }
46 
bm_ckks_decrypt(State & state,shared_ptr<BMEnv> bm_env)47     void bm_ckks_decrypt(State &state, shared_ptr<BMEnv> bm_env)
48     {
49         vector<Ciphertext> &ct = bm_env->ct();
50         Plaintext &pt = bm_env->pt()[0];
51         for (auto _ : state)
52         {
53             state.PauseTiming();
54             bm_env->randomize_ct_ckks(ct[0]);
55 
56             state.ResumeTiming();
57             bm_env->decryptor()->decrypt(ct[0], pt);
58         }
59     }
60 
bm_ckks_encode_double(State & state,shared_ptr<BMEnv> bm_env)61     void bm_ckks_encode_double(State &state, shared_ptr<BMEnv> bm_env)
62     {
63         vector<double> &msg = bm_env->msg_double();
64         Plaintext &pt = bm_env->pt()[0];
65         parms_id_type parms_id = bm_env->context().first_parms_id();
66         double scale = bm_env->safe_scale();
67         for (auto _ : state)
68         {
69             state.PauseTiming();
70             bm_env->randomize_message_double(msg);
71 
72             state.ResumeTiming();
73             bm_env->ckks_encoder()->encode(msg, parms_id, scale, pt);
74         }
75     }
76 
bm_ckks_decode_double(State & state,shared_ptr<BMEnv> bm_env)77     void bm_ckks_decode_double(State &state, shared_ptr<BMEnv> bm_env)
78     {
79         vector<double> &msg = bm_env->msg_double();
80         Plaintext &pt = bm_env->pt()[0];
81         for (auto _ : state)
82         {
83             state.PauseTiming();
84             bm_env->randomize_pt_ckks(pt);
85 
86             state.ResumeTiming();
87             bm_env->ckks_encoder()->decode(pt, msg);
88         }
89     }
90 
bm_ckks_add_ct(State & state,shared_ptr<BMEnv> bm_env)91     void bm_ckks_add_ct(State &state, shared_ptr<BMEnv> bm_env)
92     {
93         vector<Ciphertext> &ct = bm_env->ct();
94         double scale = bm_env->safe_scale();
95         for (auto _ : state)
96         {
97             state.PauseTiming();
98             bm_env->randomize_ct_ckks(ct[0]);
99             ct[0].scale() = scale;
100             bm_env->randomize_ct_ckks(ct[1]);
101             ct[1].scale() = scale;
102             state.ResumeTiming();
103             Ciphertext res;
104             bm_env->evaluator()->add(ct[0], ct[1], res);
105         }
106     }
107 
bm_ckks_add_pt(State & state,shared_ptr<BMEnv> bm_env)108     void bm_ckks_add_pt(State &state, shared_ptr<BMEnv> bm_env)
109     {
110         vector<Ciphertext> &ct = bm_env->ct();
111         Plaintext &pt = bm_env->pt()[0];
112         double scale = bm_env->safe_scale();
113         for (auto _ : state)
114         {
115             state.PauseTiming();
116             bm_env->randomize_ct_ckks(ct[0]);
117             ct[0].scale() = scale;
118             bm_env->randomize_pt_ckks(pt);
119             pt.scale() = scale;
120 
121             state.ResumeTiming();
122             bm_env->evaluator()->add_plain(ct[0], pt, ct[2]);
123         }
124     }
125 
bm_ckks_negate(State & state,shared_ptr<BMEnv> bm_env)126     void bm_ckks_negate(State &state, shared_ptr<BMEnv> bm_env)
127     {
128         vector<Ciphertext> &ct = bm_env->ct();
129         double scale = bm_env->safe_scale();
130         for (auto _ : state)
131         {
132             state.PauseTiming();
133             bm_env->randomize_ct_ckks(ct[0]);
134             ct[0].scale() = scale;
135 
136             state.ResumeTiming();
137             bm_env->evaluator()->negate(ct[0], ct[2]);
138         }
139     }
140 
bm_ckks_sub_ct(State & state,shared_ptr<BMEnv> bm_env)141     void bm_ckks_sub_ct(State &state, shared_ptr<BMEnv> bm_env)
142     {
143         vector<Ciphertext> &ct = bm_env->ct();
144         double scale = bm_env->safe_scale();
145         for (auto _ : state)
146         {
147             state.PauseTiming();
148             bm_env->randomize_ct_ckks(ct[0]);
149             ct[0].scale() = scale;
150             bm_env->randomize_ct_ckks(ct[1]);
151             ct[1].scale() = scale;
152 
153             state.ResumeTiming();
154             bm_env->evaluator()->sub(ct[0], ct[1], ct[2]);
155         }
156     }
157 
bm_ckks_sub_pt(State & state,shared_ptr<BMEnv> bm_env)158     void bm_ckks_sub_pt(State &state, shared_ptr<BMEnv> bm_env)
159     {
160         vector<Ciphertext> &ct = bm_env->ct();
161         Plaintext &pt = bm_env->pt()[0];
162         double scale = bm_env->safe_scale();
163         for (auto _ : state)
164         {
165             state.PauseTiming();
166             bm_env->randomize_ct_ckks(ct[0]);
167             ct[0].scale() = scale;
168             bm_env->randomize_pt_ckks(pt);
169             pt.scale() = scale;
170 
171             state.ResumeTiming();
172             bm_env->evaluator()->sub_plain(ct[0], pt, ct[2]);
173         }
174     }
175 
bm_ckks_mul_ct(State & state,shared_ptr<BMEnv> bm_env)176     void bm_ckks_mul_ct(State &state, shared_ptr<BMEnv> bm_env)
177     {
178         vector<Ciphertext> &ct = bm_env->ct();
179         double scale = bm_env->safe_scale();
180         for (auto _ : state)
181         {
182             state.PauseTiming();
183             bm_env->randomize_ct_ckks(ct[0]);
184             ct[0].scale() = scale;
185             bm_env->randomize_ct_ckks(ct[1]);
186             ct[1].scale() = scale;
187 
188             state.ResumeTiming();
189             bm_env->evaluator()->multiply(ct[0], ct[1], ct[2]);
190         }
191     }
192 
bm_ckks_mul_pt(State & state,shared_ptr<BMEnv> bm_env)193     void bm_ckks_mul_pt(State &state, shared_ptr<BMEnv> bm_env)
194     {
195         vector<Ciphertext> &ct = bm_env->ct();
196         Plaintext &pt = bm_env->pt()[0];
197         double scale = bm_env->safe_scale();
198         for (auto _ : state)
199         {
200             state.PauseTiming();
201             bm_env->randomize_ct_ckks(ct[0]);
202             ct[0].scale() = scale;
203             bm_env->randomize_pt_ckks(pt);
204             pt.scale() = scale;
205 
206             state.ResumeTiming();
207             bm_env->evaluator()->multiply_plain(ct[0], pt, ct[2]);
208         }
209     }
210 
bm_ckks_square(State & state,shared_ptr<BMEnv> bm_env)211     void bm_ckks_square(State &state, shared_ptr<BMEnv> bm_env)
212     {
213         vector<Ciphertext> &ct = bm_env->ct();
214         double scale = bm_env->safe_scale();
215         for (auto _ : state)
216         {
217             state.PauseTiming();
218             bm_env->randomize_ct_ckks(ct[0]);
219             ct[0].scale() = scale;
220             bm_env->randomize_ct_ckks(ct[1]);
221             ct[1].scale() = scale;
222 
223             state.ResumeTiming();
224             bm_env->evaluator()->square(ct[0], ct[2]);
225         }
226     }
227 
bm_ckks_rescale_inplace(State & state,shared_ptr<BMEnv> bm_env)228     void bm_ckks_rescale_inplace(State &state, shared_ptr<BMEnv> bm_env)
229     {
230         vector<Ciphertext> &ct = bm_env->ct();
231         double scale = bm_env->safe_scale() * pow(2.0, 20);
232         for (auto _ : state)
233         {
234             state.PauseTiming();
235             bm_env->randomize_ct_ckks(ct[0]);
236             ct[0].scale() = scale;
237 
238             state.ResumeTiming();
239             bm_env->evaluator()->rescale_to_next_inplace(ct[0]);
240         }
241     }
242 
bm_ckks_relin_inplace(State & state,shared_ptr<BMEnv> bm_env)243     void bm_ckks_relin_inplace(State &state, shared_ptr<BMEnv> bm_env)
244     {
245         Ciphertext ct;
246         for (auto _ : state)
247         {
248             state.PauseTiming();
249             ct.resize(bm_env->context(), size_t(3));
250             bm_env->randomize_ct_ckks(ct);
251 
252             state.ResumeTiming();
253             bm_env->evaluator()->relinearize_inplace(ct, bm_env->rlk());
254         }
255     }
256 
bm_ckks_rotate(State & state,shared_ptr<BMEnv> bm_env)257     void bm_ckks_rotate(State &state, shared_ptr<BMEnv> bm_env)
258     {
259         vector<Ciphertext> &ct = bm_env->ct();
260         for (auto _ : state)
261         {
262             state.PauseTiming();
263             bm_env->randomize_ct_ckks(ct[0]);
264 
265             state.ResumeTiming();
266             bm_env->evaluator()->rotate_vector(ct[0], 1, bm_env->glk(), ct[2]);
267         }
268     }
269 } // namespace sealbench
270