1 /*
2  * Copyright (c) 2005-2019, NumPy Developers.
3  *
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are
8  * met:
9  *
10  *  * Redistributions of source code must retain the above copyright
11  *     notice, this list of conditions and the following disclaimer.
12  *
13  *  * Redistributions in binary form must reproduce the above
14  *     copyright notice, this list of conditions and the following
15  *     disclaimer in the documentation and/or other materials provided
16  *     with the distribution.
17  *
18  *  * Neither the name of the NumPy Developers nor the names of any
19  *     contributors may be used to endorse or promote products derived
20  *     from this software without specific prior written permission.
21  *
22  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  */
34 
35 /*!
36  * \file np_einsum_op-inl.h
37  * \brief Function definition of numpy-compatible einsum operator
38  * modified by Haozheng Fan(@hzfan) from:
39  * https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/einsum.c.src
40  */
41 
42 #ifndef MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
43 #define MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
44 
45 #include <mxnet/operator_util.h>
46 #include <string>
47 #include <vector>
48 #include <algorithm>
49 #include "./np_tensordot_op-inl.h"
50 #include "./np_einsum_path_op-inl.h"
51 #include "../../common/static_array.h"
52 #include "../mxnet_op.h"
53 #include "../operator_common.h"
54 #include "../mshadow_op.h"
55 #include "../elemwise_op_common.h"
56 
57 namespace mxnet {
58 namespace op {
59 
60 #define NPY_MAXDIMS 16
61 #define NPY_MAXARGS 16
62 
get_stride(const TShape & shape)63 inline TShape get_stride(const TShape& shape) {
64   int ndim = shape.ndim(), prod = 1;
65   TShape stride = TShape(ndim, -1);
66   for (int i = ndim - 1; i >= 0; i--) {
67     stride[i] = shape[i] > 1 ? prod : 0;
68     prod = prod * shape[i];
69   }
70   return stride;
71 }
72 
pad(const TShape & shape,int odim)73 inline TShape pad(const TShape& shape, int odim) {
74   int ndim = shape.ndim();
75   CHECK_GE(odim, ndim);
76   TShape ret(odim, 1);
77   for (int idim = 0; idim < ndim; ++idim) {
78     ret[idim] = shape[idim];
79   }
80   return ret;
81 }
82 
83 /*
84  * Parses the subscripts for one operand into an output of 'ndim'
85  * labels. The resulting 'op_labels' array will have:
86  *  - the ASCII code of the label for the first occurrence of a label;
87  *  - the (negative) offset to the first occurrence of the label for
88  *    repeated labels;
89  *  - zero for broadcast dimensions, if subscripts has an ellipsis.
90  * For example:
91  *  - subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
92  *  - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
93  */
parse_operand_subscripts(const char * subscripts,int length,int ndim,int iop,char * op_labels,char * label_counts,int * min_label,int * max_label)94 inline int parse_operand_subscripts(const char *subscripts, int length,
95                                     int ndim, int iop, char *op_labels,
96                                     char *label_counts, int *min_label, int *max_label) {
97   using namespace mxnet_op;
98   int i;
99   int idim = 0;
100   int ellipsis = -1;
101 
102   /* Process all labels for this operand */
103   for (i = 0; i < length; ++i) {
104     int label = subscripts[i];
105 
106     /* A proper label for an axis. */
107     if (label > 0 && isalpha(label)) {
108       /* Check we don't exceed the operator dimensions. */
109       CHECK(idim < ndim)
110         << "einstein sum subscripts string contains "
111         << "too many subscripts for operand "
112         << iop;
113 
114       op_labels[idim++] = label;
115       if (label < *min_label) {
116         *min_label = label;
117       }
118       if (label > *max_label) {
119         *max_label = label;
120       }
121       label_counts[label]++;
122     } else if (label == '.') {
123       /* The beginning of the ellipsis. */
124       /* Check it's a proper ellipsis. */
125       CHECK(!(ellipsis != -1 || i + 2 >= length
126               || subscripts[++i] != '.' || subscripts[++i] != '.'))
127         << "einstein sum subscripts string contains a "
128         << "'.' that is not part of an ellipsis ('...') "
129         << "in operand "
130         << iop;
131 
132       ellipsis = idim;
133     } else {
134         CHECK(label == ' ')
135           << "invalid subscript '" << static_cast<char>(label)
136           << "' in einstein sum "
137           << "subscripts string, subscripts must "
138           << "be letters";
139     }
140   }
141 
142   /* No ellipsis found, labels must match dimensions exactly. */
143   if (ellipsis == -1) {
144     CHECK(idim == ndim)
145       << "operand has more dimensions than subscripts "
146       << "given in einstein sum, but no '...' ellipsis "
147       << "provided to broadcast the extra dimensions.";
148   } else if (idim < ndim) {
149     /* Ellipsis found, may have to add broadcast dimensions. */
150     /* Move labels after ellipsis to the end. */
151     for (i = 0; i < idim - ellipsis; ++i) {
152       op_labels[ndim - i - 1] = op_labels[idim - i - 1];
153     }
154     /* Set all broadcast dimensions to zero. */
155     for (i = 0; i < ndim - idim; ++i) {
156       op_labels[ellipsis + i] = 0;
157     }
158   }
159 
160   /*
161    * Find any labels duplicated for this operand, and turn them
162    * into negative offsets to the axis to merge with.
163    *
164    * In C, the char type may be signed or unsigned, but with
165    * twos complement arithmetic the char is ok either way here, and
166    * later where it matters the char is cast to a signed char.
167    */
168   for (idim = 0; idim < ndim - 1; ++idim) {
169     int label = op_labels[idim];
170     /* If it is a proper label, find any duplicates of it. */
171     if (label > 0) {
172       /* Search for the next matching label. */
173       char *next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
174 
175       while (next != nullptr) {
176         /* The offset from next to op_labels[idim] (negative). */
177         *next = static_cast<char>((op_labels + idim) - next);
178         /* Search for the next matching label. */
179         next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
180       }
181     }
182   }
183   return 0;
184 }
185 
186 /*
187  * Parses the subscripts for the output operand into an output that
188  * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
189  * number of output dimensions, or -1 if there is an error. Similarly
190  * to parse_operand_subscripts, the 'out_labels' array will have, for
191  * each dimension:
192  *  - the ASCII code of the corresponding label;
193  *  - zero for broadcast dimensions, if subscripts has an ellipsis.
194  */
parse_output_subscripts(const char * subscripts,int length,int ndim_broadcast,const char * label_counts,char * out_labels)195 inline int parse_output_subscripts(const char *subscripts, int length,
196                                    int ndim_broadcast,
197                                    const char *label_counts, char *out_labels) {
198   using namespace mxnet_op;
199   int i, bdim;
200   int ndim = 0;
201   int ellipsis = 0;
202 
203   /* Process all the output labels. */
204   for (i = 0; i < length; ++i) {
205     int label = subscripts[i];
206 
207     /* A proper label for an axis. */
208     if (label > 0 && isalpha(label)) {
209       /* Check that it doesn't occur again. */
210       CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
211         << "einstein sum subscripts string includes "
212         << "output subscript '" << static_cast<char>(label)
213         << "' multiple times";
214 
215       /* Check that it was used in the inputs. */
216       CHECK(label_counts[label] != 0)
217         << "einstein sum subscripts string included "
218         << "output subscript '" << static_cast<char>(label)
219         << "' which never appeared "
220         << "in an input";
221 
222       /* Check that there is room in out_labels for this label. */
223       CHECK(ndim < NPY_MAXDIMS)
224         << "einstein sum subscripts string contains "
225         << "too many subscripts in the output";
226 
227       out_labels[ndim++] = label;
228     } else if (label == '.') {
229       /* The beginning of the ellipsis. */
230       /* Check it is a proper ellipsis. */
231       CHECK(!(ellipsis || i + 2 >= length
232               || subscripts[++i] != '.' || subscripts[++i] != '.'))
233         << "einstein sum subscripts string "
234         << "contains a '.' that is not part of "
235         << "an ellipsis ('...') in the output";
236 
237       /* Check there is room in out_labels for broadcast dims. */
238       CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS)
239         << "einstein sum subscripts string contains "
240         << "too many subscripts in the output";
241 
242       ellipsis = 1;
243       for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
244         out_labels[ndim++] = 0;
245       }
246     } else {
247       CHECK(label == ' ')
248         << "invalid subscript '" << static_cast<char>(label)
249         << "' in einstein sum "
250         << "subscripts string, subscripts must "
251         << "be letters";
252     }
253   }
254 
255   /* If no ellipsis was found there should be no broadcast dimensions. */
256   CHECK(!(!ellipsis && ndim_broadcast > 0))
257     << "output has more dimensions than subscripts "
258     << "given in einstein sum, but no '...' ellipsis "
259     << "provided to broadcast the extra dimensions.";
260 
261   return ndim;
262 }
263 
get_combined_dims_view(const TBlob & op,int iop,char * labels,TShape * newshape,TShape * newstride)264 inline void get_combined_dims_view(const TBlob& op, int iop,
265                                    char *labels,
266                                    TShape* newshape,
267                                    TShape* newstride) {
268   using namespace mxnet_op;
269   int idim, ndim, icombine, combineoffset;
270   int icombinemap[NPY_MAXDIMS];
271   int newdim;
272 
273   const TShape& shape = op.shape_;
274   TShape stride = get_stride(shape);
275   ndim = op.shape_.ndim();
276   newdim = newshape->ndim();
277 
278   /* Initialize the dimensions and strides to zero */
279   for (idim = 0; idim < newdim; ++idim) {
280     (*newshape)[idim] = 0;
281     (*newstride)[idim] = 0;
282   }
283 
284   /* Copy the dimensions and strides, except when collapsing */
285   icombine = 0;
286   for (idim = 0; idim < ndim; ++idim) {
287     /*
288      * The char type may be either signed or unsigned, we
289      * need it to be signed here.
290      */
291     int label = (signed char)labels[idim];
292     /* If this label says to merge axes, get the actual label */
293     if (label < 0) {
294       combineoffset = label;
295       label = labels[idim+label];
296     } else {
297       combineoffset = 0;
298       if (icombine != idim) {
299         labels[icombine] = labels[idim];
300       }
301       icombinemap[idim] = icombine;
302     }
303     /* If the label is 0, it's an unlabeled broadcast dimension */
304     if (label == 0) {
305       (*newshape)[icombine] = shape[idim];
306       (*newstride)[icombine] = stride[idim];
307     } else {
308       /* Update the combined axis dimensions and strides */
309       int i = icombinemap[idim + combineoffset];
310       CHECK(!(combineoffset < 0 && (*newshape)[i] != 0 &&
311               (*newshape)[i] != shape[idim]))
312         << "dimensions in operand " << iop
313         << " for collapsing index '" << label
314         << "' don't match (" << static_cast<int>((*newshape)[i])
315         << " != " << shape[idim] << ")";
316       (*newshape)[i] = shape[idim];
317       (*newstride)[i] += stride[idim];
318     }
319 
320     /* If the label didn't say to combine axes, increment dest i */
321     if (combineoffset == 0) {
322       icombine++;
323     }
324   }
325 }
326 
prepare_op_axes(int ndim,int iop,char * labels,int * axes,int ndim_iter,char * iter_labels)327 inline static int prepare_op_axes(int ndim, int iop, char *labels,
328                                   int *axes, int ndim_iter, char *iter_labels) {
329   using namespace mxnet_op;
330   int i, label, ibroadcast;
331 
332   ibroadcast = ndim-1;
333   for (i = ndim_iter-1; i >= 0; --i) {
334     label = iter_labels[i];
335     /*
336      * If it's an unlabeled broadcast dimension, choose
337      * the next broadcast dimension from the operand.
338      */
339     if (label == 0) {
340       while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
341         --ibroadcast;
342       }
343       /*
344        * If we used up all the operand broadcast dimensions,
345        * extend it with a "newaxis"
346        */
347       if (ibroadcast < 0) {
348         axes[i] = -1;
349       } else {
350         /* Otherwise map to the broadcast axis */
351         axes[i] = ibroadcast;
352         --ibroadcast;
353       }
354     } else {
355       /* It's a labeled dimension, find the matching one */
356       char *match = reinterpret_cast<char*>(memchr(labels, label, ndim));
357       /* If the op doesn't have the label, broadcast it */
358       if (match == nullptr) {
359         axes[i] = -1;
360       } else {
361         /* Otherwise use it */
362         axes[i] = match - labels;
363       }
364     }
365   }
366   return 0;
367 }
368 
369 struct NumpyEinsumParam: public dmlc::Parameter<NumpyEinsumParam> {
370   int num_args;
371   int optimize;
372   std::string  subscripts;
DMLC_DECLARE_PARAMETERNumpyEinsumParam373   DMLC_DECLARE_PARAMETER(NumpyEinsumParam) {
374     DMLC_DECLARE_FIELD(num_args)
375       .set_lower_bound(1)
376       .describe("Number of input arrays.");
377     DMLC_DECLARE_FIELD(subscripts)
378       .set_default("")
379       .describe("Specifies the subscripts for summation as comma separated list"
380       " of subscript labels. An implicit (classical Einstein summation) calculation"
381       " is performed unless the explicit indicator '->' is included as well as"
382       " subscript labels of the precise output form.");
383     DMLC_DECLARE_FIELD(optimize)
384       .set_default(0);
385   }
386 };
387 
388 class EinsumOp {
389  public:
390   int num_args;
391   int optimize;
392   std::string subscripts;
393   std::shared_ptr<NDArray> tempspace;
394   std::vector<Step> paths;
EinsumOp(int num_args,int optimize,std::string subscripts)395   explicit EinsumOp(int num_args, int optimize, std::string subscripts) {
396     this->num_args = num_args;
397     this->optimize = optimize;
398     this->subscripts = subscripts;
399   }
400 };  // class EinsumOp
401 
402 template<int dimension, int req, bool back, typename AType>
403 struct numpy_einsum{
404   template<typename DType>
Mapnumpy_einsum405   MSHADOW_XINLINE static void Map(index_t i, DType* out,
406                                   common::StaticArray<DType*, NPY_MAXARGS> op,
407                                   mshadow::Shape<dimension> oshape,
408                                   common::StaticArray<mshadow::Shape<dimension>,
409                                                       NPY_MAXARGS> ostride,
410                                   mshadow::Shape<dimension> reduceshape,
411                                   common::StaticArray<mshadow::Shape<dimension>,
412                                                       NPY_MAXARGS> rstride,
413                                   int nop,
414                                   int iop0,
415                                   const DType* out_grad) {
416     using namespace mxnet_op;
417     mshadow::Shape<dimension> oidx = unravel(i, oshape);
418     i = back ? dot(oidx, ostride[iop0]) : i;
419     if (req == kWriteTo) {
420       out[i] = (DType)0;
421     }
422     for (int rdim = 0; rdim < dimension; ++rdim) {
423       if (reduceshape[rdim] == 0) {
424         return;
425       }
426     }
427     mshadow::Shape<dimension> ridx = unravel(0, reduceshape);
428     AType sum = 0;
429     do {
430       AType tmp = back ? static_cast<AType>(out_grad[dot(oidx, ostride[nop]) +
431                                                      dot(ridx, rstride[nop])]): (AType)1;
432       for (int iop = 0; iop < nop; ++iop) {
433         if (iop != iop0) {
434           index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]);
435           tmp = tmp * static_cast<AType>(op[iop][k]);
436         }
437       }
438       sum = sum + tmp;
439     }while (inc(&ridx, reduceshape));
440     out[i] = out[i] + static_cast<DType>(sum);
441   }
442 };
443 
444 template<typename xpu, bool back>
NumpyEinsumProcess(const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs,const char * subscripts,int nop,const OpContext & ctx)445 inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
446                                const std::vector<OpReqType>& req,
447                                const std::vector<TBlob>& outputs,
448                                const char *subscripts, int nop,
449                                const OpContext& ctx) {
450   using namespace mxnet_op;
451 
452   /* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
453   CHECK(nop < NPY_MAXARGS)
454     << "too many operands provided to einstein sum function";
455   CHECK(nop >= 1)
456     << "not enough operands provided to einstein sum function";
457 
458   /* Step 1: Parse the subscripts string into label_counts and op_labels */
459   int iop, idim, min_label = 127, max_label = 0;
460   char label_counts[128], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
461   memset(label_counts, 0, sizeof(label_counts));
462   for (iop = 0; iop < nop; ++iop) {
463     int length = static_cast<int>(strcspn(subscripts, ",-"));
464 
465     CHECK(!(iop == nop - 1 && subscripts[length] == ','))
466       << "more operands provided to einstein sum function "
467       << "than specified in the subscripts string";
468     CHECK(!(iop < nop-1 && subscripts[length] != ','))
469       << "fewer operands provided to einstein sum function "
470       << "than specified in the subscripts string";
471     CHECK_GE(parse_operand_subscripts(subscripts, length,
472                                       inputs[iop + back].shape_.ndim(),
473                                       iop, op_labels[iop], label_counts,
474                                       &min_label, &max_label), 0);
475 
476     /* Move subscripts to the start of the labels for the next op */
477     subscripts += length;
478     if (iop < nop - 1) {
479       subscripts++;
480     }
481   }
482 
483   /*
484    * Find the number of broadcast dimensions, which is the maximum
485    * number of labels == 0 in an op_labels array.
486    */
487   int ndim_broadcast = 0;
488   for (iop = 0; iop < nop; ++iop) {
489     int count_zeros = 0;
490     int ndim;
491     char *labels = op_labels[iop];
492 
493     ndim = inputs[iop + back].shape_.ndim();
494     for (idim = 0; idim < ndim; ++idim) {
495       if (labels[idim] == 0) {
496         ++count_zeros;
497       }
498     }
499 
500     if (count_zeros > ndim_broadcast) {
501       ndim_broadcast = count_zeros;
502     }
503   }
504 
505   /*
506    * If there is no output signature, fill output_labels and ndim_output
507    * using each label that appeared once, in alphabetical order.
508    */
509   int label, ndim_output;
510   char output_labels[NPY_MAXDIMS];
511   if (subscripts[0] == '\0') {
512     /* If no output was specified, always broadcast left, as usual. */
513     for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
514       output_labels[ndim_output] = 0;
515     }
516     for (label = min_label; label <= max_label; ++label) {
517       if (label_counts[label] == 1) {
518         CHECK(ndim_output < NPY_MAXDIMS)
519           << "einstein sum subscript string has too many "
520           << "distinct labels";
521         output_labels[ndim_output++] = label;
522       }
523     }
524   } else {
525     CHECK(subscripts[0] == '-' && subscripts[1] == '>')
526       << "einstein sum subscript string does not "
527       << "contain proper '->' output specified";
528     subscripts += 2;
529 
530     /* Parse the output subscript string. */
531     ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
532                                           ndim_broadcast, label_counts,
533                                           output_labels);
534     CHECK_GE(ndim_output, 0);
535   }
536 
537   /*
538    * Step 2:
539    * Process all the input ops, combining dimensions into their
540    * diagonal where specified.
541    */
542   std::vector<TShape> opshape(nop), opstride_true(nop);
543   for (iop = 0; iop < nop; ++iop) {
544     char *labels = op_labels[iop];
545     int combine, ndim;
546 
547     ndim = inputs[iop + back].shape_.ndim();
548 
549     /*
550      * Check whether any dimensions need to be combined
551      *
552      * The char type may be either signed or unsigned, we
553      * need it to be signed here.
554      */
555     combine = 0;
556     for (idim = 0; idim < ndim; ++idim) {
557       if ((signed char)labels[idim] < 0) {
558         combine++;
559       }
560     }
561 
562     /* If any dimensions are combined, create a view which combines them */
563     if (combine) {
564       TShape tshape(ndim - combine, -1);
565       TShape tstride(ndim - combine, -1);
566       get_combined_dims_view(inputs[iop + back], iop, labels,
567                              &tshape, &tstride);
568       opshape[iop] = tshape;
569       opstride_true[iop] = tstride;
570     } else {
571       /* No combining needed */
572       opshape[iop] = inputs[iop + back].shape_;
573       opstride_true[iop] = get_stride(opshape[iop]);
574     }
575   }
576 
577   /*
578    * Step 3:
579    * Set up the labels for the iterator (output + combined labels).
580    * Can just share the output_labels memory, because iter_labels
581    * is output_labels with some more labels appended.
582    */
583   char *iter_labels = output_labels;
584   int ndim_iter = ndim_output;
585   for (label = min_label; label <= max_label; ++label) {
586     if (label_counts[label] > 0 &&
587         memchr(output_labels, label, ndim_output) == nullptr) {
588       CHECK(ndim_iter < NPY_MAXDIMS)
589         << "too many subscripts in einsum";
590       iter_labels[ndim_iter++] = label;
591     }
592   }
593 
594   /* Step 4: Set up the op_axes for the iterator */
595   TShape itershape(ndim_iter, -1);
596   std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
597   TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_;
598   TShape ostride_true = get_stride(oshape);
599   TShape reduceshape;
600   std::vector<TShape> remainshape(nop);
601   int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
602   int *op_axes[NPY_MAXARGS];
603 
604   for (iop = 0; iop < nop; ++iop) {
605     op_axes[iop] = op_axes_arrays[iop];
606     CHECK_GE(prepare_op_axes(opshape[iop].ndim(), iop, op_labels[iop],
607              op_axes[iop], ndim_iter, iter_labels), 0);
608     for (idim = 0; idim < ndim_iter; idim++) {
609       if (op_axes[iop][idim] != -1) {
610         iterstride[iop][idim] = opstride_true[iop][op_axes[iop][idim]];
611         if (itershape[idim] != -1) {
612           if (itershape[idim] == 1) {
613             itershape[idim] = opshape[iop][op_axes[iop][idim]];
614           }
615         } else {
616           itershape[idim] = opshape[iop][op_axes[iop][idim]];
617         }
618       }
619     }
620   }
621   for (idim = 0; idim < ndim_output; ++idim) {
622     iterstride[nop][idim] = ostride_true[idim];
623   }
624   reduceshape = TShape(ndim_iter - ndim_output, 0);
625   for (idim = ndim_output; idim < ndim_iter; ++idim) {
626     reduceshape[idim - ndim_output] = itershape[idim];
627   }
628   for (iop = 0; iop < nop; iop++) {
629     std::vector<size_t> rsh;
630     for (idim = 0; idim < ndim_iter; idim++) {
631       if (op_axes_arrays[iop][idim] == -1 ||
632           itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]]) {
633         rsh.push_back(itershape[idim]);
634       }
635     }
636     remainshape[iop] = TShape(rsh.begin(), rsh.end());
637   }
638 
639   // exclude the 0-dim case
640   if (ndim_iter == 0) {
641     ndim_iter = 1;
642   }
643   itershape = pad(itershape, ndim_iter);
644   for (iop = 0; iop <= nop; ++iop) {
645     iterstride[iop] = pad(iterstride[iop], ndim_iter);
646   }
647   oshape = pad(oshape, ndim_iter);
648   reduceshape = pad(reduceshape, ndim_iter);
649   for (iop = 0; iop < nop; ++iop) {
650     opshape[iop] = pad(opshape[iop], ndim_iter);
651     remainshape[iop] = pad(remainshape[iop], ndim_iter);
652   }
653 
654   if (!back) {
655     if (oshape.Size() == 0) {
656       return;
657     }
658     const TBlob &out_data = outputs[0];
659     MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
660       mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
661       for (iop = 0; iop < nop; ++iop) {
662         op[iop] = inputs[iop].dptr<DType>();
663       }
664       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
665         MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
666           mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride_arr;
667           mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride_arr;
668           for (iop = 0; iop < nop; ++iop) {
669             mshadow::Shape<dimension> otmp, rtmp;
670             for (idim = 0; idim < dimension; ++idim) {
671               otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1;
672               rtmp[idim] = idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1;
673             }
674             ostride_arr[iop] = otmp;
675             rstride_arr[iop] = rtmp;
676           }
677           Kernel<numpy_einsum<dimension, req_type, 0, AType>,
678                  xpu>::Launch(ctx.get_stream<xpu>(),
679                               oshape.Size(),
680                               out_data.dptr<DType>(),
681                               op,
682                               oshape.get<dimension>(),
683                               ostride_arr,
684                               reduceshape.get<dimension>(),
685                               rstride_arr,
686                               nop,
687                               -1,
688                               (DType*)nullptr);
689         })
690       })
691     })
692   } else {
693     if (oshape.Size() == 0) {
694       for (iop = 0; iop < nop; ++iop) {
695         const TBlob& out_data = outputs[iop];
696         if (opshape[iop].Size() > 0) {
697           MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
698             MXNET_ASSIGN_REQ_SWITCH(req[iop], req_type, {
699               if (req_type == kWriteTo) {
700                 out_data.FlatTo1D<xpu, DType>(ctx.get_stream<xpu>()) = 0;
701               }
702             })
703           })
704         }
705       }
706       return;
707     }
708     for (int i = 0; i < nop; ++i) {
709       const TBlob &out_data = outputs[i];
710       const TBlob &out_grad = inputs[0];
711       std::vector<TShape> opstride(nop + 1, TShape(ndim_iter, 0));
712       std::vector<TShape> remainstride(nop + 1, TShape(ndim_iter, 0));
713       for (iop = 0; iop <= nop; ++iop) {
714         int j = 0;
715         for (idim = 0; idim < ndim_iter; ++idim) {
716           if (op_axes_arrays[i][idim] == -1 ||
717               (iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 &&
718               op_axes_arrays[iop][idim] != -1 &&
719               opshape[iop][op_axes_arrays[iop][idim]] != 1)) {
720             remainstride[iop][j++] = iterstride[iop][idim];
721           } else {
722             opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim];
723           }
724         }
725       }
726       MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
727         mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
728         for (iop = 0; iop < nop; ++iop) {
729           op[iop] = inputs[iop + back].dptr<DType>();
730         }
731         MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, {
732           MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
733             mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> opstride_arr;
734             mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> remainstride_arr;
735             for (iop = 0; iop <= nop; ++iop) {
736               opstride_arr[iop] = opstride[iop].get<dimension>();
737               remainstride_arr[iop] = remainstride[iop].get<dimension>();
738             }
739             Kernel<numpy_einsum<dimension, req_type, 1, AType>,
740                   xpu>::Launch(ctx.get_stream<xpu>(),
741                                opshape[i].Size(),
742                                out_data.dptr<DType>(),
743                                op,
744                                opshape[i].get<dimension>(),
745                                opstride_arr,
746                                remainshape[i].get<dimension>(),
747                                remainstride_arr,
748                                nop,
749                                i,
750                                out_grad.dptr<DType>());
751           })
752         })
753       })
754     }
755   }
756 }
757 
758 template<typename xpu>
NumpyEinsumForward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)759 inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
760                                const OpContext& ctx,
761                                const std::vector<TBlob>& inputs,
762                                const std::vector<OpReqType>& req,
763                                const std::vector<TBlob>& outputs) {
764   using namespace mshadow;
765   using namespace mxnet_op;
766   EinsumOp& state = state_ptr.get_state<EinsumOp>();
767   int num_args = state.num_args;
768   int optimize = state.optimize;
769   const char* subscripts = state.subscripts.c_str();
770   Stream<xpu> *s = ctx.get_stream<xpu>();
771   CHECK_EQ(inputs.size(), num_args);
772   CHECK_EQ(outputs.size(), 1U);
773   if (optimize == 0) {
774     NumpyEinsumProcess<xpu, 0>(inputs, req, outputs, subscripts, num_args, ctx);
775     return;
776   }
777   std::vector<Step>& paths = state.paths;
778   std::vector<std::vector<int> > pos;
779   std::string string_repr;
780   paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
781   int paths_len = paths.size();
782   size_t temp_space_size = 0, max_temp_space_size = 0;
783   std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
784   for (int i = 0; i + 1 < paths_len; ++i) {
785     temp_space_size += paths[i].oshape.Size();
786   }
787   for (int i = 0; i < paths_len; ++i) {
788     max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
789   }
790   temp_space_size += max_temp_space_size;
791   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
792     state.tempspace.reset<NDArray>(new NDArray(TShape(Shape1(temp_space_size)),
793                                                ctx.run_ctx.ctx,
794                                                false,
795                                                outputs[0].type_flag_));
796     Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
797     size_t begin = max_temp_space_size;
798     for (int i = 0; i < paths_len - 1; ++i) {
799       TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
800       temp_space_vec[i] = tblob.reshape(paths[i].oshape);
801       begin = begin + paths[i].oshape.Size();
802     }
803     for (int i = 0; i < paths_len; ++i) {
804       tmp_operands.clear();
805 
806       // We remove inds from right to left
807       for (const int& p : paths[i].contract_inds) {
808         tmp_operands.push_back(operands[p]);
809         operands.erase(operands.begin() + p);
810       }
811       bool handle_out = (i == paths_len - 1);
812       // Call tensordot if still possible
813       if (paths[i].do_blas) {
814         // Contract!
815         if (paths[i].do_einsum || handle_out) {
816           TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
817           max_temp_space.FlatTo1D<xpu, DType>(s) = 0;
818           max_temp_space = max_temp_space.reshape(paths[i].tshape);
819           size_t tensordot_tempspace_size =
820             TensordotWorkspaceSize<xpu>(paths[i].left_pos,
821                                         paths[i].right_pos,
822                                         tmp_operands[0],
823                                         tmp_operands[1],
824                                         max_temp_space,
825                                         std::vector<OpReqType>{OpReqType::kWriteTo});
826           Tensor<xpu, 1, char> tensordot_tempspace =
827             ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tensordot_tempspace_size), s);
828           TensordotImpl<xpu>(paths[i].left_pos,
829                              paths[i].right_pos,
830                              ctx,
831                              tmp_operands[0],
832                              tmp_operands[1],
833                              max_temp_space,
834                              std::vector<OpReqType>{OpReqType::kWriteTo},
835                              tensordot_tempspace);
836           NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{max_temp_space},
837             handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
838             handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
839             paths[i].blas2einsum_str.c_str(),
840             1, ctx);
841         } else {
842           size_t tensordot_tempspace_size =
843             TensordotWorkspaceSize<xpu>(paths[i].left_pos,
844                                         paths[i].right_pos,
845                                         tmp_operands[0],
846                                         tmp_operands[1],
847                                         temp_space_vec[i],
848                                         std::vector<OpReqType>{OpReqType::kWriteTo});
849           Tensor<xpu, 1, char> tensordot_tempspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
850             Shape1(tensordot_tempspace_size), s);
851           TensordotImpl<xpu>(paths[i].left_pos,
852                              paths[i].right_pos,
853                              ctx,
854                              tmp_operands[0],
855                              tmp_operands[1],
856                              temp_space_vec[i],
857                              std::vector<OpReqType>{OpReqType::kWriteTo},
858                              tensordot_tempspace);
859         }
860       } else {
861         NumpyEinsumProcess<xpu, 0>(tmp_operands,
862         handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
863         handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
864         paths[i].einsum_str.c_str(), tmp_operands.size(), ctx);
865       }
866       if (!handle_out) {
867         operands.push_back(temp_space_vec[i]);
868       }
869     }
870   });
871 }
872 
873 template<typename xpu>
NumpyEinsumBackward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)874 inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
875                                 const OpContext& ctx,
876                                 const std::vector<TBlob>& inputs,
877                                 const std::vector<OpReqType>& req,
878                                 const std::vector<TBlob>& outputs) {
879   using namespace mshadow;
880   using namespace mshadow_op;
881   const EinsumOp& state = state_ptr.get_state<EinsumOp>();
882   int num_args = state.num_args;
883   int optimize = state.optimize;
884   const char* subscripts = state.subscripts.c_str();
885   Stream<xpu> *s = ctx.get_stream<xpu>();
886   CHECK_EQ(inputs.size(), 1 + num_args);
887   CHECK_EQ(outputs.size(), num_args);
888   if (optimize == 0) {
889     NumpyEinsumProcess<xpu, 1>(inputs, req, outputs, subscripts, num_args, ctx);
890     return;
891   }
892   // calculate temporary space size for temp_grad
893   const std::vector<Step>& paths = state.paths;
894   int paths_len = paths.size();
895   size_t temp_space_size = 0, max_temp_space_size = 0;
896   for (int i = 0; i < paths_len - 1; ++i) {
897     temp_space_size += paths[i].oshape.Size();
898   }
899   for (int i = 0; i < paths_len; ++i) {
900     max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
901   }
902   temp_space_size += max_temp_space_size;
903   // replay the forward process
904   std::vector<std::vector<int> > op_idx(paths_len + 1);
905   for (int i = 0; i <= paths_len; ++i) {
906     if (i == 0) {
907       op_idx[i].reserve(num_args);
908       for (int j = 0; j < num_args; ++j) {
909         op_idx[i].push_back(j + 1);
910       }
911     } else {
912       op_idx[i] = op_idx[i - 1];
913       // We remove inds from right to left
914       for (const int& p : paths[i - 1].contract_inds) {
915         op_idx[i].erase(op_idx[i].begin() + p);
916       }
917       op_idx[i].push_back(-static_cast<int>(i - 1));
918     }
919   }
920   // calculate temporary space size for tensordot
921   size_t tensordot_max_tempspace_size = 0;
922   size_t begin_tensordot_tempspace = 0;
923   std::vector<TBlob> temp_inputs, temp_outputs;
924   std::vector<OpReqType> temp_req;
925   std::vector<size_t> tensordot_tempspace_size;
926   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
927     for (int i = 0; i < paths_len; ++i) {
928       temp_inputs.clear();
929       temp_outputs.clear();
930       temp_req.clear();
931       bool handle_out = (i == paths_len - 1);
932 
933       if (handle_out) {
934         temp_inputs.push_back(inputs[0]);
935       } else {
936         temp_inputs.push_back(TBlob((DType*)nullptr,
937                                     paths[i].oshape,
938                                     xpu::kDevMask));
939       }
940       for (auto p : paths[i].contract_inds) {
941         int idx = op_idx[i][p];
942         if (idx >= 1) {
943           temp_inputs.push_back(inputs[idx]);
944           temp_outputs.push_back(outputs[idx - 1]);
945           temp_req.push_back(req[idx - 1]);
946         } else {
947           temp_inputs.push_back(TBlob((DType*)nullptr,
948                                       paths[-idx].oshape,
949                                       xpu::kDevMask));
950           temp_outputs.push_back(TBlob((DType*)nullptr,
951                                       paths[-idx].oshape,
952                                       xpu::kDevMask));
953           temp_req.push_back(OpReqType::kWriteTo);
954         }
955       }
956       size_t cur_tensordot_tempspace_size = 0;
957       if (paths[i].do_blas) {
958         if (paths[i].do_einsum) {
959           cur_tensordot_tempspace_size =
960             TensordotBackwardWorkspaceSize<xpu>(paths[i].left_pos,
961                                                 paths[i].right_pos,
962                                                 TBlob((DType*)nullptr,
963                                                       paths[i].tshape,
964                                                       xpu::kDevMask),
965                                                 temp_inputs[1],
966                                                 temp_inputs[2],
967                                                 temp_outputs[0],
968                                                 temp_outputs[1],
969                                                 temp_req);
970         } else {
971           cur_tensordot_tempspace_size =
972             TensordotBackwardWorkspaceSize<xpu>(paths[i].left_pos,
973                                                 paths[i].right_pos,
974                                                 temp_inputs[0],
975                                                 temp_inputs[1],
976                                                 temp_inputs[2],
977                                                 temp_outputs[0],
978                                                 temp_outputs[1],
979                                                 temp_req);
980         }
981       }
982       tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size);
983       tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size,
984                                               cur_tensordot_tempspace_size);
985     }
986     begin_tensordot_tempspace = temp_space_size;
987     temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType);
988   });
989   // allocate temporary space and propagate
990   std::vector<TBlob> temp_grad(paths_len - 1), temp_data(paths_len - 1);
991   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
992     // allocate temporary space for gradients of intermediate results
993     Tensor<xpu, 1, DType> temp_space = ctx.requested[0].get_space_typed<xpu, 1, DType>
994       (Shape1(temp_space_size), s);
995     size_t begin = max_temp_space_size;
996     for (int i = 0; i + 1 < paths_len; ++i) {
997       TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
998       temp_grad[i] = tblob.reshape(paths[i].oshape);
999       begin = begin + paths[i].oshape.Size();
1000     }
1001 
1002     // reinterprete ndarray for intermediate results
1003     Tensor<xpu, 1, DType> ndarray_space = state.tempspace->data().FlatTo1D<xpu, DType>();
1004     begin = max_temp_space_size;
1005     for (int i = 0; i + 1 < paths_len; ++i) {
1006       TBlob tblob = TBlob(ndarray_space.Slice(begin, begin + paths[i].oshape.Size()));
1007       temp_data[i] = tblob.reshape(paths[i].oshape);
1008       begin = begin + paths[i].oshape.Size();
1009     }
1010 
1011     // go through the paths in the reversed order
1012     for (int i = paths_len - 1; i >= 0; i--) {
1013       temp_inputs.clear();
1014       temp_outputs.clear();
1015       temp_req.clear();
1016       bool handle_out = (i == paths_len - 1);
1017 
1018       if (handle_out) {
1019         temp_inputs.push_back(inputs[0]);
1020       } else {
1021         temp_inputs.push_back(temp_grad[i]);
1022       }
1023       for (auto p : paths[i].contract_inds) {
1024         int idx = op_idx[i][p];
1025         if (idx >= 1) {
1026           temp_inputs.push_back(inputs[idx]);
1027           temp_outputs.push_back(outputs[idx - 1]);
1028           temp_req.push_back(req[idx - 1]);
1029         } else {
1030           temp_inputs.push_back(temp_data[-idx]);
1031           temp_outputs.push_back(temp_grad[-idx]);
1032           temp_req.push_back(OpReqType::kWriteTo);
1033         }
1034       }
1035       if (paths[i].do_blas) {
1036         CHECK_EQ(temp_inputs.size(), 3U);
1037         CHECK_EQ(temp_outputs.size(), 2U);
1038         CHECK_EQ(temp_req.size(), 2U);
1039         Tensor<xpu, 1, DType> tensordot_tempspace = temp_space.Slice(begin_tensordot_tempspace,
1040                                                                      temp_space_size);
1041         Tensor<xpu, 1, char> char_tempspace =
1042           Tensor<xpu, 1, char>(reinterpret_cast<char*>(tensordot_tempspace.dptr_),
1043                                                        Shape1(tensordot_tempspace_size[i]),
1044                                                        tensordot_tempspace.stream_);
1045         if (paths[i].do_einsum) {
1046           TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
1047           max_temp_space = max_temp_space.reshape(paths[i].tshape);
1048           NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{temp_inputs[0]},
1049                                      std::vector<OpReqType>{kWriteTo},
1050                                      std::vector<TBlob>{max_temp_space},
1051                                      paths[i].einsum2blas_str.c_str(),
1052                                      1, ctx);
1053           TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
1054                                      max_temp_space, temp_inputs[1], temp_inputs[2],
1055                                      temp_outputs[0], temp_outputs[1], temp_req, char_tempspace);
1056         } else {
1057           TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
1058                                      temp_inputs[0], temp_inputs[1], temp_inputs[2],
1059                                      temp_outputs[0], temp_outputs[1], temp_req, char_tempspace);
1060         }
1061       } else {
1062         NumpyEinsumProcess<xpu, 1>(temp_inputs, temp_req, temp_outputs,
1063                                    paths[i].einsum_str.c_str(),
1064                                    temp_outputs.size(),
1065                                    ctx);
1066       }
1067     }
1068   });
1069 }
1070 
1071 }  // namespace op
1072 }  // namespace mxnet
1073 
1074 #endif  // MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
1075