1 /*----------------------------------------------------------------------------
2 ADOL-C -- Automatic Differentiation by Overloading in C++
3 File: medipacksupport.cpp
4 Revision: $Id$
5
6 Copyright (c) Max Sagebaum
7
8 This file is part of ADOL-C. This software is provided as open source.
9 Any use, reproduction, or distribution of the software constitutes
10 recipient's acceptance of the terms of the accompanying license file.
11
12 ----------------------------------------------------------------------------*/
13
14 #include "taping_p.h"
15 #include "oplate.h"
16 #include "adolc/adouble.h"
17
18 #ifdef ADOLC_MEDIPACK_SUPPORT
19
20 #include <vector>
21 #include "medipacksupport_p.h"
22
23 struct AdolcMeDiAdjointInterface : public medi::AdjointInterface {
24
25 double** adjointBase;
26 double* primalBase;
27 int vecSize;
28
AdolcMeDiAdjointInterfaceAdolcMeDiAdjointInterface29 AdolcMeDiAdjointInterface(double** adjointBase, double* primalBase, int vecSize) :
30 adjointBase(adjointBase),
31 primalBase(primalBase),
32 vecSize(vecSize) {}
33
computeElementsAdolcMeDiAdjointInterface34 int computeElements(int elements) const {
35 return elements * vecSize;
36 }
37
getVectorSizeAdolcMeDiAdjointInterface38 int getVectorSize() const {
39 return vecSize;
40 }
41
createAdjointTypeBufferAdolcMeDiAdjointInterface42 inline void createAdjointTypeBuffer(void* &buf, size_t size) const {
43 buf = (void*)new double[size*vecSize];
44 }
45
deleteAdjointTypeBufferAdolcMeDiAdjointInterface46 inline void deleteAdjointTypeBuffer(void* &b) const {
47 if(NULL != b) {
48 double* buf = (double*)b;
49 delete [] buf;
50 b = NULL;
51 }
52 }
53
createPrimalTypeBufferAdolcMeDiAdjointInterface54 inline void createPrimalTypeBuffer(void* &buf, size_t size) const {
55 buf = (void*)new double[size];
56 }
57
deletePrimalTypeBufferAdolcMeDiAdjointInterface58 inline void deletePrimalTypeBuffer(void* &b) const {
59 if(NULL != b) {
60 double* buf = (double*)b;
61 delete [] buf;
62 b = NULL;
63 }
64 }
65
getAdjointsAdolcMeDiAdjointInterface66 inline void getAdjoints(const void* i, void* a, int elements) const {
67 double* adjoints = (double*)a;
68 int* indices = (int*)i;
69
70 for(int pos = 0; pos < elements; ++pos) {
71 for(int dim = 0; dim < vecSize; ++dim) {
72 adjoints[calcIndex(pos, dim)] = adjointBase[dim][indices[pos]];
73 adjointBase[dim][indices[pos]] = 0.0;
74 }
75 }
76 }
77
updateAdjointsAdolcMeDiAdjointInterface78 inline void updateAdjoints(const void* i, const void* a, int elements) const {
79 double* adjoints = (double*)a;
80 int* indices = (int*)i;
81
82 for(int pos = 0; pos < elements; ++pos) {
83 for(int dim = 0; dim < vecSize; ++dim) {
84 adjointBase[dim][indices[pos]] += adjoints[calcIndex(pos, dim)];
85 }
86 }
87 }
88
setPrimalsAdolcMeDiAdjointInterface89 inline void setPrimals(const void* i, const void* p, int elements) const {
90 double* primals = (double*)p;
91 int* indices = (int*)i;
92
93 for(int pos = 0; pos < elements; ++pos) {
94 primalBase[indices[pos]] = primals[pos];
95 }
96 }
97
getPrimalsAdolcMeDiAdjointInterface98 inline void getPrimals(const void* i, const void* p, int elements) const {
99 double* primals = (double*)p;
100 int* indices = (int*)i;
101
102 for(int pos = 0; pos < elements; ++pos) {
103 primals[pos] = primalBase[indices[pos]];
104 }
105 }
106
combineAdjointsAdolcMeDiAdjointInterface107 inline void combineAdjoints(void* b, const int elements, const int ranks) const {
108 double* buf = (double*)b;
109 for(int curRank = 1; curRank < ranks; ++curRank) {
110 for(int curPos = 0; curPos < elements; ++curPos) {
111 for(int dim = 0; dim < vecSize; ++dim) {
112 buf[calcIndex(curPos, dim)] += buf[calcIndex(elements * curRank + curPos, dim)];
113 }
114 }
115 }
116 }
117
118 private:
119
calcIndexAdolcMeDiAdjointInterface120 inline int calcIndex(int pos, int dim) const {
121 return pos * vecSize + dim;
122 }
123 };
124
125 struct AdolcMediStatic {
126 typedef std::vector<medi::HandleBase*> HandleVector;
127 std::vector<HandleVector*> tapeHandles;
128
~AdolcMediStaticAdolcMediStatic129 ~AdolcMediStatic() {
130 for(size_t i = 0; i < tapeHandles.size(); ++i) {
131 if(nullptr != tapeHandles[i]) {
132 clearHandles(*tapeHandles[i]);
133
134 delete tapeHandles[i];
135 tapeHandles[i] = nullptr;
136 }
137 }
138 }
139
getTapeVectorAdolcMediStatic140 HandleVector& getTapeVector(short tapeId) {
141 return *tapeHandles[tapeId];
142 }
143
callHandleReverseAdolcMediStatic144 void callHandleReverse(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
145 HandleVector& handleVec = getTapeVector(tapeId);
146
147 medi::HandleBase* handle = handleVec[index];
148
149 handle->funcReverse(handle, &interface);
150 }
151
callHandleForwardAdolcMediStatic152 void callHandleForward(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
153 HandleVector& handleVec = getTapeVector(tapeId);
154
155 medi::HandleBase* handle = handleVec[index];
156
157 handle->funcForward(handle, &interface);
158 }
159
callHandlePrimalAdolcMediStatic160 void callHandlePrimal(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
161 HandleVector& handleVec = getTapeVector(tapeId);
162
163 medi::HandleBase* handle = handleVec[index];
164
165 handle->funcPrimal(handle, &interface);
166 }
167
initTapeAdolcMediStatic168 void initTape(short tapeId) {
169 if((size_t)tapeId >= tapeHandles.size()) {
170 tapeHandles.resize(tapeId + 1, nullptr);
171 }
172
173 if(nullptr == tapeHandles[tapeId]) {
174 tapeHandles[tapeId] = new HandleVector();
175 } else {
176 clearHandles(*tapeHandles[tapeId]);
177 }
178 }
179
freeTapeAdolcMediStatic180 void freeTape(short tapeId) {
181 if((size_t)tapeId < tapeHandles.size() && nullptr != tapeHandles[tapeId]) {
182 clearHandles(*tapeHandles[tapeId]);
183 }
184 }
185
clearHandlesAdolcMediStatic186 void clearHandles(HandleVector& handles) {
187 for(size_t i = 0; i < handles.size(); ++i) {
188 medi::HandleBase* h = handles[i];
189
190 delete h;
191 }
192
193 handles.resize(0);
194 }
195
addHandleAdolcMediStatic196 locint addHandle(short tapeId, medi::HandleBase* handle) {
197 HandleVector& vector = getTapeVector(tapeId);
198
199 locint index = (locint)vector.size();
200 vector.push_back(handle);
201 return index;
202 }
203 };
204
205 AdolcMediStatic* adolcMediStatic;
206
mediAddHandle(medi::HandleBase * h)207 void mediAddHandle(medi::HandleBase* h) {
208 ADOLC_OPENMP_THREAD_NUMBER;
209 ADOLC_OPENMP_GET_THREAD_NUMBER;
210
211 // do not need to check trace flag, this is included in the handle check
212 put_op(medi_call);
213 locint index = adolcMediStatic->addHandle(ADOLC_CURRENT_TAPE_INFOS.tapeID, h);
214
215 ADOLC_PUT_LOCINT(index);
216 }
217
mediCallHandleReverse(short tapeId,locint index,double * primalVec,double ** adjointVec,int vecSize)218 void mediCallHandleReverse(short tapeId, locint index, double* primalVec, double** adjointVec, int vecSize) {
219 AdolcMeDiAdjointInterface interface(adjointVec, primalVec, vecSize);
220
221 adolcMediStatic->callHandleReverse(tapeId, index, interface);
222 }
223
mediCallHandleForward(short tapeId,locint index,double * primalVec,double ** adjointVec,int vecSize)224 void mediCallHandleForward(short tapeId, locint index, double* primalVec, double** adjointVec, int vecSize) {
225 AdolcMeDiAdjointInterface interface(adjointVec, primalVec, vecSize);
226
227 adolcMediStatic->callHandleForward(tapeId, index, interface);
228 }
229
mediCallHandlePrimal(short tapeId,locint index,double * primalVec)230 void mediCallHandlePrimal(short tapeId, locint index, double* primalVec) {
231 AdolcMeDiAdjointInterface interface(nullptr, primalVec, 1);
232
233 adolcMediStatic->callHandlePrimal(tapeId, index, interface);
234 }
235
mediInitTape(short tapeId)236 void mediInitTape(short tapeId) {
237 if(NULL == adolcMediStatic) {
238 mediInitStatic();
239 }
240 adolcMediStatic->initTape(tapeId);
241 }
242
mediInitStatic()243 void mediInitStatic() {
244 adolcMediStatic = new AdolcMediStatic();
245 }
246
mediFinalizeStatic()247 void mediFinalizeStatic() {
248 delete adolcMediStatic;
249 }
250
251 MPI_Datatype AdolcTool::MpiType;
252 MPI_Datatype AdolcTool::ModifiedMpiType;
253 MPI_Datatype AdolcTool::PrimalMpiType;
254 MPI_Datatype AdolcTool::AdjointMpiType;
255 AdolcTool::MediType* AdolcTool::MPI_TYPE;
256 medi::AMPI_Datatype AdolcTool::MPI_INT_TYPE;
257
258 medi::OperatorHelper<medi::FunctionHelper<adouble, double, double, int, double, AdolcTool>> AdolcTool::operatorHelper;
259
260 #endif
261