1 // mqueue.cpp - originally written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 #ifndef CRYPTOPP_IMPORTS
6 
7 #include "mqueue.h"
8 
NAMESPACE_BEGIN(CryptoPP)9 NAMESPACE_BEGIN(CryptoPP)
10 
11 MessageQueue::MessageQueue(unsigned int nodeSize)
12 	: m_queue(nodeSize), m_lengths(1, 0U), m_messageCounts(1, 0U)
13 {
14 }
15 
CopyRangeTo2(BufferedTransformation & target,lword & begin,lword end,const std::string & channel,bool blocking) const16 size_t MessageQueue::CopyRangeTo2(BufferedTransformation &target, lword &begin, lword end, const std::string &channel, bool blocking) const
17 {
18 	if (begin >= MaxRetrievable())
19 		return 0;
20 
21 	return m_queue.CopyRangeTo2(target, begin, STDMIN(MaxRetrievable(), end), channel, blocking);
22 }
23 
TransferTo2(BufferedTransformation & target,lword & transferBytes,const std::string & channel,bool blocking)24 size_t MessageQueue::TransferTo2(BufferedTransformation &target, lword &transferBytes, const std::string &channel, bool blocking)
25 {
26 	transferBytes = STDMIN(MaxRetrievable(), transferBytes);
27 	size_t blockedBytes = m_queue.TransferTo2(target, transferBytes, channel, blocking);
28 	m_lengths.front() -= transferBytes;
29 	return blockedBytes;
30 }
31 
GetNextMessage()32 bool MessageQueue::GetNextMessage()
33 {
34 	if (NumberOfMessages() > 0 && !AnyRetrievable())
35 	{
36 		m_lengths.pop_front();
37 		if (m_messageCounts[0] == 0 && m_messageCounts.size() > 1)
38 			m_messageCounts.pop_front();
39 		return true;
40 	}
41 	else
42 		return false;
43 }
44 
CopyMessagesTo(BufferedTransformation & target,unsigned int count,const std::string & channel) const45 unsigned int MessageQueue::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const
46 {
47 	ByteQueue::Walker walker(m_queue);
48 	std::deque<lword>::const_iterator it = m_lengths.begin();
49 	unsigned int i;
50 	for (i=0; i<count && it != --m_lengths.end(); ++i, ++it)
51 	{
52 		walker.TransferTo(target, *it, channel);
53 		if (GetAutoSignalPropagation())
54 			target.ChannelMessageEnd(channel, GetAutoSignalPropagation()-1);
55 	}
56 	return i;
57 }
58 
swap(MessageQueue & rhs)59 void MessageQueue::swap(MessageQueue &rhs)
60 {
61 	m_queue.swap(rhs.m_queue);
62 	m_lengths.swap(rhs.m_lengths);
63 }
64 
Spy(size_t & contiguousSize) const65 const byte * MessageQueue::Spy(size_t &contiguousSize) const
66 {
67 	const byte *result = m_queue.Spy(contiguousSize);
68 	contiguousSize = UnsignedMin(contiguousSize, MaxRetrievable());
69 	return result;
70 }
71 
72 // *************************************************************
73 
MapChannel(const std::string & channel) const74 unsigned int EqualityComparisonFilter::MapChannel(const std::string &channel) const
75 {
76 	if (channel == m_firstChannel)
77 		return 0;
78 	else if (channel == m_secondChannel)
79 		return 1;
80 	else
81 		return 2;
82 }
83 
ChannelPut2(const std::string & channel,const byte * inString,size_t length,int messageEnd,bool blocking)84 size_t EqualityComparisonFilter::ChannelPut2(const std::string &channel, const byte *inString, size_t length, int messageEnd, bool blocking)
85 {
86 	if (!blocking)
87 		throw BlockingInputOnly("EqualityComparisonFilter");
88 
89 	unsigned int i = MapChannel(channel);
90 
91 	if (i == 2)
92 		return Output(3, inString, length, messageEnd, blocking, channel);
93 	else if (m_mismatchDetected)
94 		return 0;
95 	else
96 	{
97 		MessageQueue &q1 = m_q[i], &q2 = m_q[1-i];
98 
99 		if (q2.AnyMessages() && q2.MaxRetrievable() < length)
100 			goto mismatch;
101 
102 		while (length > 0 && q2.AnyRetrievable())
103 		{
104 			size_t len = length;
105 			const byte *data = q2.Spy(len);
106 			len = STDMIN(len, length);
107 			if (memcmp(inString, data, len) != 0)
108 				goto mismatch;
109 			inString += len;
110 			length -= len;
111 			q2.Skip(len);
112 		}
113 
114 		q1.Put(inString, length);
115 
116 		if (messageEnd)
117 		{
118 			if (q2.AnyRetrievable())
119 				goto mismatch;
120 			else if (q2.AnyMessages())
121 				q2.GetNextMessage();
122 			else if (q2.NumberOfMessageSeries() > 0)
123 				goto mismatch;
124 			else
125 				q1.MessageEnd();
126 		}
127 
128 		return 0;
129 
130 mismatch:
131 		return HandleMismatchDetected(blocking);
132 	}
133 }
134 
ChannelMessageSeriesEnd(const std::string & channel,int propagation,bool blocking)135 bool EqualityComparisonFilter::ChannelMessageSeriesEnd(const std::string &channel, int propagation, bool blocking)
136 {
137 	unsigned int i = MapChannel(channel);
138 
139 	if (i == 2)
140 	{
141 		OutputMessageSeriesEnd(4, propagation, blocking, channel);
142 		return false;
143 	}
144 	else if (m_mismatchDetected)
145 		return false;
146 	else
147 	{
148 		MessageQueue &q1 = m_q[i], &q2 = m_q[1-i];
149 
150 		if (q2.AnyRetrievable() || q2.AnyMessages())
151 			goto mismatch;
152 		else if (q2.NumberOfMessageSeries() > 0)
153 			return Output(2, (const byte *)"\1", 1, 0, blocking) != 0;
154 		else
155 			q1.MessageSeriesEnd();
156 
157 		return false;
158 
159 mismatch:
160 		return HandleMismatchDetected(blocking);
161 	}
162 }
163 
HandleMismatchDetected(bool blocking)164 bool EqualityComparisonFilter::HandleMismatchDetected(bool blocking)
165 {
166 	m_mismatchDetected = true;
167 	if (m_throwIfNotEqual)
168 		throw MismatchDetected();
169 	const byte b[1] = {0};
170 	return Output(1, b, 1, 0, blocking) != 0;
171 }
172 
173 NAMESPACE_END
174 
175 #endif
176