1 /*******************************************************************************
2 * Copyright 2015-2016 Juan Francisco Crespo Galán
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *   http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 ******************************************************************************/
16 
17 #include "fx/ConvolverReader.h"
18 #include "Exception.h"
19 
20 #include <cstring>
21 #include <algorithm>
22 #include <cmath>
23 #include <cstdlib>
24 
25 AUD_NAMESPACE_BEGIN
ConvolverReader(std::shared_ptr<IReader> reader,std::shared_ptr<ImpulseResponse> ir,std::shared_ptr<ThreadPool> threadPool,std::shared_ptr<FFTPlan> plan)26 ConvolverReader::ConvolverReader(std::shared_ptr<IReader> reader, std::shared_ptr<ImpulseResponse> ir, std::shared_ptr<ThreadPool> threadPool, std::shared_ptr<FFTPlan> plan) :
27 	m_reader(reader), m_ir(ir), m_N(plan->getSize()), m_eosReader(false), m_eosTail(false), m_inChannels(reader->getSpecs().channels), m_irChannels(ir->getSpecs().channels), m_threadPool(threadPool), m_position(0)
28 {
29 	m_nChannelThreads = std::min((int)threadPool->getNumOfThreads(), m_inChannels);
30 	m_futures.resize(m_nChannelThreads);
31 
32 	int irLength = m_ir->getLength();
33 	if(m_irChannels != 1 && m_irChannels != m_inChannels)
34 		AUD_THROW(StateException, "The impulse response and the sound must either have the same amount of channels or the impulse response must be mono");
35 	if(m_reader->getSpecs().rate != m_ir->getSpecs().rate)
36 		AUD_THROW(StateException, "The sound and the impulse response. must have the same rate");
37 
38 	m_M = m_L = m_N / 2;
39 
40 	if(m_irChannels > 1)
41 		for(int i = 0; i < m_inChannels; i++)
42 			m_convolvers.push_back(std::unique_ptr<Convolver>(new Convolver(ir->getChannel(i), irLength, m_threadPool, plan)));
43 	else
44 		for(int i = 0; i < m_inChannels; i++)
45 			m_convolvers.push_back(std::unique_ptr<Convolver>(new Convolver(ir->getChannel(0), irLength, m_threadPool, plan)));
46 
47 	for(int i = 0; i < m_inChannels; i++)
48 		m_vecInOut.push_back((sample_t*)std::malloc(m_L*sizeof(sample_t)));
49 	m_outBuffer = (sample_t*)std::malloc(m_L*m_inChannels*sizeof(sample_t));
50 	m_outBufLen = m_eOutBufLen = m_outBufferPos = m_L*m_inChannels;
51 }
52 
~ConvolverReader()53 ConvolverReader::~ConvolverReader()
54 {
55 	std::free(m_outBuffer);
56 	for(int i = 0; i < m_inChannels; i++)
57 		std::free(m_vecInOut[i]);
58 }
59 
isSeekable() const60 bool ConvolverReader::isSeekable() const
61 {
62 	return m_reader->isSeekable();
63 }
64 
seek(int position)65 void ConvolverReader::seek(int position)
66 {
67 	m_position = position;
68 	m_reader->seek(position);
69 	for(int i = 0; i < m_inChannels; i++)
70 		m_convolvers[i]->reset();
71 	m_eosTail = false;
72 	m_eosReader = false;
73 	m_outBufferPos = m_eOutBufLen = m_outBufLen;
74 }
75 
getLength() const76 int ConvolverReader::getLength() const
77 {
78 	return m_reader->getLength();
79 }
80 
getPosition() const81 int ConvolverReader::getPosition() const
82 {
83 	return m_position;
84 }
85 
getSpecs() const86 Specs ConvolverReader::getSpecs() const
87 {
88 	return m_reader->getSpecs();
89 }
90 
read(int & length,bool & eos,sample_t * buffer)91 void ConvolverReader::read(int& length, bool& eos, sample_t* buffer)
92 {
93 	if(length <= 0)
94 	{
95 		length = 0;
96 		eos = (m_eosTail && m_outBufferPos >= m_eOutBufLen);
97 		return;
98 	}
99 	eos = false;
100 	int writePos = 0;
101 	do
102 	{
103 		int bufRest = m_eOutBufLen - m_outBufferPos;
104 		int writeLength = std::min((length*m_inChannels) - writePos, m_eOutBufLen + bufRest);
105 		if(bufRest < writeLength || (m_eOutBufLen == 0 && m_eosTail))
106 		{
107 			if(bufRest > 0)
108 				std::memcpy(buffer + writePos, m_outBuffer + m_outBufferPos, bufRest*sizeof(sample_t));
109 			if(!m_eosTail)
110 			{
111 				loadBuffer();
112 				int len = std::min(std::abs(writeLength - bufRest), m_eOutBufLen);
113 				std::memcpy(buffer + writePos + bufRest, m_outBuffer, len*sizeof(sample_t));
114 				m_outBufferPos = len;
115 				writeLength = std::min((length*m_inChannels) - writePos, m_eOutBufLen + bufRest);
116 			}
117 			else
118 			{
119 				m_outBufferPos += bufRest;
120 				length = (writePos + bufRest) / m_inChannels;
121 				eos = true;
122 				return;
123 			}
124 		}
125 		else
126 		{
127 			std::memcpy(buffer + writePos, m_outBuffer + m_outBufferPos, writeLength*sizeof(sample_t));
128 			m_outBufferPos += writeLength;
129 		}
130 		writePos += writeLength;
131 	} while(writePos < length*m_inChannels);
132 	m_position += length;
133 }
134 
loadBuffer()135 void ConvolverReader::loadBuffer()
136 {
137 	m_lastLengthIn = m_L;
138 	m_reader->read(m_lastLengthIn, m_eosReader, m_outBuffer);
139 	if(!m_eosReader || m_lastLengthIn>0)
140 	{
141 		divideByChannel(m_outBuffer, m_lastLengthIn*m_inChannels);
142 		int len = m_lastLengthIn;
143 
144 		for(int i = 0; i < m_futures.size(); i++)
145 			m_futures[i] = m_threadPool->enqueue(&ConvolverReader::threadFunction, this, i, true);
146 		for(auto &fut : m_futures)
147 			len = fut.get();
148 
149 		joinByChannel(0, len);
150 		m_eOutBufLen = len*m_inChannels;
151 	}
152 	else if(!m_eosTail)
153 	{
154 		int len = m_lastLengthIn = m_L;
155 		for(int i = 0; i < m_futures.size(); i++)
156 			m_futures[i] = m_threadPool->enqueue(&ConvolverReader::threadFunction, this, i, false);
157 		for(auto &fut : m_futures)
158 			len = fut.get();
159 
160 		joinByChannel(0, len);
161 		m_eOutBufLen = len*m_inChannels;
162 	}
163 }
164 
divideByChannel(const sample_t * buffer,int len)165 void ConvolverReader::divideByChannel(const sample_t* buffer, int len)
166 {
167 	int k = 0;
168 	for(int i = 0; i < len; i += m_inChannels)
169 	{
170 		for(int j = 0; j < m_inChannels; j++)
171 			m_vecInOut[j][k] = buffer[i + j];
172 		k++;
173 	}
174 }
175 
joinByChannel(int start,int len)176 void ConvolverReader::joinByChannel(int start, int len)
177 {
178 	int k = 0;
179 	for(int i = 0; i < len*m_inChannels; i += m_inChannels)
180 	{
181 		for(int j = 0; j < m_vecInOut.size(); j++)
182 			m_outBuffer[i + j + start] = m_vecInOut[j][k];
183 		k++;
184 	}
185 }
186 
threadFunction(int id,bool input)187 int ConvolverReader::threadFunction(int id, bool input)
188 {
189 	int share = std::ceil((float)m_inChannels / (float)m_nChannelThreads);
190 	int start = id*share;
191 	int end = std::min(start + share, m_inChannels);
192 
193 	int l=m_lastLengthIn;
194 	for(int i = start; i < end; i++)
195 		if(input)
196 			m_convolvers[i]->getNext(m_vecInOut[i], m_vecInOut[i], l, m_eosTail);
197 		else
198 			m_convolvers[i]->getNext(nullptr, m_vecInOut[i], l, m_eosTail);
199 
200 	return l;
201 }
202 
203 AUD_NAMESPACE_END
204