1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 using Microsoft.Research.SEAL;
5 using Microsoft.VisualStudio.TestTools.UnitTesting;
6 using System;
7 using System.Collections.Generic;
8 using System.Numerics;
9 
10 namespace SEALNetTest
11 {
12     [TestClass]
13     public class CKKSEncoderTests
14     {
15         [TestMethod]
EncodeDecodeDoubleTest()16         public void EncodeDecodeDoubleTest()
17         {
18             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);
19             parms.PolyModulusDegree = 64;
20             parms.CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40, 40, 40 });
21             SEALContext context = new SEALContext(parms,
22                 expandModChain: false,
23                 secLevel: SecLevelType.None);
24 
25             int slots = 16;
26             Plaintext plain = new Plaintext();
27             double delta = 1 << 16;
28             List<Complex> result = new List<Complex>();
29 
30             CKKSEncoder encoder = new CKKSEncoder(context);
31             Assert.AreEqual(32ul, encoder.SlotCount);
32 
33             double value = 10d;
34             encoder.Encode(value, delta, plain);
35             encoder.Decode(plain, result);
36 
37             for (int i = 0; i < slots; i++)
38             {
39                 double tmp = Math.Abs(value - result[i].Real);
40                 Assert.IsTrue(tmp < 0.5);
41             }
42         }
43 
44         [TestMethod]
EncodeDecodeUlongTest()45         public void EncodeDecodeUlongTest()
46         {
47 
48             int slots = 32;
49             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);
50             parms.PolyModulusDegree = (ulong)slots * 2;
51             parms.CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40, 40, 40 });
52             SEALContext context = new SEALContext(parms,
53                 expandModChain: false,
54                 secLevel: SecLevelType.None);
55             CKKSEncoder encoder = new CKKSEncoder(context);
56 
57             Plaintext plain = new Plaintext();
58             List<Complex> result = new List<Complex>();
59 
60             long value = 15;
61             encoder.Encode(value, plain);
62             encoder.Decode(plain, result);
63 
64             for (int i = 0; i < slots; i++)
65             {
66                 double tmp = Math.Abs(value - result[i].Real);
67                 Assert.IsTrue(tmp < 0.5);
68             }
69         }
70 
71         [TestMethod]
EncodeDecodeComplexTest()72         public void EncodeDecodeComplexTest()
73         {
74             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS)
75             {
76                 PolyModulusDegree = 64,
77                 CoeffModulus = CoeffModulus.Create(64, new int[] { 40, 40, 40, 40 })
78             };
79 
80             SEALContext context = new SEALContext(parms,
81                 expandModChain: false,
82                 secLevel: SecLevelType.None);
83             CKKSEncoder encoder = new CKKSEncoder(context);
84 
85             Plaintext plain = new Plaintext();
86             Complex value = new Complex(3.1415, 2.71828);
87 
88             encoder.Encode(value, scale: Math.Pow(2, 20), destination: plain);
89 
90             List<Complex> result = new List<Complex>();
91             encoder.Decode(plain, result);
92 
93             Assert.IsTrue(result.Count > 0);
94             Assert.AreEqual(3.1415, result[0].Real, delta: 0.0001);
95             Assert.AreEqual(2.71828, result[0].Imaginary, delta: 0.0001);
96         }
97 
98         [TestMethod]
EncodeDecodeVectorTest()99         public void EncodeDecodeVectorTest()
100         {
101             int slots = 32;
102             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS);
103             parms.PolyModulusDegree = (ulong)slots * 2;
104             parms.CoeffModulus = CoeffModulus.Create((ulong)slots * 2, new int[] { 60, 60, 60, 60 });
105             SEALContext context = new SEALContext(parms,
106                 expandModChain: false,
107                 secLevel: SecLevelType.None);
108             CKKSEncoder encoder = new CKKSEncoder(context);
109 
110             List<Complex> values = new List<Complex>(slots);
111             Random rnd = new Random();
112             int dataBound = 1 << 30;
113             double delta = 1ul << 40;
114 
115             for (int i = 0; i < slots; i++)
116             {
117                 values.Add(new Complex(rnd.Next() % dataBound, 0));
118             }
119 
120             Plaintext plain = new Plaintext();
121             encoder.Encode(values, delta, plain);
122 
123             List<Complex> result = new List<Complex>();
124             encoder.Decode(plain, result);
125 
126             for (int i = 0; i < slots; i++)
127             {
128                 double tmp = Math.Abs(values[i].Real - result[i].Real);
129                 Assert.IsTrue(tmp < 0.5);
130             }
131         }
132 
133         [TestMethod]
EncodeDecodeVectorDoubleTest()134         public void EncodeDecodeVectorDoubleTest()
135         {
136             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS)
137             {
138                 PolyModulusDegree = 64,
139                 CoeffModulus = CoeffModulus.Create(64, new int[] { 30, 30 })
140             };
141 
142             SEALContext context = new SEALContext(parms,
143                 expandModChain: false,
144                 secLevel: SecLevelType.None);
145             CKKSEncoder encoder = new CKKSEncoder(context);
146             Plaintext plain = new Plaintext();
147 
148             double[] values = new double[] { 0.1, 2.3, 34.4 };
149             encoder.Encode(values, scale: Math.Pow(2, 20), destination: plain);
150 
151             List<double> result = new List<double>();
152             encoder.Decode(plain, result);
153 
154             Assert.IsNotNull(result);
155             Assert.AreEqual(0.1, result[0], delta: 0.001);
156             Assert.AreEqual(2.3, result[1], delta: 0.001);
157             Assert.AreEqual(34.4, result[2], delta: 0.001);
158         }
159 
160         [TestMethod]
ExceptionsTest()161         public void ExceptionsTest()
162         {
163             EncryptionParameters parms = new EncryptionParameters(SchemeType.CKKS)
164             {
165                 PolyModulusDegree = 64,
166                 CoeffModulus = CoeffModulus.Create(64, new int[] { 30, 30 })
167             };
168 
169             SEALContext context = new SEALContext(parms,
170                 expandModChain: false,
171                 secLevel: SecLevelType.None);
172             CKKSEncoder encoder = new CKKSEncoder(context);
173             List<double> vald = new List<double>();
174             List<double> vald_null = null;
175             List<Complex> valc = new List<Complex>();
176             List<Complex> valc_null = null;
177             Plaintext plain = new Plaintext();
178             Plaintext plain_null = null;
179             MemoryPoolHandle pool = MemoryManager.GetPool(MMProfOpt.ForceGlobal);
180             Complex complex = new Complex(1, 2);
181 
182             Utilities.AssertThrows<ArgumentNullException>(() => encoder = new CKKSEncoder(null));
183 
184             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(vald, ParmsId.Zero, 10.0, plain_null));
185             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(vald, null, 10.0, plain));
186             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(vald_null, ParmsId.Zero, 10.0, plain));
187             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(vald, ParmsId.Zero, 10.0, plain, pool));
188 
189             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(valc, ParmsId.Zero, 10.0, plain_null));
190             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(valc, null, 10.0, plain));
191             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(valc_null, ParmsId.Zero, 10.0, plain));
192             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(valc, ParmsId.Zero, 10.0, plain, pool));
193 
194             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(vald, 10.0, plain_null));
195             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(vald_null, 10.0, plain));
196             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(vald, -10.0, plain, pool));
197 
198             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(valc, 10.0, plain_null));
199             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(valc_null, 10.0, plain));
200             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(valc, -10.0, plain, pool));
201 
202             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10.0, ParmsId.Zero, 20.0, plain_null));
203             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10.0, null, 20.0, plain));
204             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(10.0, ParmsId.Zero, 20.0, plain, pool));
205 
206             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10.0, 20.0, plain_null));
207             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(10.0, -20.0, plain, pool));
208 
209             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(complex, ParmsId.Zero, 10.0, plain_null));
210             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(complex, null, 10.0, plain));
211             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(complex, ParmsId.Zero, 10.0, plain, pool));
212 
213             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(complex, 10.0, plain_null));
214             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(complex, -10.0, plain, pool));
215 
216             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10, ParmsId.Zero, plain_null));
217             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10, null, plain));
218             Utilities.AssertThrows<ArgumentException>(() => encoder.Encode(10, ParmsId.Zero, plain));
219 
220             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Encode(10, plain_null));
221 
222             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Decode(plain, vald_null));
223             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Decode(plain_null, vald));
224             Utilities.AssertThrows<ArgumentException>(() => encoder.Decode(plain, vald, pool));
225 
226             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Decode(plain, valc_null));
227             Utilities.AssertThrows<ArgumentNullException>(() => encoder.Decode(plain_null, valc));
228             Utilities.AssertThrows<ArgumentException>(() => encoder.Decode(plain, valc, pool));
229         }
230     }
231 }
232