1 // ida.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "config.h"
5 
6 #include "ida.h"
7 #include "stdcpp.h"
8 #include "algebra.h"
9 #include "polynomi.h"
10 #include "polynomi.cpp"
11 
12 NAMESPACE_BEGIN(CryptoPP)
13 
14 #if (defined(_MSC_VER) && (_MSC_VER < 1400)) && !defined(__MWERKS__)
15 	// VC60 and VC7 workaround: built-in reverse_iterator has two template parameters, Dinkumware only has one
16 	typedef std::reverse_bidirectional_iterator<const byte *, const byte> RevIt;
17 #elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC)
18 	typedef std::reverse_iterator<const byte *, std::random_access_iterator_tag, const byte> RevIt;
19 #else
20 	typedef std::reverse_iterator<const byte *> RevIt;
21 #endif
22 
IsolatedInitialize(const NameValuePairs & parameters)23 void RawIDA::IsolatedInitialize(const NameValuePairs &parameters)
24 {
25 	if (!parameters.GetIntValue("RecoveryThreshold", m_threshold))
26 		throw InvalidArgument("RawIDA: missing RecoveryThreshold argument");
27 
28 	CRYPTOPP_ASSERT(m_threshold > 0);
29 	if (m_threshold <= 0)
30 		throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0");
31 
32 	m_lastMapPosition = m_inputChannelMap.end();
33 	m_channelsReady = 0;
34 	m_channelsFinished = 0;
35 	m_w.New(m_threshold);
36 	m_y.New(m_threshold);
37 	m_inputQueues.reserve(m_threshold);
38 
39 	m_outputChannelIds.clear();
40 	m_outputChannelIdStrings.clear();
41 	m_outputQueues.clear();
42 
43 	word32 outputChannelID;
44 	if (parameters.GetValue("OutputChannelID", outputChannelID))
45 		AddOutputChannel(outputChannelID);
46 	else
47 	{
48 		int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold);
49 		CRYPTOPP_ASSERT(nShares > 0);
50 		if (nShares <= 0) {nShares = m_threshold;}
51 		for (unsigned int i=0; i< (unsigned int)(nShares); i++)
52 			AddOutputChannel(i);
53 	}
54 }
55 
InsertInputChannel(word32 channelId)56 unsigned int RawIDA::InsertInputChannel(word32 channelId)
57 {
58 	if (m_lastMapPosition != m_inputChannelMap.end())
59 	{
60 		if (m_lastMapPosition->first == channelId)
61 			goto skipFind;
62 		++m_lastMapPosition;
63 		if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId)
64 			goto skipFind;
65 	}
66 	m_lastMapPosition = m_inputChannelMap.find(channelId);
67 
68 skipFind:
69 	if (m_lastMapPosition == m_inputChannelMap.end())
70 	{
71 		if (m_inputChannelIds.size() == size_t(m_threshold))
72 			return m_threshold;
73 
74 		m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first;
75 		m_inputQueues.push_back(MessageQueue());
76 		m_inputChannelIds.push_back(channelId);
77 
78 		if (m_inputChannelIds.size() == size_t(m_threshold))
79 			PrepareInterpolation();
80 	}
81 	return m_lastMapPosition->second;
82 }
83 
LookupInputChannel(word32 channelId) const84 unsigned int RawIDA::LookupInputChannel(word32 channelId) const
85 {
86 	std::map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId);
87 	if (it == m_inputChannelMap.end())
88 		return m_threshold;
89 	else
90 		return it->second;
91 }
92 
ChannelData(word32 channelId,const byte * inString,size_t length,bool messageEnd)93 void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd)
94 {
95 	int i = InsertInputChannel(channelId);
96 	if (i < m_threshold)
97 	{
98 		lword size = m_inputQueues[i].MaxRetrievable();
99 		m_inputQueues[i].Put(inString, length);
100 		if (size < 4 && size + length >= 4)
101 		{
102 			m_channelsReady++;
103 			if (m_channelsReady == size_t(m_threshold))
104 				ProcessInputQueues();
105 		}
106 
107 		if (messageEnd)
108 		{
109 			m_inputQueues[i].MessageEnd();
110 			if (m_inputQueues[i].NumberOfMessages() == 1)
111 			{
112 				m_channelsFinished++;
113 				if (m_channelsFinished == size_t(m_threshold))
114 				{
115 					m_channelsReady = 0;
116 					for (i=0; i<m_threshold; i++)
117 						m_channelsReady += m_inputQueues[i].AnyRetrievable();
118 					ProcessInputQueues();
119 				}
120 			}
121 		}
122 	}
123 }
124 
InputBuffered(word32 channelId) const125 lword RawIDA::InputBuffered(word32 channelId) const
126 {
127 	int i = LookupInputChannel(channelId);
128 	return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0;
129 }
130 
ComputeV(unsigned int i)131 void RawIDA::ComputeV(unsigned int i)
132 {
133 	if (i >= m_v.size())
134 	{
135 		m_v.resize(i+1);
136 		m_outputToInput.resize(i+1);
137 	}
138 
139 	m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]);
140 	if (m_outputToInput[i] == size_t(m_threshold) && i * size_t(m_threshold) <= 1000*1000)
141 	{
142 		m_v[i].resize(m_threshold);
143 		PrepareBulkPolynomialInterpolationAt(m_gf32, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
144 	}
145 }
146 
AddOutputChannel(word32 channelId)147 void RawIDA::AddOutputChannel(word32 channelId)
148 {
149 	m_outputChannelIds.push_back(channelId);
150 	m_outputChannelIdStrings.push_back(WordToString(channelId));
151 	m_outputQueues.push_back(ByteQueue());
152 	if (m_inputChannelIds.size() == size_t(m_threshold))
153 		ComputeV((unsigned int)m_outputChannelIds.size() - 1);
154 }
155 
PrepareInterpolation()156 void RawIDA::PrepareInterpolation()
157 {
158 	CRYPTOPP_ASSERT(m_inputChannelIds.size() == size_t(m_threshold));
159 	PrepareBulkPolynomialInterpolation(m_gf32, m_w.begin(), &(m_inputChannelIds[0]), (unsigned int)(m_threshold));
160 	for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
161 		ComputeV(i);
162 }
163 
ProcessInputQueues()164 void RawIDA::ProcessInputQueues()
165 {
166 	bool finished = (m_channelsFinished == size_t(m_threshold));
167 	unsigned int i;
168 
169 	while (finished ? m_channelsReady > 0 : m_channelsReady == size_t(m_threshold))
170 	{
171 		m_channelsReady = 0;
172 		for (i=0; i<size_t(m_threshold); i++)
173 		{
174 			MessageQueue &queue = m_inputQueues[i];
175 			queue.GetWord32(m_y[i]);
176 
177 			if (finished)
178 				m_channelsReady += queue.AnyRetrievable();
179 			else
180 				m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4;
181 		}
182 
183 		for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++)
184 		{
185 			if (m_outputToInput[i] != size_t(m_threshold))
186 				m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]);
187 			else if (m_v[i].size() == size_t(m_threshold))
188 				m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(m_gf32, m_y.begin(), m_v[i].begin(), m_threshold));
189 			else
190 			{
191 				m_u.resize(m_threshold);
192 				PrepareBulkPolynomialInterpolationAt(m_gf32, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
193 				m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(m_gf32, m_y.begin(), m_u.begin(), m_threshold));
194 			}
195 		}
196 	}
197 
198 	if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable())
199 		FlushOutputQueues();
200 
201 	if (finished)
202 	{
203 		OutputMessageEnds();
204 
205 		m_channelsReady = 0;
206 		m_channelsFinished = 0;
207 		m_v.clear();
208 
209 		std::vector<MessageQueue> inputQueues;
210 		std::vector<word32> inputChannelIds;
211 
212 		inputQueues.swap(m_inputQueues);
213 		inputChannelIds.swap(m_inputChannelIds);
214 		m_inputChannelMap.clear();
215 		m_lastMapPosition = m_inputChannelMap.end();
216 
217 		for (i=0; i<size_t(m_threshold); i++)
218 		{
219 			inputQueues[i].GetNextMessage();
220 			inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i]));
221 		}
222 	}
223 }
224 
FlushOutputQueues()225 void RawIDA::FlushOutputQueues()
226 {
227 	for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
228 		m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]);
229 }
230 
OutputMessageEnds()231 void RawIDA::OutputMessageEnds()
232 {
233 	if (GetAutoSignalPropagation() != 0)
234 	{
235 		for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
236 			AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1);
237 	}
238 }
239 
240 // ****************************************************************
241 
IsolatedInitialize(const NameValuePairs & parameters)242 void SecretSharing::IsolatedInitialize(const NameValuePairs &parameters)
243 {
244 	m_pad = parameters.GetValueWithDefault("AddPadding", true);
245 	m_ida.IsolatedInitialize(parameters);
246 }
247 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)248 size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
249 {
250 	if (!blocking)
251 		throw BlockingInputOnly("SecretSharing");
252 
253 	SecByteBlock buf(UnsignedMin(256, length));
254 	unsigned int threshold = m_ida.GetThreshold();
255 	while (length > 0)
256 	{
257 		size_t len = STDMIN(length, buf.size());
258 		m_ida.ChannelData(0xffffffff, begin, len, false);
259 		for (unsigned int i=0; i<threshold-1; i++)
260 		{
261 			m_rng.GenerateBlock(buf, len);
262 			m_ida.ChannelData(i, buf, len, false);
263 		}
264 		length -= len;
265 		begin += len;
266 	}
267 
268 	if (messageEnd)
269 	{
270 		m_ida.SetAutoSignalPropagation(messageEnd-1);
271 		if (m_pad)
272 		{
273 			SecretSharing::Put(1);
274 			while (m_ida.InputBuffered(0xffffffff) > 0)
275 				SecretSharing::Put(0);
276 		}
277 		m_ida.ChannelData(0xffffffff, NULLPTR, 0, true);
278 		for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++)
279 			m_ida.ChannelData(i, NULLPTR, 0, true);
280 	}
281 
282 	return 0;
283 }
284 
IsolatedInitialize(const NameValuePairs & parameters)285 void SecretRecovery::IsolatedInitialize(const NameValuePairs &parameters)
286 {
287 	m_pad = parameters.GetValueWithDefault("RemovePadding", true);
288 	RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff)));
289 }
290 
FlushOutputQueues()291 void SecretRecovery::FlushOutputQueues()
292 {
293 	if (m_pad)
294 		m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4);
295 	else
296 		m_outputQueues[0].TransferTo(*AttachedTransformation());
297 }
298 
OutputMessageEnds()299 void SecretRecovery::OutputMessageEnds()
300 {
301 	if (m_pad)
302 	{
303 		PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
304 		m_outputQueues[0].TransferAllTo(paddingRemover);
305 	}
306 
307 	if (GetAutoSignalPropagation() != 0)
308 		AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
309 }
310 
311 // ****************************************************************
312 
IsolatedInitialize(const NameValuePairs & parameters)313 void InformationDispersal::IsolatedInitialize(const NameValuePairs &parameters)
314 {
315 	m_nextChannel = 0;
316 	m_pad = parameters.GetValueWithDefault("AddPadding", true);
317 	m_ida.IsolatedInitialize(parameters);
318 }
319 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)320 size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
321 {
322 	if (!blocking)
323 		throw BlockingInputOnly("InformationDispersal");
324 
325 	while (length--)
326 	{
327 		m_ida.ChannelData(m_nextChannel, begin, 1, false);
328 		begin++;
329 		m_nextChannel++;
330 		if (m_nextChannel == m_ida.GetThreshold())
331 			m_nextChannel = 0;
332 	}
333 
334 	if (messageEnd)
335 	{
336 		m_ida.SetAutoSignalPropagation(messageEnd-1);
337 		if (m_pad)
338 			InformationDispersal::Put(1);
339 		for (word32 i=0; i<m_ida.GetThreshold(); i++)
340 			m_ida.ChannelData(i, NULLPTR, 0, true);
341 	}
342 
343 	return 0;
344 }
345 
IsolatedInitialize(const NameValuePairs & parameters)346 void InformationRecovery::IsolatedInitialize(const NameValuePairs &parameters)
347 {
348 	m_pad = parameters.GetValueWithDefault("RemovePadding", true);
349 	RawIDA::IsolatedInitialize(parameters);
350 }
351 
FlushOutputQueues()352 void InformationRecovery::FlushOutputQueues()
353 {
354 	while (m_outputQueues[0].AnyRetrievable())
355 	{
356 		for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
357 			m_outputQueues[i].TransferTo(m_queue, 1);
358 	}
359 
360 	if (m_pad)
361 		m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold);
362 	else
363 		m_queue.TransferTo(*AttachedTransformation());
364 }
365 
OutputMessageEnds()366 void InformationRecovery::OutputMessageEnds()
367 {
368 	if (m_pad)
369 	{
370 		PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
371 		m_queue.TransferAllTo(paddingRemover);
372 	}
373 
374 	if (GetAutoSignalPropagation() != 0)
375 		AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
376 }
377 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)378 size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
379 {
380 	if (!blocking)
381 		throw BlockingInputOnly("PaddingRemover");
382 
383 	const byte *const end = begin + length;
384 
385 	if (m_possiblePadding)
386 	{
387 		size_t len = FindIfNot(begin, end, byte(0)) - begin;
388 		m_zeroCount += len;
389 		begin += len;
390 		if (begin == end)
391 			return 0;
392 
393 		AttachedTransformation()->Put(1);
394 		while (m_zeroCount--)
395 			AttachedTransformation()->Put(0);
396 		AttachedTransformation()->Put(*begin++);
397 		m_possiblePadding = false;
398 	}
399 
400 	const byte *x = FindIfNot(RevIt(end), RevIt(begin), byte(0)).base();
401 	if (x != begin && *(x-1) == 1)
402 	{
403 		AttachedTransformation()->Put(begin, x-begin-1);
404 		m_possiblePadding = true;
405 		m_zeroCount = end - x;
406 	}
407 	else
408 		AttachedTransformation()->Put(begin, end-begin);
409 
410 	if (messageEnd)
411 	{
412 		m_possiblePadding = false;
413 		Output(0, begin, length, messageEnd, blocking);
414 	}
415 	return 0;
416 }
417 
418 NAMESPACE_END
419