1 /*----------------------------------------------------------------------------
2  ADOL-C -- Automatic Differentiation by Overloading in C++
3  File:     medipacksupport.cpp
4  Revision: $Id$
5 
6  Copyright (c) Max Sagebaum
7 
8  This file is part of ADOL-C. This software is provided as open source.
9  Any use, reproduction, or distribution of the software constitutes
10  recipient's acceptance of the terms of the accompanying license file.
11 
12 ----------------------------------------------------------------------------*/
13 
14 #include "taping_p.h"
15 #include "oplate.h"
16 #include "adolc/adouble.h"
17 
18 #ifdef ADOLC_MEDIPACK_SUPPORT
19 
20 #include <vector>
21 #include "medipacksupport_p.h"
22 
23 struct AdolcMeDiAdjointInterface : public medi::AdjointInterface {
24 
25     double** adjointBase;
26     double* primalBase;
27     int vecSize;
28 
AdolcMeDiAdjointInterfaceAdolcMeDiAdjointInterface29     AdolcMeDiAdjointInterface(double** adjointBase, double* primalBase, int vecSize) :
30       adjointBase(adjointBase),
31       primalBase(primalBase),
32       vecSize(vecSize) {}
33 
computeElementsAdolcMeDiAdjointInterface34     int computeElements(int elements) const {
35       return elements * vecSize;
36     }
37 
getVectorSizeAdolcMeDiAdjointInterface38     int getVectorSize() const {
39       return vecSize;
40     }
41 
createAdjointTypeBufferAdolcMeDiAdjointInterface42     inline void createAdjointTypeBuffer(void* &buf, size_t size) const {
43       buf = (void*)new double[size*vecSize];
44     }
45 
deleteAdjointTypeBufferAdolcMeDiAdjointInterface46     inline void deleteAdjointTypeBuffer(void* &b) const {
47       if(NULL != b) {
48         double* buf = (double*)b;
49         delete [] buf;
50         b = NULL;
51       }
52     }
53 
createPrimalTypeBufferAdolcMeDiAdjointInterface54     inline void createPrimalTypeBuffer(void* &buf, size_t size) const {
55       buf = (void*)new double[size];
56     }
57 
deletePrimalTypeBufferAdolcMeDiAdjointInterface58     inline void deletePrimalTypeBuffer(void* &b) const {
59       if(NULL != b) {
60         double* buf = (double*)b;
61         delete [] buf;
62         b = NULL;
63       }
64     }
65 
getAdjointsAdolcMeDiAdjointInterface66     inline void getAdjoints(const void* i, void* a, int elements) const {
67       double* adjoints = (double*)a;
68       int* indices = (int*)i;
69 
70       for(int pos = 0; pos < elements; ++pos) {
71         for(int dim = 0; dim < vecSize; ++dim) {
72           adjoints[calcIndex(pos, dim)] = adjointBase[dim][indices[pos]];
73           adjointBase[dim][indices[pos]] = 0.0;
74         }
75       }
76     }
77 
updateAdjointsAdolcMeDiAdjointInterface78     inline void updateAdjoints(const void* i, const void* a, int elements) const {
79       double* adjoints = (double*)a;
80       int* indices = (int*)i;
81 
82       for(int pos = 0; pos < elements; ++pos) {
83         for(int dim = 0; dim < vecSize; ++dim) {
84           adjointBase[dim][indices[pos]] += adjoints[calcIndex(pos, dim)];
85         }
86       }
87     }
88 
setPrimalsAdolcMeDiAdjointInterface89     inline void setPrimals(const void* i, const void* p, int elements) const {
90       double* primals = (double*)p;
91       int* indices = (int*)i;
92 
93       for(int pos = 0; pos < elements; ++pos) {
94         primalBase[indices[pos]] = primals[pos];
95       }
96     }
97 
getPrimalsAdolcMeDiAdjointInterface98     inline void getPrimals(const void* i, const void* p, int elements) const {
99       double* primals = (double*)p;
100       int* indices = (int*)i;
101 
102       for(int pos = 0; pos < elements; ++pos) {
103         primals[pos] = primalBase[indices[pos]];
104       }
105     }
106 
combineAdjointsAdolcMeDiAdjointInterface107     inline void combineAdjoints(void* b, const int elements, const int ranks) const {
108       double* buf = (double*)b;
109       for(int curRank = 1; curRank < ranks; ++curRank) {
110         for(int curPos = 0; curPos < elements; ++curPos) {
111           for(int dim = 0; dim < vecSize; ++dim) {
112             buf[calcIndex(curPos, dim)] += buf[calcIndex(elements * curRank + curPos, dim)];
113           }
114         }
115       }
116     }
117 
118   private:
119 
calcIndexAdolcMeDiAdjointInterface120     inline int calcIndex(int pos, int dim) const {
121       return pos * vecSize + dim;
122     }
123 };
124 
125 struct AdolcMediStatic {
126     typedef std::vector<medi::HandleBase*> HandleVector;
127     std::vector<HandleVector*> tapeHandles;
128 
~AdolcMediStaticAdolcMediStatic129     ~AdolcMediStatic() {
130       for(size_t i = 0; i < tapeHandles.size(); ++i) {
131         if(nullptr != tapeHandles[i]) {
132           clearHandles(*tapeHandles[i]);
133 
134           delete tapeHandles[i];
135           tapeHandles[i] = nullptr;
136         }
137       }
138     }
139 
getTapeVectorAdolcMediStatic140     HandleVector& getTapeVector(short tapeId) {
141       return *tapeHandles[tapeId];
142     }
143 
callHandleReverseAdolcMediStatic144     void callHandleReverse(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
145       HandleVector& handleVec = getTapeVector(tapeId);
146 
147       medi::HandleBase* handle = handleVec[index];
148 
149       handle->funcReverse(handle, &interface);
150     }
151 
callHandleForwardAdolcMediStatic152     void callHandleForward(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
153       HandleVector& handleVec = getTapeVector(tapeId);
154 
155       medi::HandleBase* handle = handleVec[index];
156 
157       handle->funcForward(handle, &interface);
158     }
159 
callHandlePrimalAdolcMediStatic160     void callHandlePrimal(short tapeId, locint index, AdolcMeDiAdjointInterface& interface) {
161       HandleVector& handleVec = getTapeVector(tapeId);
162 
163       medi::HandleBase* handle = handleVec[index];
164 
165       handle->funcPrimal(handle, &interface);
166     }
167 
initTapeAdolcMediStatic168     void initTape(short tapeId) {
169       if((size_t)tapeId >= tapeHandles.size()) {
170         tapeHandles.resize(tapeId + 1, nullptr);
171       }
172 
173       if(nullptr == tapeHandles[tapeId]) {
174         tapeHandles[tapeId] = new HandleVector();
175       } else {
176         clearHandles(*tapeHandles[tapeId]);
177       }
178     }
179 
freeTapeAdolcMediStatic180     void freeTape(short tapeId) {
181       if((size_t)tapeId < tapeHandles.size() && nullptr != tapeHandles[tapeId]) {
182         clearHandles(*tapeHandles[tapeId]);
183       }
184     }
185 
clearHandlesAdolcMediStatic186     void clearHandles(HandleVector& handles) {
187       for(size_t i = 0; i < handles.size(); ++i) {
188         medi::HandleBase* h = handles[i];
189 
190         delete h;
191       }
192 
193       handles.resize(0);
194     }
195 
addHandleAdolcMediStatic196     locint addHandle(short tapeId, medi::HandleBase* handle) {
197       HandleVector& vector = getTapeVector(tapeId);
198 
199       locint index = (locint)vector.size();
200       vector.push_back(handle);
201       return index;
202     }
203 };
204 
205 AdolcMediStatic* adolcMediStatic;
206 
mediAddHandle(medi::HandleBase * h)207 void mediAddHandle(medi::HandleBase* h) {
208   ADOLC_OPENMP_THREAD_NUMBER;
209   ADOLC_OPENMP_GET_THREAD_NUMBER;
210 
211   // do not need to check trace flag, this is included in the handle check
212   put_op(medi_call);
213   locint index = adolcMediStatic->addHandle(ADOLC_CURRENT_TAPE_INFOS.tapeID, h);
214 
215   ADOLC_PUT_LOCINT(index);
216 }
217 
mediCallHandleReverse(short tapeId,locint index,double * primalVec,double ** adjointVec,int vecSize)218 void mediCallHandleReverse(short tapeId, locint index, double* primalVec, double** adjointVec, int vecSize) {
219   AdolcMeDiAdjointInterface interface(adjointVec, primalVec, vecSize);
220 
221   adolcMediStatic->callHandleReverse(tapeId, index, interface);
222 }
223 
mediCallHandleForward(short tapeId,locint index,double * primalVec,double ** adjointVec,int vecSize)224 void mediCallHandleForward(short tapeId, locint index, double* primalVec, double** adjointVec, int vecSize) {
225   AdolcMeDiAdjointInterface interface(adjointVec, primalVec, vecSize);
226 
227   adolcMediStatic->callHandleForward(tapeId, index, interface);
228 }
229 
mediCallHandlePrimal(short tapeId,locint index,double * primalVec)230 void mediCallHandlePrimal(short tapeId, locint index, double* primalVec) {
231   AdolcMeDiAdjointInterface interface(nullptr, primalVec, 1);
232 
233   adolcMediStatic->callHandlePrimal(tapeId, index, interface);
234 }
235 
mediInitTape(short tapeId)236 void mediInitTape(short tapeId) {
237   if(NULL == adolcMediStatic) {
238     mediInitStatic();
239   }
240   adolcMediStatic->initTape(tapeId);
241 }
242 
mediInitStatic()243 void mediInitStatic() {
244   adolcMediStatic = new AdolcMediStatic();
245 }
246 
mediFinalizeStatic()247 void mediFinalizeStatic() {
248   delete adolcMediStatic;
249 }
250 
251 MPI_Datatype AdolcTool::MpiType;
252 MPI_Datatype AdolcTool::ModifiedMpiType;
253 MPI_Datatype AdolcTool::PrimalMpiType;
254 MPI_Datatype AdolcTool::AdjointMpiType;
255 AdolcTool::MediType* AdolcTool::MPI_TYPE;
256 medi::AMPI_Datatype AdolcTool::MPI_INT_TYPE;
257 
258 medi::OperatorHelper<medi::FunctionHelper<adouble, double, double, int, double, AdolcTool>> AdolcTool::operatorHelper;
259 
260 #endif
261