-
Notifications
You must be signed in to change notification settings - Fork 765
/
EaCx.cpp
230 lines (203 loc) · 8.18 KB
/
EaCx.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
/* Copyright (C) 2012-2017 IBM Corp.
* This program is Licensed under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. See accompanying LICENSE file.
*/
/* EaCx.cpp - Encoding/decoding and data-movement for encrypted complex data
*/
#include <algorithm>
#include <type_traits>
#include "zzX.h"
#include "EncryptedArray.h"
#include "timing.h"
#include "clonedPtr.h"
#include "norms.h"
#include "debugging.h"
NTL_CLIENT
static constexpr cx_double the_imaginary_i = cx_double(0.0, 1.0);
void EncryptedArrayCx::decrypt(const Ctxt& ctxt,
const FHESecKey& sKey, vector<cx_double>& ptxt) const
{
//OLD: assert(&getContext() == &ctxt.getContext());
helib::assertEq(&getContext(), &ctxt.getContext(), "Cannot decrypt with non-matching contextx");
NTL::ZZX pp;
sKey.Decrypt(pp, ctxt);
// convert to zzX, if the pp is too big, scale it down
long nBits = NTL::MaxBits(pp) - NTL_SP_NBITS;
zzX zpp(INIT_SIZE, deg(pp)+1);
double factor;
if (nBits<=0) { // convert to zzX, double
for (long i=0; i<lsize(zpp); i++)
conv(zpp[i], pp[i]);
factor = to_double(ctxt.getRatFactor());
} else { // scale and then convert to zzX, double
for (long i=0; i<lsize(zpp); i++)
conv(zpp[i], pp[i]>>nBits);
factor = to_double(ctxt.getRatFactor()/power2_xdouble(nBits));
}
canonicalEmbedding(ptxt, zpp, getPAlgebra()); // decode without scaling
for (cx_double& cx : ptxt) // divide by the factor
cx /= factor;
}
// rotate ciphertext in dimension 0 by amt
void EncryptedArrayCx::rotate1D(Ctxt& ctxt, long i, long amt, bool dc) const
{
//OLD: assert(&getContext() == &ctxt.getContext());
helib::assertEq(&getContext(), &ctxt.getContext(), "Cannot decrypt with non-matching contextx");
//OLD: assert(nativeDimension(i));
helib::assertTrue(nativeDimension(i), "Rotation in " + std::to_string(i) + " is not a native operation");
const PAlgebra& palg = getPAlgebra();
long ord = sizeOfDimension(i);
amt %= ord;// DIRT: assumes division w/ remainder follows C++11 and C99 rules
if (amt == 0) return;
ctxt.smartAutomorph(palg.genToPow(i, amt));
}
// Shift k positions along the i'th dimension with zero fill.
// Negative shift amount denotes shift in the opposite direction.
void EncryptedArrayCx::shift1D(Ctxt& ctxt, long i, long k) const
{
throw helib::LogicError("EncryptedArrayCx::shift1D not implemented");
}
// We only support linear arrays for approximate numbers,
// so rotate,shift are the same as rotate1D, shift1D
void EncryptedArrayCx::rotate(Ctxt& ctxt, long amt) const
{
rotate1D(ctxt, 0, amt, true);
}
void EncryptedArrayCx::shift(Ctxt& ctxt, long amt) const
{
shift1D(ctxt, 0, amt);
}
double EncryptedArrayCx::encode(zzX& ptxt, const vector<cx_double>& array,
double useThisSize, long precision) const
{
if (useThisSize < 0) for (auto& x : array) {
if (useThisSize < std::abs(x))
useThisSize = std::abs(x);
}
if (useThisSize <= 0)
useThisSize = 1.0;
// This factor ensures that encode/decode introduce less than 1/precision
// error. If precision=0 then the error bound defaults to 2^{-almod.getR()}
double factor = encodeScalingFactor(precision)/useThisSize;
embedInSlots(ptxt, array, getPAlgebra(), factor);
return factor;
}
double EncryptedArrayCx::encode(zzX& ptxt, double num,
double useThisSize, long precision) const
{
// This factor ensures that encode/decode introduce less than
// 1/precision error. If precision=0 then the scaling factor
// defaults to encodeScalingFactor(), corresponding to precision
// error bound of 2^{-almod.getR()}
if (useThisSize <= 0)
useThisSize = roundedSize(num);
double factor = encodeScalingFactor(precision)/useThisSize;
resize(ptxt, 1, long(round(num*factor))); // Constant polynomial
return factor;
}
double EncryptedArrayCx::encodei(zzX& ptxt, long precision) const
{
vector<cx_double> v(size(), the_imaginary_i); // i in all the slots
return this->encode(ptxt, v, /*size=*/1.0, precision);
}
const zzX& EncryptedArrayCx::getiEncoded() const
{
if (lsize(iEncoded)<=0) // encoded-i not yet initialized
encodei(const_cast<zzX&>(iEncoded)); // temporarily suspend cont-ness
return iEncoded;
}
void EncryptedArrayCx::decode(vector<cx_double>& array, const zzX& ptxt, double scaling) const
{
//OLD: assert (scaling>0);
helib::assertTrue<helib::InvalidArgument>(scaling>0, "Scaling must be positive to decode");
canonicalEmbedding(array, ptxt, getPAlgebra());
for (auto& x: array) x /= scaling;
}
// return an array of random complex numbers in a circle of radius rad
void EncryptedArrayCx::random(vector<cx_double>& array, double rad) const
{
const double twoPi = 8 * std::atan(1);
if (rad==0) rad = 1.0; // radius
resize(array, size()); // allocate space
for (auto& x : array) {
long bits = NTL::RandomLen_long(32); // 32 random bits
double r = std::sqrt(bits & 0xffff)/256.0; // sqrt(uniform[0,1])
double theta = twoPi * ((bits>>16)& 0xffff) / 65536.0; // uniform(0,2pi)
x = std::polar(rad*r,theta);
}
}
void EncryptedArrayCx::extractRealPart(Ctxt& c) const
{
Ctxt tmp = c;
tmp.complexConj(); // the complex conjugate of c
c += tmp; // c + conj(c) = 2*real(c)
c.multByConstantCKKS(0.5); // divide by two
}
// Note: If called with dcrt==nullptr, it will perform FFT's when
// encoding i as a DoubleCRT object. If called with dcrt!=nullptr,
// it assumes that dcrt points to an object that encodes i. If the
// primeSet of the given DoubleCRT is missing some of the moduli in
// c.getPrimeSet(), many extra FFTs/iFFTs will be called.
void EncryptedArrayCx::extractImPart(Ctxt& c, DoubleCRT* iDcrtPtr) const
{
DoubleCRT tmpDcrt(getContext(), IndexSet::emptySet());
{Ctxt tmp = c;
c.complexConj(); // the complex conjugate of c
c -= tmp; // conj(c) - c = -2*i*imaginary(c)
}
if (iDcrtPtr==nullptr) { // Need to encode i in a DoubleCRt object
tmpDcrt.addPrimes(c.getPrimeSet());
const zzX& iEncoded = getiEncoded();
tmpDcrt.FFT(iEncoded, c.getPrimeSet());
// FFT is a low-level DoubleCRT procedure to initialize an
// existing object with a given PrimeSet and a given polynomial
iDcrtPtr = &tmpDcrt;
}
c.multByConstantCKKS(*iDcrtPtr); // multiply by i
c.multByConstantCKKS(0.5); // divide by two
}
double EncryptedArrayCx::buildLinPolyCoeffs(vector<zzX>& C,
const cx_double& oneImage, const cx_double& iImage,
long precision) const
{
resize(C,2); // allocate space
// Compute the constants x,y such that L(z) = x*z + y*conjugate(z)
cx_double x = (oneImage - the_imaginary_i*iImage)*0.5;
cx_double y = (oneImage + the_imaginary_i*iImage)*0.5;
double sizex = std::abs(x);
double sizey = std::abs(y);
double msize = roundedSize(std::max(sizex,sizey));
// Encode x,y in zzX objects
long n = size();
vector<cx_double> v(n, x); // x in all the slots
encode(C[0], v, msize, precision);
v.assign(n, y); // y in all the slots
return encode(C[1], v, msize, precision);
}
double EncryptedArrayCx::buildLinPolyCoeffs(vector<zzX>& C,
const vector<cx_double>&oneImages, const vector<cx_double>&iImages,
long precision) const
{
resize(C,2); // allocate space
// Compute the constants x,y such that L(z) = x*z + y*conjugate(z)
vector<cx_double> x(size());
vector<cx_double> y(size());
double msize = 0.0;
for (long j=0; j<size(); j++) {
x[j] = (oneImages[j] - the_imaginary_i*iImages[j])*0.5;
y[j] = (oneImages[j] + the_imaginary_i*iImages[j])*0.5;
if (msize < std::abs(x[j])) msize = std::abs(x[j]);
if (msize < std::abs(y[j])) msize = std::abs(y[j]);
}
// Encode x,y in zzX objects
msize = roundedSize(msize);
encode(C[0], x, msize, precision);
encode(C[1], y, msize, precision);
}