1 // filters.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "config.h"
5 
6 #if CRYPTOPP_MSC_VERSION
7 # pragma warning(disable: 4100 4189 4355)
8 #endif
9 
10 #if CRYPTOPP_GCC_DIAGNOSTIC_AVAILABLE
11 # pragma GCC diagnostic ignored "-Wunused-value"
12 #endif
13 
14 #ifndef CRYPTOPP_IMPORTS
15 
16 #include "filters.h"
17 #include "mqueue.h"
18 #include "fltrimpl.h"
19 #include "argnames.h"
20 #include "smartptr.h"
21 #include "stdcpp.h"
22 #include "misc.h"
23 
NAMESPACE_BEGIN(CryptoPP)24 NAMESPACE_BEGIN(CryptoPP)
25 
26 Filter::Filter(BufferedTransformation *attachment)
27 	: m_attachment(attachment), m_inputPosition(0), m_continueAt(0)
28 {
29 }
30 
NewDefaultAttachment() const31 BufferedTransformation * Filter::NewDefaultAttachment() const
32 {
33 	return new MessageQueue;
34 }
35 
AttachedTransformation()36 BufferedTransformation * Filter::AttachedTransformation()
37 {
38 	if (m_attachment.get() == NULLPTR)
39 		m_attachment.reset(NewDefaultAttachment());
40 	return m_attachment.get();
41 }
42 
AttachedTransformation() const43 const BufferedTransformation *Filter::AttachedTransformation() const
44 {
45 	if (m_attachment.get() == NULLPTR)
46 		const_cast<Filter *>(this)->m_attachment.reset(NewDefaultAttachment());
47 	return m_attachment.get();
48 }
49 
Detach(BufferedTransformation * newOut)50 void Filter::Detach(BufferedTransformation *newOut)
51 {
52 	m_attachment.reset(newOut);
53 }
54 
Insert(Filter * filter)55 void Filter::Insert(Filter *filter)
56 {
57 	filter->m_attachment.reset(m_attachment.release());
58 	m_attachment.reset(filter);
59 }
60 
CopyRangeTo2(BufferedTransformation & target,lword & begin,lword end,const std::string & channel,bool blocking) const61 size_t Filter::CopyRangeTo2(BufferedTransformation &target, lword &begin, lword end, const std::string &channel, bool blocking) const
62 {
63 	return AttachedTransformation()->CopyRangeTo2(target, begin, end, channel, blocking);
64 }
65 
TransferTo2(BufferedTransformation & target,lword & transferBytes,const std::string & channel,bool blocking)66 size_t Filter::TransferTo2(BufferedTransformation &target, lword &transferBytes, const std::string &channel, bool blocking)
67 {
68 	return AttachedTransformation()->TransferTo2(target, transferBytes, channel, blocking);
69 }
70 
Initialize(const NameValuePairs & parameters,int propagation)71 void Filter::Initialize(const NameValuePairs &parameters, int propagation)
72 {
73 	m_inputPosition = m_continueAt = 0;
74 	IsolatedInitialize(parameters);
75 	PropagateInitialize(parameters, propagation);
76 }
77 
Flush(bool hardFlush,int propagation,bool blocking)78 bool Filter::Flush(bool hardFlush, int propagation, bool blocking)
79 {
80 	switch (m_continueAt)
81 	{
82 	case 0:
83 		if (IsolatedFlush(hardFlush, blocking))
84 			return true;
85 		// fall through
86 	case 1:
87 		if (OutputFlush(1, hardFlush, propagation, blocking))
88 			return true;
89 		// fall through
90 	default: ;
91 	}
92 	return false;
93 }
94 
MessageSeriesEnd(int propagation,bool blocking)95 bool Filter::MessageSeriesEnd(int propagation, bool blocking)
96 {
97 	switch (m_continueAt)
98 	{
99 	case 0:
100 		if (IsolatedMessageSeriesEnd(blocking))
101 			return true;
102 		// fall through
103 	case 1:
104 		if (ShouldPropagateMessageSeriesEnd() && OutputMessageSeriesEnd(1, propagation, blocking))
105 			return true;
106 		// fall through
107 	default: ;
108 	}
109 	return false;
110 }
111 
PropagateInitialize(const NameValuePairs & parameters,int propagation)112 void Filter::PropagateInitialize(const NameValuePairs &parameters, int propagation)
113 {
114 	if (propagation)
115 		AttachedTransformation()->Initialize(parameters, propagation-1);
116 }
117 
OutputModifiable(int outputSite,byte * inString,size_t length,int messageEnd,bool blocking,const std::string & channel)118 size_t Filter::OutputModifiable(int outputSite, byte *inString, size_t length, int messageEnd, bool blocking, const std::string &channel)
119 {
120 	if (messageEnd)
121 		messageEnd--;
122 	size_t result = AttachedTransformation()->ChannelPutModifiable2(channel, inString, length, messageEnd, blocking);
123 	m_continueAt = result ? outputSite : 0;
124 	return result;
125 }
126 
Output(int outputSite,const byte * inString,size_t length,int messageEnd,bool blocking,const std::string & channel)127 size_t Filter::Output(int outputSite, const byte *inString, size_t length, int messageEnd, bool blocking, const std::string &channel)
128 {
129 	if (messageEnd)
130 		messageEnd--;
131 	size_t result = AttachedTransformation()->ChannelPut2(channel, inString, length, messageEnd, blocking);
132 	m_continueAt = result ? outputSite : 0;
133 	return result;
134 }
135 
OutputFlush(int outputSite,bool hardFlush,int propagation,bool blocking,const std::string & channel)136 bool Filter::OutputFlush(int outputSite, bool hardFlush, int propagation, bool blocking, const std::string &channel)
137 {
138 	if (propagation && AttachedTransformation()->ChannelFlush(channel, hardFlush, propagation-1, blocking))
139 	{
140 		m_continueAt = outputSite;
141 		return true;
142 	}
143 	m_continueAt = 0;
144 	return false;
145 }
146 
OutputMessageSeriesEnd(int outputSite,int propagation,bool blocking,const std::string & channel)147 bool Filter::OutputMessageSeriesEnd(int outputSite, int propagation, bool blocking, const std::string &channel)
148 {
149 	if (propagation && AttachedTransformation()->ChannelMessageSeriesEnd(channel, propagation-1, blocking))
150 	{
151 		m_continueAt = outputSite;
152 		return true;
153 	}
154 	m_continueAt = 0;
155 	return false;
156 }
157 
158 // *************************************************************
159 
ResetMeter()160 void MeterFilter::ResetMeter()
161 {
162 	m_currentMessageBytes = m_totalBytes = m_currentSeriesMessages = m_totalMessages = m_totalMessageSeries = 0;
163 	m_rangesToSkip.clear();
164 }
165 
AddRangeToSkip(unsigned int message,lword position,lword size,bool sortNow)166 void MeterFilter::AddRangeToSkip(unsigned int message, lword position, lword size, bool sortNow)
167 {
168 	MessageRange r = {message, position, size};
169 	m_rangesToSkip.push_back(r);
170 	if (sortNow)
171 		std::sort(m_rangesToSkip.begin(), m_rangesToSkip.end());
172 }
173 
PutMaybeModifiable(byte * begin,size_t length,int messageEnd,bool blocking,bool modifiable)174 size_t MeterFilter::PutMaybeModifiable(byte *begin, size_t length, int messageEnd, bool blocking, bool modifiable)
175 {
176 	if (!m_transparent)
177 		return 0;
178 
179 	size_t t;
180 	FILTER_BEGIN;
181 
182 	m_begin = begin;
183 	m_length = length;
184 
185 	while (m_length > 0 || messageEnd)
186 	{
187 		if (m_length > 0  && !m_rangesToSkip.empty() && m_rangesToSkip.front().message == m_totalMessages && m_currentMessageBytes + m_length > m_rangesToSkip.front().position)
188 		{
189 			FILTER_OUTPUT_MAYBE_MODIFIABLE(1, m_begin, t = (size_t)SaturatingSubtract(m_rangesToSkip.front().position, m_currentMessageBytes), false, modifiable);
190 
191 			CRYPTOPP_ASSERT(t < m_length);
192 			m_begin = PtrAdd(m_begin, t);
193 			m_length -= t;
194 			m_currentMessageBytes += t;
195 			m_totalBytes += t;
196 
197 			if (m_currentMessageBytes + m_length < m_rangesToSkip.front().position + m_rangesToSkip.front().size)
198 				t = m_length;
199 			else
200 			{
201 				t = (size_t)SaturatingSubtract(m_rangesToSkip.front().position + m_rangesToSkip.front().size, m_currentMessageBytes);
202 				CRYPTOPP_ASSERT(t <= m_length);
203 				m_rangesToSkip.pop_front();
204 			}
205 
206 			m_begin = PtrAdd(m_begin, t);
207 			m_length -= t;
208 			m_currentMessageBytes += t;
209 			m_totalBytes += t;
210 		}
211 		else
212 		{
213 			FILTER_OUTPUT_MAYBE_MODIFIABLE(2, m_begin, m_length, messageEnd, modifiable);
214 
215 			m_currentMessageBytes += m_length;
216 			m_totalBytes += m_length;
217 			m_length = 0;
218 
219 			if (messageEnd)
220 			{
221 				m_currentMessageBytes = 0;
222 				m_currentSeriesMessages++;
223 				m_totalMessages++;
224 				messageEnd = false;
225 			}
226 		}
227 	}
228 
229 	FILTER_END_NO_MESSAGE_END;
230 }
231 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)232 size_t MeterFilter::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
233 {
234 	return PutMaybeModifiable(const_cast<byte *>(begin), length, messageEnd, blocking, false);
235 }
236 
PutModifiable2(byte * begin,size_t length,int messageEnd,bool blocking)237 size_t MeterFilter::PutModifiable2(byte *begin, size_t length, int messageEnd, bool blocking)
238 {
239 	return PutMaybeModifiable(begin, length, messageEnd, blocking, true);
240 }
241 
IsolatedMessageSeriesEnd(bool blocking)242 bool MeterFilter::IsolatedMessageSeriesEnd(bool blocking)
243 {
244 	CRYPTOPP_UNUSED(blocking);
245 	m_currentMessageBytes = 0;
246 	m_currentSeriesMessages = 0;
247 	m_totalMessageSeries++;
248 	return false;
249 }
250 
251 // *************************************************************
252 
ResetQueue(size_t blockSize,size_t maxBlocks)253 void FilterWithBufferedInput::BlockQueue::ResetQueue(size_t blockSize, size_t maxBlocks)
254 {
255 	m_buffer.New(blockSize * maxBlocks);
256 	m_blockSize = blockSize;
257 	m_maxBlocks = maxBlocks;
258 	m_size = 0;
259 	m_begin = m_buffer;
260 }
261 
GetBlock()262 byte *FilterWithBufferedInput::BlockQueue::GetBlock()
263 {
264 	if (m_size >= m_blockSize)
265 	{
266 		byte *ptr = m_begin;
267 		if ((m_begin = PtrAdd(m_begin, m_blockSize)) == m_buffer.end())
268 			m_begin = m_buffer;
269 		m_size -= m_blockSize;
270 		return ptr;
271 	}
272 	else
273 		return NULLPTR;
274 }
275 
GetContigousBlocks(size_t & numberOfBytes)276 byte *FilterWithBufferedInput::BlockQueue::GetContigousBlocks(size_t &numberOfBytes)
277 {
278 	numberOfBytes = STDMIN(numberOfBytes, STDMIN<size_t>(PtrDiff(m_buffer.end(), m_begin), m_size));
279 	byte *ptr = m_begin;
280 	m_begin = PtrAdd(m_begin, numberOfBytes);
281 	m_size -= numberOfBytes;
282 	if (m_size == 0 || m_begin == m_buffer.end())
283 		m_begin = m_buffer;
284 	return ptr;
285 }
286 
GetAll(byte * outString)287 size_t FilterWithBufferedInput::BlockQueue::GetAll(byte *outString)
288 {
289 	// Avoid passing NULL pointer to memcpy
290 	if (!outString) return 0;
291 
292 	size_t size = m_size;
293 	size_t numberOfBytes = m_maxBlocks*m_blockSize;
294 	const byte *ptr = GetContigousBlocks(numberOfBytes);
295 	memcpy(outString, ptr, numberOfBytes);
296 	memcpy(PtrAdd(outString, numberOfBytes), m_begin, m_size);
297 	m_size = 0;
298 	return size;
299 }
300 
Put(const byte * inString,size_t length)301 void FilterWithBufferedInput::BlockQueue::Put(const byte *inString, size_t length)
302 {
303 	// Avoid passing NULL pointer to memcpy
304 	if (!inString || !length) return;
305 
306 	CRYPTOPP_ASSERT(m_size + length <= m_buffer.size());
307 	byte *end = (m_size < static_cast<size_t>(PtrDiff(m_buffer.end(), m_begin)) ?
308 		PtrAdd(m_begin, m_size) : PtrAdd(m_begin, m_size - m_buffer.size()));
309 	size_t len = STDMIN(length, size_t(m_buffer.end()-end));
310 	memcpy(end, inString, len);
311 	if (len < length)
312 		memcpy(m_buffer, PtrAdd(inString, len), length-len);
313 	m_size += length;
314 }
315 
FilterWithBufferedInput(BufferedTransformation * attachment)316 FilterWithBufferedInput::FilterWithBufferedInput(BufferedTransformation *attachment)
317 	: Filter(attachment), m_firstSize(SIZE_MAX), m_blockSize(0), m_lastSize(SIZE_MAX), m_firstInputDone(false)
318 {
319 }
320 
FilterWithBufferedInput(size_t firstSize,size_t blockSize,size_t lastSize,BufferedTransformation * attachment)321 FilterWithBufferedInput::FilterWithBufferedInput(size_t firstSize, size_t blockSize, size_t lastSize, BufferedTransformation *attachment)
322 	: Filter(attachment), m_firstSize(firstSize), m_blockSize(blockSize), m_lastSize(lastSize), m_firstInputDone(false)
323 {
324 	if (m_firstSize == SIZE_MAX || m_blockSize < 1 || m_lastSize == SIZE_MAX)
325 		throw InvalidArgument("FilterWithBufferedInput: invalid buffer size");
326 
327 	m_queue.ResetQueue(1, m_firstSize);
328 }
329 
IsolatedInitialize(const NameValuePairs & parameters)330 void FilterWithBufferedInput::IsolatedInitialize(const NameValuePairs &parameters)
331 {
332 	InitializeDerivedAndReturnNewSizes(parameters, m_firstSize, m_blockSize, m_lastSize);
333 	if (m_firstSize == SIZE_MAX || m_blockSize < 1 || m_lastSize == SIZE_MAX)
334 		throw InvalidArgument("FilterWithBufferedInput: invalid buffer size");
335 	m_queue.ResetQueue(1, m_firstSize);
336 	m_firstInputDone = false;
337 }
338 
IsolatedFlush(bool hardFlush,bool blocking)339 bool FilterWithBufferedInput::IsolatedFlush(bool hardFlush, bool blocking)
340 {
341 	if (!blocking)
342 		throw BlockingInputOnly("FilterWithBufferedInput");
343 
344 	if (hardFlush)
345 		ForceNextPut();
346 	FlushDerived();
347 
348 	return false;
349 }
350 
PutMaybeModifiable(byte * inString,size_t length,int messageEnd,bool blocking,bool modifiable)351 size_t FilterWithBufferedInput::PutMaybeModifiable(byte *inString, size_t length, int messageEnd, bool blocking, bool modifiable)
352 {
353 	if (!blocking)
354 		throw BlockingInputOnly("FilterWithBufferedInput");
355 
356 	if (length != 0)
357 	{
358 		size_t newLength = m_queue.CurrentSize() + length;
359 
360 		if (!m_firstInputDone && newLength >= m_firstSize)
361 		{
362 			size_t len = m_firstSize - m_queue.CurrentSize();
363 			m_queue.Put(inString, len);
364 			FirstPut(m_queue.GetContigousBlocks(m_firstSize));
365 			CRYPTOPP_ASSERT(m_queue.CurrentSize() == 0);
366 			m_queue.ResetQueue(m_blockSize, (2*m_blockSize+m_lastSize-2)/m_blockSize);
367 
368 			inString = PtrAdd(inString, len);
369 			newLength -= m_firstSize;
370 			m_firstInputDone = true;
371 		}
372 
373 		if (m_firstInputDone)
374 		{
375 			if (m_blockSize == 1)
376 			{
377 				while (newLength > m_lastSize && m_queue.CurrentSize() > 0)
378 				{
379 					size_t len = newLength - m_lastSize;
380 					byte *ptr = m_queue.GetContigousBlocks(len);
381 					NextPutModifiable(ptr, len);
382 					newLength -= len;
383 				}
384 
385 				if (newLength > m_lastSize)
386 				{
387 					size_t len = newLength - m_lastSize;
388 					NextPutMaybeModifiable(inString, len, modifiable);
389 					inString = PtrAdd(inString, len);
390 					newLength -= len;
391 				}
392 			}
393 			else
394 			{
395 				while (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() >= m_blockSize)
396 				{
397 					NextPutModifiable(m_queue.GetBlock(), m_blockSize);
398 					newLength -= m_blockSize;
399 				}
400 
401 				if (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() > 0)
402 				{
403 					CRYPTOPP_ASSERT(m_queue.CurrentSize() < m_blockSize);
404 					size_t len = m_blockSize - m_queue.CurrentSize();
405 					m_queue.Put(inString, len);
406 					inString = PtrAdd(inString, len);
407 					NextPutModifiable(m_queue.GetBlock(), m_blockSize);
408 					newLength -= m_blockSize;
409 				}
410 
411 				if (newLength >= m_blockSize + m_lastSize)
412 				{
413 					size_t len = RoundDownToMultipleOf(newLength - m_lastSize, m_blockSize);
414 					NextPutMaybeModifiable(inString, len, modifiable);
415 					inString = PtrAdd(inString, len);
416 					newLength -= len;
417 				}
418 			}
419 		}
420 
421 		m_queue.Put(inString, newLength - m_queue.CurrentSize());
422 	}
423 
424 	if (messageEnd)
425 	{
426 		if (!m_firstInputDone && m_firstSize==0)
427 			FirstPut(NULLPTR);
428 
429 		SecByteBlock temp(m_queue.CurrentSize());
430 		m_queue.GetAll(temp);
431 		LastPut(temp, temp.size());
432 
433 		m_firstInputDone = false;
434 		m_queue.ResetQueue(1, m_firstSize);
435 
436 		// Cast to void to suppress Coverity finding
437 		(void)Output(1, NULLPTR, 0, messageEnd, blocking);
438 	}
439 	return 0;
440 }
441 
ForceNextPut()442 void FilterWithBufferedInput::ForceNextPut()
443 {
444 	if (!m_firstInputDone)
445 		return;
446 
447 	if (m_blockSize > 1)
448 	{
449 		while (m_queue.CurrentSize() >= m_blockSize)
450 			NextPutModifiable(m_queue.GetBlock(), m_blockSize);
451 	}
452 	else
453 	{
454 		size_t len;
455 		while ((len = m_queue.CurrentSize()) > 0)
456 			NextPutModifiable(m_queue.GetContigousBlocks(len), len);
457 	}
458 }
459 
NextPutMultiple(const byte * inString,size_t length)460 void FilterWithBufferedInput::NextPutMultiple(const byte *inString, size_t length)
461 {
462 	CRYPTOPP_ASSERT(m_blockSize > 1);	// m_blockSize = 1 should always override this function
463 	while (length > 0)
464 	{
465 		CRYPTOPP_ASSERT(length >= m_blockSize);
466 		NextPutSingle(inString);
467 		inString = PtrAdd(inString, m_blockSize);
468 		length -= m_blockSize;
469 	}
470 }
471 
472 // *************************************************************
473 
Initialize(const NameValuePairs & parameters,int propagation)474 void Redirector::Initialize(const NameValuePairs &parameters, int propagation)
475 {
476 	m_target = parameters.GetValueWithDefault("RedirectionTargetPointer", (BufferedTransformation*)NULLPTR);
477 	m_behavior = parameters.GetIntValueWithDefault("RedirectionBehavior", PASS_EVERYTHING);
478 
479 	if (m_target && GetPassSignals())
480 		m_target->Initialize(parameters, propagation);
481 }
482 
483 // *************************************************************
484 
ProxyFilter(BufferedTransformation * filter,size_t firstSize,size_t lastSize,BufferedTransformation * attachment)485 ProxyFilter::ProxyFilter(BufferedTransformation *filter, size_t firstSize, size_t lastSize, BufferedTransformation *attachment)
486 	: FilterWithBufferedInput(firstSize, 1, lastSize, attachment), m_filter(filter)
487 {
488 	if (m_filter.get())
489 		m_filter->Attach(new OutputProxy(*this, false));
490 }
491 
IsolatedFlush(bool hardFlush,bool blocking)492 bool ProxyFilter::IsolatedFlush(bool hardFlush, bool blocking)
493 {
494 	return m_filter.get() ? m_filter->Flush(hardFlush, -1, blocking) : false;
495 }
496 
SetFilter(Filter * filter)497 void ProxyFilter::SetFilter(Filter *filter)
498 {
499 	m_filter.reset(filter);
500 	if (filter)
501 	{
502 		OutputProxy *proxy;
503 		member_ptr<OutputProxy> temp(proxy = new OutputProxy(*this, false));
504 		m_filter->TransferAllTo(*proxy);
505 		m_filter->Attach(temp.release());
506 	}
507 }
508 
NextPutMultiple(const byte * s,size_t len)509 void ProxyFilter::NextPutMultiple(const byte *s, size_t len)
510 {
511 	if (m_filter.get())
512 		m_filter->Put(s, len);
513 }
514 
NextPutModifiable(byte * s,size_t len)515 void ProxyFilter::NextPutModifiable(byte *s, size_t len)
516 {
517 	if (m_filter.get())
518 		m_filter->PutModifiable(s, len);
519 }
520 
521 // *************************************************************
522 
IsolatedInitialize(const NameValuePairs & parameters)523 void RandomNumberSink::IsolatedInitialize(const NameValuePairs &parameters)
524 {
525 	parameters.GetRequiredParameter("RandomNumberSink", "RandomNumberGeneratorPointer", m_rng);
526 }
527 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)528 size_t RandomNumberSink::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
529 {
530 	CRYPTOPP_UNUSED(messageEnd); CRYPTOPP_UNUSED(blocking);
531 	m_rng->IncorporateEntropy(begin, length);
532 	return 0;
533 }
534 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)535 size_t ArraySink::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
536 {
537 	CRYPTOPP_UNUSED(messageEnd); CRYPTOPP_UNUSED(blocking);
538 
539 	// Avoid passing NULL pointer to memcpy. Using memmove due to
540 	//  Valgrind finding on overlapping buffers.
541 	size_t copied = 0;
542 	if (m_buf && begin)
543 	{
544 		copied = STDMIN(length, SaturatingSubtract(m_size, m_total));
545 		memmove(PtrAdd(m_buf, m_total), begin, copied);
546 	}
547 	m_total += copied;
548 	return length - copied;
549 }
550 
CreatePutSpace(size_t & size)551 byte * ArraySink::CreatePutSpace(size_t &size)
552 {
553 	size = SaturatingSubtract(m_size, m_total);
554 	return PtrAdd(m_buf, m_total);
555 }
556 
IsolatedInitialize(const NameValuePairs & parameters)557 void ArraySink::IsolatedInitialize(const NameValuePairs &parameters)
558 {
559 	ByteArrayParameter array;
560 	if (!parameters.GetValue(Name::OutputBuffer(), array))
561 		throw InvalidArgument("ArraySink: missing OutputBuffer argument");
562 	m_buf = array.begin();
563 	m_size = array.size();
564 }
565 
Put2(const byte * begin,size_t length,int messageEnd,bool blocking)566 size_t ArrayXorSink::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
567 {
568 	CRYPTOPP_UNUSED(messageEnd); CRYPTOPP_UNUSED(blocking);
569 
570 	// Avoid passing NULL pointer to xorbuf
571 	size_t copied = 0;
572 	if (m_buf && begin)
573 	{
574 		copied = STDMIN(length, SaturatingSubtract(m_size, m_total));
575 		xorbuf(PtrAdd(m_buf, m_total), begin, copied);
576 	}
577 	m_total += copied;
578 	return length - copied;
579 }
580 
581 // *************************************************************
582 
StreamTransformationFilter(StreamTransformation & c,BufferedTransformation * attachment,BlockPaddingScheme padding)583 StreamTransformationFilter::StreamTransformationFilter(StreamTransformation &c, BufferedTransformation *attachment, BlockPaddingScheme padding)
584 	: FilterWithBufferedInput(attachment), m_cipher(c), m_padding(DEFAULT_PADDING)
585 {
586 	CRYPTOPP_ASSERT(c.MinLastBlockSize() == 0 || c.MinLastBlockSize() > c.MandatoryBlockSize());
587 
588 	const bool authenticatedFilter = dynamic_cast<AuthenticatedSymmetricCipher *>(&c) != NULLPTR;
589 	if (authenticatedFilter)
590 		throw InvalidArgument("StreamTransformationFilter: please use AuthenticatedEncryptionFilter and AuthenticatedDecryptionFilter for AuthenticatedSymmetricCipher");
591 
592 	// InitializeDerivedAndReturnNewSizes may override some of these
593 	m_mandatoryBlockSize = m_cipher.MandatoryBlockSize();
594 	m_optimalBufferSize = m_cipher.OptimalBlockSize();
595 	m_isSpecial = m_cipher.IsLastBlockSpecial() && m_mandatoryBlockSize > 1;
596 	m_reservedBufferSize = STDMAX(2*m_mandatoryBlockSize, m_optimalBufferSize);
597 
598 	FilterWithBufferedInput::IsolatedInitialize(
599 		MakeParameters
600 			(Name::BlockPaddingScheme(), padding));
601 }
602 
StreamTransformationFilter(StreamTransformation & c,BufferedTransformation * attachment,BlockPaddingScheme padding,bool authenticated)603 StreamTransformationFilter::StreamTransformationFilter(StreamTransformation &c, BufferedTransformation *attachment, BlockPaddingScheme padding, bool authenticated)
604 	: FilterWithBufferedInput(attachment), m_cipher(c), m_padding(DEFAULT_PADDING)
605 {
606 	const bool authenticatedFilter = dynamic_cast<AuthenticatedSymmetricCipher *>(&c) != NULLPTR;
607 	if (!authenticatedFilter)
608 	{
609 		CRYPTOPP_ASSERT(c.MinLastBlockSize() == 0 || c.MinLastBlockSize() > c.MandatoryBlockSize());
610 	}
611 
612 	if (authenticatedFilter && !authenticated)
613 		throw InvalidArgument("StreamTransformationFilter: please use AuthenticatedEncryptionFilter and AuthenticatedDecryptionFilter for AuthenticatedSymmetricCipher");
614 
615 	// InitializeDerivedAndReturnNewSizes may override some of these
616 	m_mandatoryBlockSize = m_cipher.MandatoryBlockSize();
617 	m_optimalBufferSize = m_cipher.OptimalBlockSize();
618 	m_isSpecial = m_cipher.IsLastBlockSpecial() && m_mandatoryBlockSize > 1;
619 	m_reservedBufferSize = STDMAX(2*m_mandatoryBlockSize, m_optimalBufferSize);
620 
621 	FilterWithBufferedInput::IsolatedInitialize(
622 		MakeParameters
623 			(Name::BlockPaddingScheme(), padding));
624 }
625 
LastBlockSize(StreamTransformation & c,BlockPaddingScheme padding)626 size_t StreamTransformationFilter::LastBlockSize(StreamTransformation &c, BlockPaddingScheme padding)
627 {
628 	if (c.MinLastBlockSize() > 0)
629 		return c.MinLastBlockSize();
630 	else if (c.MandatoryBlockSize() > 1 && !c.IsForwardTransformation() && padding != NO_PADDING && padding != ZEROS_PADDING)
631 		return c.MandatoryBlockSize();
632 
633 	return 0;
634 }
635 
InitializeDerivedAndReturnNewSizes(const NameValuePairs & parameters,size_t & firstSize,size_t & blockSize,size_t & lastSize)636 void StreamTransformationFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs &parameters, size_t &firstSize, size_t &blockSize, size_t &lastSize)
637 {
638 	BlockPaddingScheme padding = parameters.GetValueWithDefault(Name::BlockPaddingScheme(), DEFAULT_PADDING);
639 	bool isBlockCipher = (m_mandatoryBlockSize > 1 && m_cipher.MinLastBlockSize() == 0);
640 
641 	if (padding == DEFAULT_PADDING)
642 		m_padding = isBlockCipher ? PKCS_PADDING : NO_PADDING;
643 	else
644 		m_padding = padding;
645 
646 	if (!isBlockCipher)
647 	{
648 		if (m_padding == PKCS_PADDING)
649 			throw InvalidArgument("StreamTransformationFilter: PKCS_PADDING cannot be used with " + m_cipher.AlgorithmName());
650 		else if (m_padding == W3C_PADDING)
651 			throw InvalidArgument("StreamTransformationFilter: W3C_PADDING cannot be used with " + m_cipher.AlgorithmName());
652 		else if (m_padding == ONE_AND_ZEROS_PADDING)
653 			throw InvalidArgument("StreamTransformationFilter: ONE_AND_ZEROS_PADDING cannot be used with " + m_cipher.AlgorithmName());
654 	}
655 
656 	firstSize = 0;
657 	blockSize = m_mandatoryBlockSize;
658 	lastSize = LastBlockSize(m_cipher, m_padding);
659 }
660 
FirstPut(const byte * inString)661 void StreamTransformationFilter::FirstPut(const byte* inString)
662 {
663 	CRYPTOPP_UNUSED(inString);
664 	m_optimalBufferSize = STDMAX<unsigned int>(m_optimalBufferSize, RoundDownToMultipleOf(4096U, m_optimalBufferSize));
665 }
666 
NextPutMultiple(const byte * inString,size_t length)667 void StreamTransformationFilter::NextPutMultiple(const byte *inString, size_t length)
668 {
669 	if (!length)
670 		{return;}
671 
672 	const size_t s = m_cipher.MandatoryBlockSize();
673 	do
674 	{
675 		size_t len = m_optimalBufferSize;
676 		byte *space = HelpCreatePutSpace(*AttachedTransformation(), DEFAULT_CHANNEL, s, length, len);
677 		if (len < length)
678 		{
679 			if (len == m_optimalBufferSize)
680 				len -= m_cipher.GetOptimalBlockSizeUsed();
681 			len = RoundDownToMultipleOf(len, s);
682 		}
683 		else
684 			len = length;
685 		m_cipher.ProcessString(space, inString, len);
686 		AttachedTransformation()->PutModifiable(space, len);
687 		inString = PtrAdd(inString, len);
688 		length -= len;
689 	}
690 	while (length > 0);
691 }
692 
NextPutModifiable(byte * inString,size_t length)693 void StreamTransformationFilter::NextPutModifiable(byte *inString, size_t length)
694 {
695 	m_cipher.ProcessString(inString, length);
696 	AttachedTransformation()->PutModifiable(inString, length);
697 }
698 
LastPut(const byte * inString,size_t length)699 void StreamTransformationFilter::LastPut(const byte *inString, size_t length)
700 {
701 	// This block is new to StreamTransformationFilter. It is somewhat of a hack and was
702 	//  added for OCB mode; see GitHub Issue 515. The rub with OCB is, its a block cipher
703 	//  and the last block size can be 0. However, "last block = 0" is not the 0 predicated
704 	//  in the original code. In the original code 0 means "nothing special" so
705 	//  DEFAULT_PADDING is applied. OCB's 0 literally means a final block size can be 0 or
706 	//  non-0; and no padding is needed in either case because OCB has its own scheme (see
707 	//  handling of P_* and A_*).
708 	// Stream ciphers have policy objects to convey how to operate the cipher. The Crypto++
709 	//  framework operates well when MinLastBlockSize() is 1. However, it did not appear to
710 	//  cover the OCB case either because we can't stream OCB. It needs full block sizes. In
711 	//  response we hacked a IsLastBlockSpecial(). When true StreamTransformationFilter
712 	//  defers to the mode for processing of the last block.
713 	// The behavior supplied when IsLastBlockSpecial() will likely have to evolve to capture
714 	//  more complex cases from different authenc modes. I suspect it will have to change
715 	//  from a simple bool to something that conveys more information, like "last block
716 	//  no padding" or "custom padding applied by cipher".
717 	// In some respect we have already hit the need for more information. For example, OCB
718 	//  calculates the checksum on the cipher text at the same time, so we don't need the
719 	//  disjoint behavior of calling "EncryptBlock" followed by a separate "AuthenticateBlock".
720 	//  Additional information may allow us to avoid the two separate calls.
721 	if (m_isSpecial)
722 	{
723 		const size_t leftOver = length % m_mandatoryBlockSize;
724 		byte* space = HelpCreatePutSpace(*AttachedTransformation(), DEFAULT_CHANNEL, m_reservedBufferSize);
725 		length -= leftOver;
726 
727 		if (length)
728 		{
729 			// Process full blocks
730 			m_cipher.ProcessData(space, inString, length);
731 			AttachedTransformation()->Put(space, length);
732 			inString = PtrAdd(inString, length);
733 		}
734 
735 		if (leftOver)
736 		{
737 			// Process final partial block
738 			length = m_cipher.ProcessLastBlock(space, m_reservedBufferSize, inString, leftOver);
739 			AttachedTransformation()->Put(space, length);
740 		}
741 		else
742 		{
743 			// Process final empty block
744 			length = m_cipher.ProcessLastBlock(space, m_reservedBufferSize, NULLPTR, 0);
745 			AttachedTransformation()->Put(space, length);
746 		}
747 
748 		return;
749 	}
750 
751 	switch (m_padding)
752 	{
753 	case NO_PADDING:
754 	case ZEROS_PADDING:
755 		if (length > 0)
756 		{
757 			const size_t minLastBlockSize = m_cipher.MinLastBlockSize();
758 			const bool isForwardTransformation = m_cipher.IsForwardTransformation();
759 
760 			if (isForwardTransformation && m_padding == ZEROS_PADDING && (minLastBlockSize == 0 || length < minLastBlockSize))
761 			{
762 				// do padding
763 				size_t blockSize = STDMAX(minLastBlockSize, (size_t)m_mandatoryBlockSize);
764 				byte* space = HelpCreatePutSpace(*AttachedTransformation(), DEFAULT_CHANNEL, blockSize);
765 				if (inString) {memcpy(space, inString, length);}
766 				memset(PtrAdd(space, length), 0, blockSize - length);
767 				size_t used = m_cipher.ProcessLastBlock(space, blockSize, space, blockSize);
768 				AttachedTransformation()->Put(space, used);
769 			}
770 			else
771 			{
772 				if (minLastBlockSize == 0)
773 				{
774 					if (isForwardTransformation)
775 						throw InvalidDataFormat("StreamTransformationFilter: plaintext length is not a multiple of block size and NO_PADDING is specified");
776 					else
777 						throw InvalidCiphertext("StreamTransformationFilter: ciphertext length is not a multiple of block size");
778 				}
779 
780 				byte* space = HelpCreatePutSpace(*AttachedTransformation(), DEFAULT_CHANNEL, length, m_optimalBufferSize);
781 				size_t used = m_cipher.ProcessLastBlock(space, length, inString, length);
782 				AttachedTransformation()->Put(space, used);
783 			}
784 		}
785 		break;
786 
787 	case PKCS_PADDING:
788 	case W3C_PADDING:
789 	case ONE_AND_ZEROS_PADDING:
790 		unsigned int s;
791 		byte* space;
792 		s = m_mandatoryBlockSize;
793 		CRYPTOPP_ASSERT(s > 1);
794 		space = HelpCreatePutSpace(*AttachedTransformation(), DEFAULT_CHANNEL, s, m_optimalBufferSize);
795 		if (m_cipher.IsForwardTransformation())
796 		{
797 			CRYPTOPP_ASSERT(length < s);
798 			if (inString) {memcpy(space, inString, length);}
799 			if (m_padding == PKCS_PADDING)
800 			{
801 				CRYPTOPP_ASSERT(s < 256);
802 				byte pad = static_cast<byte>(s-length);
803 				memset(PtrAdd(space, length), pad, s-length);
804 			}
805 			else if (m_padding == W3C_PADDING)
806 			{
807 				CRYPTOPP_ASSERT(s < 256);
808 				memset(PtrAdd(space, length), 0, s-length-1);
809 				space[s-1] = static_cast<byte>(s-length);
810 			}
811 			else
812 			{
813 				space[length] = 0x80;
814 				memset(PtrAdd(space, length+1), 0, s-length-1);
815 			}
816 			m_cipher.ProcessData(space, space, s);
817 			AttachedTransformation()->Put(space, s);
818 		}
819 		else
820 		{
821 			if (length != s)
822 				throw InvalidCiphertext("StreamTransformationFilter: ciphertext length is not a multiple of block size");
823 			m_cipher.ProcessData(space, inString, s);
824 			if (m_padding == PKCS_PADDING)
825 			{
826 				byte pad = space[s-1];
827                                 if (pad < 1 || pad > s || FindIfNot(PtrAdd(space, s-pad), PtrAdd(space, s), pad) != PtrAdd(space, s))
828                                     throw InvalidCiphertext("StreamTransformationFilter: invalid PKCS #7 block padding found");
829 				length = s-pad;
830 			}
831 			else if (m_padding == W3C_PADDING)
832 			{
833 				byte pad = space[s - 1];
834 				if (pad < 1 || pad > s)
835 					throw InvalidCiphertext("StreamTransformationFilter: invalid W3C block padding found");
836 				length = s - pad;
837 			}
838 			else
839 			{
840 				while (length > 1 && space[length-1] == 0)
841 					--length;
842 				if (space[--length] != 0x80)
843 					throw InvalidCiphertext("StreamTransformationFilter: invalid ones-and-zeros padding found");
844 			}
845 			AttachedTransformation()->Put(space, length);
846 		}
847 		break;
848 
849 	default:
850 		CRYPTOPP_ASSERT(false);
851 	}
852 }
853 
854 // *************************************************************
855 
HashFilter(HashTransformation & hm,BufferedTransformation * attachment,bool putMessage,int truncatedDigestSize,const std::string & messagePutChannel,const std::string & hashPutChannel)856 HashFilter::HashFilter(HashTransformation &hm, BufferedTransformation *attachment, bool putMessage, int truncatedDigestSize, const std::string &messagePutChannel, const std::string &hashPutChannel)
857 	: m_hashModule(hm), m_putMessage(putMessage), m_digestSize(0), m_space(NULLPTR)
858 	, m_messagePutChannel(messagePutChannel), m_hashPutChannel(hashPutChannel)
859 {
860 	m_digestSize = truncatedDigestSize < 0 ? m_hashModule.DigestSize() : truncatedDigestSize;
861 	Detach(attachment);
862 }
863 
IsolatedInitialize(const NameValuePairs & parameters)864 void HashFilter::IsolatedInitialize(const NameValuePairs &parameters)
865 {
866 	m_putMessage = parameters.GetValueWithDefault(Name::PutMessage(), false);
867 	int s = parameters.GetIntValueWithDefault(Name::TruncatedDigestSize(), -1);
868 	m_digestSize = s < 0 ? m_hashModule.DigestSize() : s;
869 }
870 
Put2(const byte * inString,size_t length,int messageEnd,bool blocking)871 size_t HashFilter::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
872 {
873 	FILTER_BEGIN;
874 	if (m_putMessage)
875 		FILTER_OUTPUT3(1, 0, inString, length, 0, m_messagePutChannel);
876 	if (inString && length)
877 		m_hashModule.Update(inString, length);
878 	if (messageEnd)
879 	{
880 		{
881 			size_t size;
882 			m_space = HelpCreatePutSpace(*AttachedTransformation(), m_hashPutChannel, m_digestSize, m_digestSize, size = m_digestSize);
883 			m_hashModule.TruncatedFinal(m_space, m_digestSize);
884 		}
885 		FILTER_OUTPUT3(2, 0, m_space, m_digestSize, messageEnd, m_hashPutChannel);
886 	}
887 	FILTER_END_NO_MESSAGE_END;
888 }
889 
890 // *************************************************************
891 
HashVerificationFilter(HashTransformation & hm,BufferedTransformation * attachment,word32 flags,int truncatedDigestSize)892 HashVerificationFilter::HashVerificationFilter(HashTransformation &hm, BufferedTransformation *attachment, word32 flags, int truncatedDigestSize)
893 	: FilterWithBufferedInput(attachment)
894 	, m_hashModule(hm), m_flags(0), m_digestSize(0), m_verified(false)
895 {
896 	FilterWithBufferedInput::IsolatedInitialize(
897 		MakeParameters
898 			(Name::HashVerificationFilterFlags(), flags)
899 			(Name::TruncatedDigestSize(), truncatedDigestSize));
900 }
901 
InitializeDerivedAndReturnNewSizes(const NameValuePairs & parameters,size_t & firstSize,size_t & blockSize,size_t & lastSize)902 void HashVerificationFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs &parameters, size_t &firstSize, size_t &blockSize, size_t &lastSize)
903 {
904 	m_flags = parameters.GetValueWithDefault(Name::HashVerificationFilterFlags(), (word32)DEFAULT_FLAGS);
905 	int s = parameters.GetIntValueWithDefault(Name::TruncatedDigestSize(), -1);
906 	m_digestSize = s < 0 ? m_hashModule.DigestSize() : s;
907 	m_verified = false;
908 	firstSize = m_flags & HASH_AT_BEGIN ? m_digestSize : 0;
909 	blockSize = 1;
910 	lastSize = m_flags & HASH_AT_BEGIN ? 0 : m_digestSize;
911 }
912 
FirstPut(const byte * inString)913 void HashVerificationFilter::FirstPut(const byte *inString)
914 {
915 	if (m_flags & HASH_AT_BEGIN)
916 	{
917 		m_expectedHash.New(m_digestSize);
918 		if (inString) {memcpy(m_expectedHash, inString, m_expectedHash.size());}
919 		if (m_flags & PUT_HASH)
920 			AttachedTransformation()->Put(inString, m_expectedHash.size());
921 	}
922 }
923 
NextPutMultiple(const byte * inString,size_t length)924 void HashVerificationFilter::NextPutMultiple(const byte *inString, size_t length)
925 {
926 	m_hashModule.Update(inString, length);
927 	if (m_flags & PUT_MESSAGE)
928 		AttachedTransformation()->Put(inString, length);
929 }
930 
LastPut(const byte * inString,size_t length)931 void HashVerificationFilter::LastPut(const byte *inString, size_t length)
932 {
933 	if (m_flags & HASH_AT_BEGIN)
934 	{
935 		CRYPTOPP_ASSERT(length == 0);
936 		m_verified = m_hashModule.TruncatedVerify(m_expectedHash, m_digestSize);
937 	}
938 	else
939 	{
940 		m_verified = (length==m_digestSize && m_hashModule.TruncatedVerify(inString, length));
941 		if (m_flags & PUT_HASH)
942 			AttachedTransformation()->Put(inString, length);
943 	}
944 
945 	if (m_flags & PUT_RESULT)
946 		AttachedTransformation()->Put(m_verified);
947 
948 	if ((m_flags & THROW_EXCEPTION) && !m_verified)
949 		throw HashVerificationFailed();
950 }
951 
952 // *************************************************************
953 
AuthenticatedEncryptionFilter(AuthenticatedSymmetricCipher & c,BufferedTransformation * attachment,bool putAAD,int truncatedDigestSize,const std::string & macChannel,BlockPaddingScheme padding)954 AuthenticatedEncryptionFilter::AuthenticatedEncryptionFilter(AuthenticatedSymmetricCipher &c, BufferedTransformation *attachment,
955 								bool putAAD, int truncatedDigestSize, const std::string &macChannel, BlockPaddingScheme padding)
956 	: StreamTransformationFilter(c, attachment, padding, true)
957 	, m_hf(c, new OutputProxy(*this, false), putAAD, truncatedDigestSize, AAD_CHANNEL, macChannel)
958 {
959 	CRYPTOPP_ASSERT(c.IsForwardTransformation());
960 }
961 
IsolatedInitialize(const NameValuePairs & parameters)962 void AuthenticatedEncryptionFilter::IsolatedInitialize(const NameValuePairs &parameters)
963 {
964 	m_hf.IsolatedInitialize(parameters);
965 	StreamTransformationFilter::IsolatedInitialize(parameters);
966 }
967 
ChannelCreatePutSpace(const std::string & channel,size_t & size)968 byte * AuthenticatedEncryptionFilter::ChannelCreatePutSpace(const std::string &channel, size_t &size)
969 {
970 	if (channel.empty())
971 		return StreamTransformationFilter::CreatePutSpace(size);
972 
973 	if (channel == AAD_CHANNEL)
974 		return m_hf.CreatePutSpace(size);
975 
976 	throw InvalidChannelName("AuthenticatedEncryptionFilter", channel);
977 }
978 
ChannelPut2(const std::string & channel,const byte * begin,size_t length,int messageEnd,bool blocking)979 size_t AuthenticatedEncryptionFilter::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking)
980 {
981 	if (channel.empty())
982 		return StreamTransformationFilter::Put2(begin, length, messageEnd, blocking);
983 
984 	if (channel == AAD_CHANNEL)
985 		return m_hf.Put2(begin, length, 0, blocking);
986 
987 	throw InvalidChannelName("AuthenticatedEncryptionFilter", channel);
988 }
989 
LastPut(const byte * inString,size_t length)990 void AuthenticatedEncryptionFilter::LastPut(const byte *inString, size_t length)
991 {
992 	StreamTransformationFilter::LastPut(inString, length);
993 	m_hf.MessageEnd();
994 }
995 
996 // *************************************************************
997 
AuthenticatedDecryptionFilter(AuthenticatedSymmetricCipher & c,BufferedTransformation * attachment,word32 flags,int truncatedDigestSize,BlockPaddingScheme padding)998 AuthenticatedDecryptionFilter::AuthenticatedDecryptionFilter(AuthenticatedSymmetricCipher &c, BufferedTransformation *attachment, word32 flags, int truncatedDigestSize, BlockPaddingScheme padding)
999 	: FilterWithBufferedInput(attachment)
1000 	, m_hashVerifier(c, new OutputProxy(*this, false))
1001 	, m_streamFilter(c, new OutputProxy(*this, false), padding, true)
1002 {
1003 	CRYPTOPP_ASSERT(!c.IsForwardTransformation() || c.IsSelfInverting());
1004 	FilterWithBufferedInput::IsolatedInitialize(
1005 		MakeParameters
1006 			(Name::BlockPaddingScheme(), padding)
1007 			(Name::AuthenticatedDecryptionFilterFlags(), flags)
1008 			(Name::TruncatedDigestSize(), truncatedDigestSize));
1009 }
1010 
InitializeDerivedAndReturnNewSizes(const NameValuePairs & parameters,size_t & firstSize,size_t & blockSize,size_t & lastSize)1011 void AuthenticatedDecryptionFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs &parameters, size_t &firstSize, size_t &blockSize, size_t &lastSize)
1012 {
1013 	word32 flags = parameters.GetValueWithDefault(Name::AuthenticatedDecryptionFilterFlags(), (word32)DEFAULT_FLAGS);
1014 
1015 	m_hashVerifier.Initialize(CombinedNameValuePairs(parameters, MakeParameters(Name::HashVerificationFilterFlags(), flags)));
1016 	m_streamFilter.Initialize(parameters);
1017 
1018 	firstSize = m_hashVerifier.m_firstSize;
1019 	blockSize = 1;
1020 	lastSize = m_hashVerifier.m_lastSize;
1021 }
1022 
ChannelCreatePutSpace(const std::string & channel,size_t & size)1023 byte * AuthenticatedDecryptionFilter::ChannelCreatePutSpace(const std::string &channel, size_t &size)
1024 {
1025 	if (channel.empty())
1026 		return m_streamFilter.CreatePutSpace(size);
1027 
1028 	if (channel == AAD_CHANNEL)
1029 		return m_hashVerifier.CreatePutSpace(size);
1030 
1031 	throw InvalidChannelName("AuthenticatedDecryptionFilter", channel);
1032 }
1033 
ChannelPut2(const std::string & channel,const byte * begin,size_t length,int messageEnd,bool blocking)1034 size_t AuthenticatedDecryptionFilter::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking)
1035 {
1036 	if (channel.empty())
1037 	{
1038 		if (m_lastSize > 0)
1039 			m_hashVerifier.ForceNextPut();
1040 		return FilterWithBufferedInput::Put2(begin, length, messageEnd, blocking);
1041 	}
1042 
1043 	if (channel == AAD_CHANNEL)
1044 		return m_hashVerifier.Put2(begin, length, 0, blocking);
1045 
1046 	throw InvalidChannelName("AuthenticatedDecryptionFilter", channel);
1047 }
1048 
FirstPut(const byte * inString)1049 void AuthenticatedDecryptionFilter::FirstPut(const byte *inString)
1050 {
1051 	m_hashVerifier.Put(inString, m_firstSize);
1052 }
1053 
NextPutMultiple(const byte * inString,size_t length)1054 void AuthenticatedDecryptionFilter::NextPutMultiple(const byte *inString, size_t length)
1055 {
1056 	m_streamFilter.Put(inString, length);
1057 }
1058 
LastPut(const byte * inString,size_t length)1059 void AuthenticatedDecryptionFilter::LastPut(const byte *inString, size_t length)
1060 {
1061 	m_streamFilter.MessageEnd();
1062 	m_hashVerifier.PutMessageEnd(inString, length);
1063 }
1064 
1065 // *************************************************************
1066 
IsolatedInitialize(const NameValuePairs & parameters)1067 void SignerFilter::IsolatedInitialize(const NameValuePairs &parameters)
1068 {
1069 	m_putMessage = parameters.GetValueWithDefault(Name::PutMessage(), false);
1070 	m_messageAccumulator.reset(m_signer.NewSignatureAccumulator(m_rng));
1071 }
1072 
Put2(const byte * inString,size_t length,int messageEnd,bool blocking)1073 size_t SignerFilter::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
1074 {
1075 	FILTER_BEGIN;
1076 	m_messageAccumulator->Update(inString, length);
1077 	if (m_putMessage)
1078 		FILTER_OUTPUT(1, inString, length, 0);
1079 	if (messageEnd)
1080 	{
1081 		m_buf.New(m_signer.SignatureLength());
1082 		m_signer.Sign(m_rng, m_messageAccumulator.release(), m_buf);
1083 		FILTER_OUTPUT(2, m_buf, m_buf.size(), messageEnd);
1084 		m_messageAccumulator.reset(m_signer.NewSignatureAccumulator(m_rng));
1085 	}
1086 	FILTER_END_NO_MESSAGE_END;
1087 }
1088 
SignatureVerificationFilter(const PK_Verifier & verifier,BufferedTransformation * attachment,word32 flags)1089 SignatureVerificationFilter::SignatureVerificationFilter(const PK_Verifier &verifier, BufferedTransformation *attachment, word32 flags)
1090 	: FilterWithBufferedInput(attachment)
1091 	, m_verifier(verifier), m_flags(0), m_verified(0)
1092 {
1093 	FilterWithBufferedInput::IsolatedInitialize(
1094 		MakeParameters
1095 			(Name::SignatureVerificationFilterFlags(), flags));
1096 }
1097 
InitializeDerivedAndReturnNewSizes(const NameValuePairs & parameters,size_t & firstSize,size_t & blockSize,size_t & lastSize)1098 void SignatureVerificationFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs &parameters, size_t &firstSize, size_t &blockSize, size_t &lastSize)
1099 {
1100 	m_flags = parameters.GetValueWithDefault(Name::SignatureVerificationFilterFlags(), (word32)DEFAULT_FLAGS);
1101 	m_messageAccumulator.reset(m_verifier.NewVerificationAccumulator());
1102 	size_t size = m_verifier.SignatureLength();
1103 	CRYPTOPP_ASSERT(size != 0);	// TODO: handle recoverable signature scheme
1104 	m_verified = false;
1105 	firstSize = m_flags & SIGNATURE_AT_BEGIN ? size : 0;
1106 	blockSize = 1;
1107 	lastSize = m_flags & SIGNATURE_AT_BEGIN ? 0 : size;
1108 }
1109 
FirstPut(const byte * inString)1110 void SignatureVerificationFilter::FirstPut(const byte *inString)
1111 {
1112 	if (m_flags & SIGNATURE_AT_BEGIN)
1113 	{
1114 		if (m_verifier.SignatureUpfront())
1115 			m_verifier.InputSignature(*m_messageAccumulator, inString, m_verifier.SignatureLength());
1116 		else
1117 		{
1118 			m_signature.New(m_verifier.SignatureLength());
1119 			if (inString) {memcpy(m_signature, inString, m_signature.size());}
1120 		}
1121 
1122 		if (m_flags & PUT_SIGNATURE)
1123 			AttachedTransformation()->Put(inString, m_signature.size());
1124 	}
1125 	else
1126 	{
1127 		CRYPTOPP_ASSERT(!m_verifier.SignatureUpfront());
1128 	}
1129 }
1130 
NextPutMultiple(const byte * inString,size_t length)1131 void SignatureVerificationFilter::NextPutMultiple(const byte *inString, size_t length)
1132 {
1133 	m_messageAccumulator->Update(inString, length);
1134 	if (m_flags & PUT_MESSAGE)
1135 		AttachedTransformation()->Put(inString, length);
1136 }
1137 
LastPut(const byte * inString,size_t length)1138 void SignatureVerificationFilter::LastPut(const byte *inString, size_t length)
1139 {
1140 	if (m_flags & SIGNATURE_AT_BEGIN)
1141 	{
1142 		CRYPTOPP_ASSERT(length == 0);
1143 		m_verifier.InputSignature(*m_messageAccumulator, m_signature, m_signature.size());
1144 		m_verified = m_verifier.VerifyAndRestart(*m_messageAccumulator);
1145 	}
1146 	else
1147 	{
1148 		m_verifier.InputSignature(*m_messageAccumulator, inString, length);
1149 		m_verified = m_verifier.VerifyAndRestart(*m_messageAccumulator);
1150 		if (m_flags & PUT_SIGNATURE)
1151 			AttachedTransformation()->Put(inString, length);
1152 	}
1153 
1154 	if (m_flags & PUT_RESULT)
1155 		AttachedTransformation()->Put(m_verified);
1156 
1157 	if ((m_flags & THROW_EXCEPTION) && !m_verified)
1158 		throw SignatureVerificationFailed();
1159 }
1160 
1161 // *************************************************************
1162 
PumpAll2(bool blocking)1163 size_t Source::PumpAll2(bool blocking)
1164 {
1165 	unsigned int messageCount = UINT_MAX;
1166 	do {
1167 		RETURN_IF_NONZERO(PumpMessages2(messageCount, blocking));
1168 	} while(messageCount == UINT_MAX);
1169 
1170 	return 0;
1171 }
1172 
GetNextMessage()1173 bool Store::GetNextMessage()
1174 {
1175 	if (!m_messageEnd && !AnyRetrievable())
1176 	{
1177 		m_messageEnd=true;
1178 		return true;
1179 	}
1180 	else
1181 		return false;
1182 }
1183 
CopyMessagesTo(BufferedTransformation & target,unsigned int count,const std::string & channel) const1184 unsigned int Store::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const
1185 {
1186 	if (m_messageEnd || count == 0)
1187 		return 0;
1188 	else
1189 	{
1190 		CopyTo(target, ULONG_MAX, channel);
1191 		if (GetAutoSignalPropagation())
1192 			target.ChannelMessageEnd(channel, GetAutoSignalPropagation()-1);
1193 		return 1;
1194 	}
1195 }
1196 
StoreInitialize(const NameValuePairs & parameters)1197 void StringStore::StoreInitialize(const NameValuePairs &parameters)
1198 {
1199 	ConstByteArrayParameter array;
1200 	if (!parameters.GetValue(Name::InputBuffer(), array))
1201 		throw InvalidArgument("StringStore: missing InputBuffer argument");
1202 	m_store = array.begin();
1203 	m_length = array.size();
1204 	m_count = 0;
1205 }
1206 
TransferTo2(BufferedTransformation & target,lword & transferBytes,const std::string & channel,bool blocking)1207 size_t StringStore::TransferTo2(BufferedTransformation &target, lword &transferBytes, const std::string &channel, bool blocking)
1208 {
1209 	lword position = 0;
1210 	size_t blockedBytes = CopyRangeTo2(target, position, transferBytes, channel, blocking);
1211 	m_count += static_cast<size_t>(position);
1212 	transferBytes = position;
1213 	return blockedBytes;
1214 }
1215 
CopyRangeTo2(BufferedTransformation & target,lword & begin,lword end,const std::string & channel,bool blocking) const1216 size_t StringStore::CopyRangeTo2(BufferedTransformation &target, lword &begin, lword end, const std::string &channel, bool blocking) const
1217 {
1218 	size_t i = UnsignedMin(m_length, m_count+begin);
1219 	size_t len = UnsignedMin(m_length-i, end-begin);
1220 	size_t blockedBytes = target.ChannelPut2(channel, PtrAdd(m_store, i), len, 0, blocking);
1221 	if (!blockedBytes)
1222 		begin = PtrAdd(begin, len);
1223 	return blockedBytes;
1224 }
1225 
StoreInitialize(const NameValuePairs & parameters)1226 void RandomNumberStore::StoreInitialize(const NameValuePairs &parameters)
1227 {
1228 	parameters.GetRequiredParameter("RandomNumberStore", "RandomNumberGeneratorPointer", m_rng);
1229 	int length;
1230 	parameters.GetRequiredIntParameter("RandomNumberStore", "RandomNumberStoreSize", length);
1231 	m_length = length;
1232 }
1233 
TransferTo2(BufferedTransformation & target,lword & transferBytes,const std::string & channel,bool blocking)1234 size_t RandomNumberStore::TransferTo2(BufferedTransformation &target, lword &transferBytes, const std::string &channel, bool blocking)
1235 {
1236 	if (!blocking)
1237 		throw NotImplemented("RandomNumberStore: nonblocking transfer is not implemented by this object");
1238 
1239 	transferBytes = UnsignedMin(transferBytes, m_length - m_count);
1240 	m_rng->GenerateIntoBufferedTransformation(target, channel, transferBytes);
1241 	m_count += transferBytes;
1242 
1243 	return 0;
1244 }
1245 
CopyRangeTo2(BufferedTransformation & target,lword & begin,lword end,const std::string & channel,bool blocking) const1246 size_t NullStore::CopyRangeTo2(BufferedTransformation &target, lword &begin, lword end, const std::string &channel, bool blocking) const
1247 {
1248 	static const byte nullBytes[128] = {0};
1249 	while (begin < end)
1250 	{
1251 		size_t len = (size_t)STDMIN(end-begin, lword(128));
1252 		size_t blockedBytes = target.ChannelPut2(channel, nullBytes, len, 0, blocking);
1253 		if (blockedBytes)
1254 			return blockedBytes;
1255 		begin = PtrAdd(begin, len);
1256 	}
1257 	return 0;
1258 }
1259 
TransferTo2(BufferedTransformation & target,lword & transferBytes,const std::string & channel,bool blocking)1260 size_t NullStore::TransferTo2(BufferedTransformation &target, lword &transferBytes, const std::string &channel, bool blocking)
1261 {
1262 	lword begin = 0;
1263 	size_t blockedBytes = NullStore::CopyRangeTo2(target, begin, transferBytes, channel, blocking);
1264 	transferBytes = begin; m_size -= begin;
1265 	return blockedBytes;
1266 }
1267 
1268 NAMESPACE_END
1269 
1270 #endif
1271