1 /*
2 * Copyright (c) 2005-2019, NumPy Developers.
3 *
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are
8 * met:
9 *
10 * * Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *
13 * * Redistributions in binary form must reproduce the above
14 * copyright notice, this list of conditions and the following
15 * disclaimer in the documentation and/or other materials provided
16 * with the distribution.
17 *
18 * * Neither the name of the NumPy Developers nor the names of any
19 * contributors may be used to endorse or promote products derived
20 * from this software without specific prior written permission.
21 *
22 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
34
35 /*!
36 * \file np_einsum_op-inl.h
37 * \brief Function definition of numpy-compatible einsum operator
38 * modified by Haozheng Fan(@hzfan) from:
39 * https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/einsum.c.src
40 */
41
42 #ifndef MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
43 #define MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
44
45 #include <mxnet/operator_util.h>
46 #include <string>
47 #include <vector>
48 #include <algorithm>
49 #include "./np_tensordot_op-inl.h"
50 #include "./np_einsum_path_op-inl.h"
51 #include "../../common/static_array.h"
52 #include "../mxnet_op.h"
53 #include "../operator_common.h"
54 #include "../mshadow_op.h"
55 #include "../elemwise_op_common.h"
56
57 namespace mxnet {
58 namespace op {
59
60 #define NPY_MAXDIMS 16
61 #define NPY_MAXARGS 16
62
get_stride(const TShape & shape)63 inline TShape get_stride(const TShape& shape) {
64 int ndim = shape.ndim(), prod = 1;
65 TShape stride = TShape(ndim, -1);
66 for (int i = ndim - 1; i >= 0; i--) {
67 stride[i] = shape[i] > 1 ? prod : 0;
68 prod = prod * shape[i];
69 }
70 return stride;
71 }
72
pad(const TShape & shape,int odim)73 inline TShape pad(const TShape& shape, int odim) {
74 int ndim = shape.ndim();
75 CHECK_GE(odim, ndim);
76 TShape ret(odim, 1);
77 for (int idim = 0; idim < ndim; ++idim) {
78 ret[idim] = shape[idim];
79 }
80 return ret;
81 }
82
83 /*
84 * Parses the subscripts for one operand into an output of 'ndim'
85 * labels. The resulting 'op_labels' array will have:
86 * - the ASCII code of the label for the first occurrence of a label;
87 * - the (negative) offset to the first occurrence of the label for
88 * repeated labels;
89 * - zero for broadcast dimensions, if subscripts has an ellipsis.
90 * For example:
91 * - subscripts="abbcbc", ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
92 * - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
93 */
parse_operand_subscripts(const char * subscripts,int length,int ndim,int iop,char * op_labels,char * label_counts,int * min_label,int * max_label)94 inline int parse_operand_subscripts(const char *subscripts, int length,
95 int ndim, int iop, char *op_labels,
96 char *label_counts, int *min_label, int *max_label) {
97 using namespace mxnet_op;
98 int i;
99 int idim = 0;
100 int ellipsis = -1;
101
102 /* Process all labels for this operand */
103 for (i = 0; i < length; ++i) {
104 int label = subscripts[i];
105
106 /* A proper label for an axis. */
107 if (label > 0 && isalpha(label)) {
108 /* Check we don't exceed the operator dimensions. */
109 CHECK(idim < ndim)
110 << "einstein sum subscripts string contains "
111 << "too many subscripts for operand "
112 << iop;
113
114 op_labels[idim++] = label;
115 if (label < *min_label) {
116 *min_label = label;
117 }
118 if (label > *max_label) {
119 *max_label = label;
120 }
121 label_counts[label]++;
122 } else if (label == '.') {
123 /* The beginning of the ellipsis. */
124 /* Check it's a proper ellipsis. */
125 CHECK(!(ellipsis != -1 || i + 2 >= length
126 || subscripts[++i] != '.' || subscripts[++i] != '.'))
127 << "einstein sum subscripts string contains a "
128 << "'.' that is not part of an ellipsis ('...') "
129 << "in operand "
130 << iop;
131
132 ellipsis = idim;
133 } else {
134 CHECK(label == ' ')
135 << "invalid subscript '" << static_cast<char>(label)
136 << "' in einstein sum "
137 << "subscripts string, subscripts must "
138 << "be letters";
139 }
140 }
141
142 /* No ellipsis found, labels must match dimensions exactly. */
143 if (ellipsis == -1) {
144 CHECK(idim == ndim)
145 << "operand has more dimensions than subscripts "
146 << "given in einstein sum, but no '...' ellipsis "
147 << "provided to broadcast the extra dimensions.";
148 } else if (idim < ndim) {
149 /* Ellipsis found, may have to add broadcast dimensions. */
150 /* Move labels after ellipsis to the end. */
151 for (i = 0; i < idim - ellipsis; ++i) {
152 op_labels[ndim - i - 1] = op_labels[idim - i - 1];
153 }
154 /* Set all broadcast dimensions to zero. */
155 for (i = 0; i < ndim - idim; ++i) {
156 op_labels[ellipsis + i] = 0;
157 }
158 }
159
160 /*
161 * Find any labels duplicated for this operand, and turn them
162 * into negative offsets to the axis to merge with.
163 *
164 * In C, the char type may be signed or unsigned, but with
165 * twos complement arithmetic the char is ok either way here, and
166 * later where it matters the char is cast to a signed char.
167 */
168 for (idim = 0; idim < ndim - 1; ++idim) {
169 int label = op_labels[idim];
170 /* If it is a proper label, find any duplicates of it. */
171 if (label > 0) {
172 /* Search for the next matching label. */
173 char *next = reinterpret_cast<char*>(memchr(op_labels + idim + 1, label, ndim - idim - 1));
174
175 while (next != nullptr) {
176 /* The offset from next to op_labels[idim] (negative). */
177 *next = static_cast<char>((op_labels + idim) - next);
178 /* Search for the next matching label. */
179 next = reinterpret_cast<char*>(memchr(next + 1, label, op_labels + ndim - 1 - next));
180 }
181 }
182 }
183 return 0;
184 }
185
186 /*
187 * Parses the subscripts for the output operand into an output that
188 * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
189 * number of output dimensions, or -1 if there is an error. Similarly
190 * to parse_operand_subscripts, the 'out_labels' array will have, for
191 * each dimension:
192 * - the ASCII code of the corresponding label;
193 * - zero for broadcast dimensions, if subscripts has an ellipsis.
194 */
parse_output_subscripts(const char * subscripts,int length,int ndim_broadcast,const char * label_counts,char * out_labels)195 inline int parse_output_subscripts(const char *subscripts, int length,
196 int ndim_broadcast,
197 const char *label_counts, char *out_labels) {
198 using namespace mxnet_op;
199 int i, bdim;
200 int ndim = 0;
201 int ellipsis = 0;
202
203 /* Process all the output labels. */
204 for (i = 0; i < length; ++i) {
205 int label = subscripts[i];
206
207 /* A proper label for an axis. */
208 if (label > 0 && isalpha(label)) {
209 /* Check that it doesn't occur again. */
210 CHECK(memchr(subscripts + i + 1, label, length - i - 1) == nullptr)
211 << "einstein sum subscripts string includes "
212 << "output subscript '" << static_cast<char>(label)
213 << "' multiple times";
214
215 /* Check that it was used in the inputs. */
216 CHECK(label_counts[label] != 0)
217 << "einstein sum subscripts string included "
218 << "output subscript '" << static_cast<char>(label)
219 << "' which never appeared "
220 << "in an input";
221
222 /* Check that there is room in out_labels for this label. */
223 CHECK(ndim < NPY_MAXDIMS)
224 << "einstein sum subscripts string contains "
225 << "too many subscripts in the output";
226
227 out_labels[ndim++] = label;
228 } else if (label == '.') {
229 /* The beginning of the ellipsis. */
230 /* Check it is a proper ellipsis. */
231 CHECK(!(ellipsis || i + 2 >= length
232 || subscripts[++i] != '.' || subscripts[++i] != '.'))
233 << "einstein sum subscripts string "
234 << "contains a '.' that is not part of "
235 << "an ellipsis ('...') in the output";
236
237 /* Check there is room in out_labels for broadcast dims. */
238 CHECK(ndim + ndim_broadcast <= NPY_MAXDIMS)
239 << "einstein sum subscripts string contains "
240 << "too many subscripts in the output";
241
242 ellipsis = 1;
243 for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
244 out_labels[ndim++] = 0;
245 }
246 } else {
247 CHECK(label == ' ')
248 << "invalid subscript '" << static_cast<char>(label)
249 << "' in einstein sum "
250 << "subscripts string, subscripts must "
251 << "be letters";
252 }
253 }
254
255 /* If no ellipsis was found there should be no broadcast dimensions. */
256 CHECK(!(!ellipsis && ndim_broadcast > 0))
257 << "output has more dimensions than subscripts "
258 << "given in einstein sum, but no '...' ellipsis "
259 << "provided to broadcast the extra dimensions.";
260
261 return ndim;
262 }
263
get_combined_dims_view(const TBlob & op,int iop,char * labels,TShape * newshape,TShape * newstride)264 inline void get_combined_dims_view(const TBlob& op, int iop,
265 char *labels,
266 TShape* newshape,
267 TShape* newstride) {
268 using namespace mxnet_op;
269 int idim, ndim, icombine, combineoffset;
270 int icombinemap[NPY_MAXDIMS];
271 int newdim;
272
273 const TShape& shape = op.shape_;
274 TShape stride = get_stride(shape);
275 ndim = op.shape_.ndim();
276 newdim = newshape->ndim();
277
278 /* Initialize the dimensions and strides to zero */
279 for (idim = 0; idim < newdim; ++idim) {
280 (*newshape)[idim] = 0;
281 (*newstride)[idim] = 0;
282 }
283
284 /* Copy the dimensions and strides, except when collapsing */
285 icombine = 0;
286 for (idim = 0; idim < ndim; ++idim) {
287 /*
288 * The char type may be either signed or unsigned, we
289 * need it to be signed here.
290 */
291 int label = (signed char)labels[idim];
292 /* If this label says to merge axes, get the actual label */
293 if (label < 0) {
294 combineoffset = label;
295 label = labels[idim+label];
296 } else {
297 combineoffset = 0;
298 if (icombine != idim) {
299 labels[icombine] = labels[idim];
300 }
301 icombinemap[idim] = icombine;
302 }
303 /* If the label is 0, it's an unlabeled broadcast dimension */
304 if (label == 0) {
305 (*newshape)[icombine] = shape[idim];
306 (*newstride)[icombine] = stride[idim];
307 } else {
308 /* Update the combined axis dimensions and strides */
309 int i = icombinemap[idim + combineoffset];
310 CHECK(!(combineoffset < 0 && (*newshape)[i] != 0 &&
311 (*newshape)[i] != shape[idim]))
312 << "dimensions in operand " << iop
313 << " for collapsing index '" << label
314 << "' don't match (" << static_cast<int>((*newshape)[i])
315 << " != " << shape[idim] << ")";
316 (*newshape)[i] = shape[idim];
317 (*newstride)[i] += stride[idim];
318 }
319
320 /* If the label didn't say to combine axes, increment dest i */
321 if (combineoffset == 0) {
322 icombine++;
323 }
324 }
325 }
326
prepare_op_axes(int ndim,int iop,char * labels,int * axes,int ndim_iter,char * iter_labels)327 inline static int prepare_op_axes(int ndim, int iop, char *labels,
328 int *axes, int ndim_iter, char *iter_labels) {
329 using namespace mxnet_op;
330 int i, label, ibroadcast;
331
332 ibroadcast = ndim-1;
333 for (i = ndim_iter-1; i >= 0; --i) {
334 label = iter_labels[i];
335 /*
336 * If it's an unlabeled broadcast dimension, choose
337 * the next broadcast dimension from the operand.
338 */
339 if (label == 0) {
340 while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
341 --ibroadcast;
342 }
343 /*
344 * If we used up all the operand broadcast dimensions,
345 * extend it with a "newaxis"
346 */
347 if (ibroadcast < 0) {
348 axes[i] = -1;
349 } else {
350 /* Otherwise map to the broadcast axis */
351 axes[i] = ibroadcast;
352 --ibroadcast;
353 }
354 } else {
355 /* It's a labeled dimension, find the matching one */
356 char *match = reinterpret_cast<char*>(memchr(labels, label, ndim));
357 /* If the op doesn't have the label, broadcast it */
358 if (match == nullptr) {
359 axes[i] = -1;
360 } else {
361 /* Otherwise use it */
362 axes[i] = match - labels;
363 }
364 }
365 }
366 return 0;
367 }
368
369 struct NumpyEinsumParam: public dmlc::Parameter<NumpyEinsumParam> {
370 int num_args;
371 int optimize;
372 std::string subscripts;
DMLC_DECLARE_PARAMETERNumpyEinsumParam373 DMLC_DECLARE_PARAMETER(NumpyEinsumParam) {
374 DMLC_DECLARE_FIELD(num_args)
375 .set_lower_bound(1)
376 .describe("Number of input arrays.");
377 DMLC_DECLARE_FIELD(subscripts)
378 .set_default("")
379 .describe("Specifies the subscripts for summation as comma separated list"
380 " of subscript labels. An implicit (classical Einstein summation) calculation"
381 " is performed unless the explicit indicator '->' is included as well as"
382 " subscript labels of the precise output form.");
383 DMLC_DECLARE_FIELD(optimize)
384 .set_default(0);
385 }
386 };
387
388 class EinsumOp {
389 public:
390 int num_args;
391 int optimize;
392 std::string subscripts;
393 std::shared_ptr<NDArray> tempspace;
394 std::vector<Step> paths;
EinsumOp(int num_args,int optimize,std::string subscripts)395 explicit EinsumOp(int num_args, int optimize, std::string subscripts) {
396 this->num_args = num_args;
397 this->optimize = optimize;
398 this->subscripts = subscripts;
399 }
400 }; // class EinsumOp
401
402 template<int dimension, int req, bool back, typename AType>
403 struct numpy_einsum{
404 template<typename DType>
Mapnumpy_einsum405 MSHADOW_XINLINE static void Map(index_t i, DType* out,
406 common::StaticArray<DType*, NPY_MAXARGS> op,
407 mshadow::Shape<dimension> oshape,
408 common::StaticArray<mshadow::Shape<dimension>,
409 NPY_MAXARGS> ostride,
410 mshadow::Shape<dimension> reduceshape,
411 common::StaticArray<mshadow::Shape<dimension>,
412 NPY_MAXARGS> rstride,
413 int nop,
414 int iop0,
415 const DType* out_grad) {
416 using namespace mxnet_op;
417 mshadow::Shape<dimension> oidx = unravel(i, oshape);
418 i = back ? dot(oidx, ostride[iop0]) : i;
419 if (req == kWriteTo) {
420 out[i] = (DType)0;
421 }
422 for (int rdim = 0; rdim < dimension; ++rdim) {
423 if (reduceshape[rdim] == 0) {
424 return;
425 }
426 }
427 mshadow::Shape<dimension> ridx = unravel(0, reduceshape);
428 AType sum = 0;
429 do {
430 AType tmp = back ? static_cast<AType>(out_grad[dot(oidx, ostride[nop]) +
431 dot(ridx, rstride[nop])]): (AType)1;
432 for (int iop = 0; iop < nop; ++iop) {
433 if (iop != iop0) {
434 index_t k = dot(oidx, ostride[iop]) + dot(ridx, rstride[iop]);
435 tmp = tmp * static_cast<AType>(op[iop][k]);
436 }
437 }
438 sum = sum + tmp;
439 }while (inc(&ridx, reduceshape));
440 out[i] = out[i] + static_cast<DType>(sum);
441 }
442 };
443
444 template<typename xpu, bool back>
NumpyEinsumProcess(const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs,const char * subscripts,int nop,const OpContext & ctx)445 inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
446 const std::vector<OpReqType>& req,
447 const std::vector<TBlob>& outputs,
448 const char *subscripts, int nop,
449 const OpContext& ctx) {
450 using namespace mxnet_op;
451
452 /* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
453 CHECK(nop < NPY_MAXARGS)
454 << "too many operands provided to einstein sum function";
455 CHECK(nop >= 1)
456 << "not enough operands provided to einstein sum function";
457
458 /* Step 1: Parse the subscripts string into label_counts and op_labels */
459 int iop, idim, min_label = 127, max_label = 0;
460 char label_counts[128], op_labels[NPY_MAXARGS][NPY_MAXDIMS];
461 memset(label_counts, 0, sizeof(label_counts));
462 for (iop = 0; iop < nop; ++iop) {
463 int length = static_cast<int>(strcspn(subscripts, ",-"));
464
465 CHECK(!(iop == nop - 1 && subscripts[length] == ','))
466 << "more operands provided to einstein sum function "
467 << "than specified in the subscripts string";
468 CHECK(!(iop < nop-1 && subscripts[length] != ','))
469 << "fewer operands provided to einstein sum function "
470 << "than specified in the subscripts string";
471 CHECK_GE(parse_operand_subscripts(subscripts, length,
472 inputs[iop + back].shape_.ndim(),
473 iop, op_labels[iop], label_counts,
474 &min_label, &max_label), 0);
475
476 /* Move subscripts to the start of the labels for the next op */
477 subscripts += length;
478 if (iop < nop - 1) {
479 subscripts++;
480 }
481 }
482
483 /*
484 * Find the number of broadcast dimensions, which is the maximum
485 * number of labels == 0 in an op_labels array.
486 */
487 int ndim_broadcast = 0;
488 for (iop = 0; iop < nop; ++iop) {
489 int count_zeros = 0;
490 int ndim;
491 char *labels = op_labels[iop];
492
493 ndim = inputs[iop + back].shape_.ndim();
494 for (idim = 0; idim < ndim; ++idim) {
495 if (labels[idim] == 0) {
496 ++count_zeros;
497 }
498 }
499
500 if (count_zeros > ndim_broadcast) {
501 ndim_broadcast = count_zeros;
502 }
503 }
504
505 /*
506 * If there is no output signature, fill output_labels and ndim_output
507 * using each label that appeared once, in alphabetical order.
508 */
509 int label, ndim_output;
510 char output_labels[NPY_MAXDIMS];
511 if (subscripts[0] == '\0') {
512 /* If no output was specified, always broadcast left, as usual. */
513 for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
514 output_labels[ndim_output] = 0;
515 }
516 for (label = min_label; label <= max_label; ++label) {
517 if (label_counts[label] == 1) {
518 CHECK(ndim_output < NPY_MAXDIMS)
519 << "einstein sum subscript string has too many "
520 << "distinct labels";
521 output_labels[ndim_output++] = label;
522 }
523 }
524 } else {
525 CHECK(subscripts[0] == '-' && subscripts[1] == '>')
526 << "einstein sum subscript string does not "
527 << "contain proper '->' output specified";
528 subscripts += 2;
529
530 /* Parse the output subscript string. */
531 ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
532 ndim_broadcast, label_counts,
533 output_labels);
534 CHECK_GE(ndim_output, 0);
535 }
536
537 /*
538 * Step 2:
539 * Process all the input ops, combining dimensions into their
540 * diagonal where specified.
541 */
542 std::vector<TShape> opshape(nop), opstride_true(nop);
543 for (iop = 0; iop < nop; ++iop) {
544 char *labels = op_labels[iop];
545 int combine, ndim;
546
547 ndim = inputs[iop + back].shape_.ndim();
548
549 /*
550 * Check whether any dimensions need to be combined
551 *
552 * The char type may be either signed or unsigned, we
553 * need it to be signed here.
554 */
555 combine = 0;
556 for (idim = 0; idim < ndim; ++idim) {
557 if ((signed char)labels[idim] < 0) {
558 combine++;
559 }
560 }
561
562 /* If any dimensions are combined, create a view which combines them */
563 if (combine) {
564 TShape tshape(ndim - combine, -1);
565 TShape tstride(ndim - combine, -1);
566 get_combined_dims_view(inputs[iop + back], iop, labels,
567 &tshape, &tstride);
568 opshape[iop] = tshape;
569 opstride_true[iop] = tstride;
570 } else {
571 /* No combining needed */
572 opshape[iop] = inputs[iop + back].shape_;
573 opstride_true[iop] = get_stride(opshape[iop]);
574 }
575 }
576
577 /*
578 * Step 3:
579 * Set up the labels for the iterator (output + combined labels).
580 * Can just share the output_labels memory, because iter_labels
581 * is output_labels with some more labels appended.
582 */
583 char *iter_labels = output_labels;
584 int ndim_iter = ndim_output;
585 for (label = min_label; label <= max_label; ++label) {
586 if (label_counts[label] > 0 &&
587 memchr(output_labels, label, ndim_output) == nullptr) {
588 CHECK(ndim_iter < NPY_MAXDIMS)
589 << "too many subscripts in einsum";
590 iter_labels[ndim_iter++] = label;
591 }
592 }
593
594 /* Step 4: Set up the op_axes for the iterator */
595 TShape itershape(ndim_iter, -1);
596 std::vector<TShape> iterstride(nop + 1, TShape(ndim_iter, 0));
597 TShape oshape = back ? inputs[0].shape_ : outputs[0].shape_;
598 TShape ostride_true = get_stride(oshape);
599 TShape reduceshape;
600 std::vector<TShape> remainshape(nop);
601 int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
602 int *op_axes[NPY_MAXARGS];
603
604 for (iop = 0; iop < nop; ++iop) {
605 op_axes[iop] = op_axes_arrays[iop];
606 CHECK_GE(prepare_op_axes(opshape[iop].ndim(), iop, op_labels[iop],
607 op_axes[iop], ndim_iter, iter_labels), 0);
608 for (idim = 0; idim < ndim_iter; idim++) {
609 if (op_axes[iop][idim] != -1) {
610 iterstride[iop][idim] = opstride_true[iop][op_axes[iop][idim]];
611 if (itershape[idim] != -1) {
612 if (itershape[idim] == 1) {
613 itershape[idim] = opshape[iop][op_axes[iop][idim]];
614 }
615 } else {
616 itershape[idim] = opshape[iop][op_axes[iop][idim]];
617 }
618 }
619 }
620 }
621 for (idim = 0; idim < ndim_output; ++idim) {
622 iterstride[nop][idim] = ostride_true[idim];
623 }
624 reduceshape = TShape(ndim_iter - ndim_output, 0);
625 for (idim = ndim_output; idim < ndim_iter; ++idim) {
626 reduceshape[idim - ndim_output] = itershape[idim];
627 }
628 for (iop = 0; iop < nop; iop++) {
629 std::vector<size_t> rsh;
630 for (idim = 0; idim < ndim_iter; idim++) {
631 if (op_axes_arrays[iop][idim] == -1 ||
632 itershape[idim] != opshape[iop][op_axes_arrays[iop][idim]]) {
633 rsh.push_back(itershape[idim]);
634 }
635 }
636 remainshape[iop] = TShape(rsh.begin(), rsh.end());
637 }
638
639 // exclude the 0-dim case
640 if (ndim_iter == 0) {
641 ndim_iter = 1;
642 }
643 itershape = pad(itershape, ndim_iter);
644 for (iop = 0; iop <= nop; ++iop) {
645 iterstride[iop] = pad(iterstride[iop], ndim_iter);
646 }
647 oshape = pad(oshape, ndim_iter);
648 reduceshape = pad(reduceshape, ndim_iter);
649 for (iop = 0; iop < nop; ++iop) {
650 opshape[iop] = pad(opshape[iop], ndim_iter);
651 remainshape[iop] = pad(remainshape[iop], ndim_iter);
652 }
653
654 if (!back) {
655 if (oshape.Size() == 0) {
656 return;
657 }
658 const TBlob &out_data = outputs[0];
659 MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
660 mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
661 for (iop = 0; iop < nop; ++iop) {
662 op[iop] = inputs[iop].dptr<DType>();
663 }
664 MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
665 MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
666 mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> ostride_arr;
667 mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> rstride_arr;
668 for (iop = 0; iop < nop; ++iop) {
669 mshadow::Shape<dimension> otmp, rtmp;
670 for (idim = 0; idim < dimension; ++idim) {
671 otmp[idim] = idim < ndim_output ? iterstride[iop][idim] : 1;
672 rtmp[idim] = idim < dimension - ndim_output ? iterstride[iop][idim + ndim_output] : 1;
673 }
674 ostride_arr[iop] = otmp;
675 rstride_arr[iop] = rtmp;
676 }
677 Kernel<numpy_einsum<dimension, req_type, 0, AType>,
678 xpu>::Launch(ctx.get_stream<xpu>(),
679 oshape.Size(),
680 out_data.dptr<DType>(),
681 op,
682 oshape.get<dimension>(),
683 ostride_arr,
684 reduceshape.get<dimension>(),
685 rstride_arr,
686 nop,
687 -1,
688 (DType*)nullptr);
689 })
690 })
691 })
692 } else {
693 if (oshape.Size() == 0) {
694 for (iop = 0; iop < nop; ++iop) {
695 const TBlob& out_data = outputs[iop];
696 if (opshape[iop].Size() > 0) {
697 MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
698 MXNET_ASSIGN_REQ_SWITCH(req[iop], req_type, {
699 if (req_type == kWriteTo) {
700 out_data.FlatTo1D<xpu, DType>(ctx.get_stream<xpu>()) = 0;
701 }
702 })
703 })
704 }
705 }
706 return;
707 }
708 for (int i = 0; i < nop; ++i) {
709 const TBlob &out_data = outputs[i];
710 const TBlob &out_grad = inputs[0];
711 std::vector<TShape> opstride(nop + 1, TShape(ndim_iter, 0));
712 std::vector<TShape> remainstride(nop + 1, TShape(ndim_iter, 0));
713 for (iop = 0; iop <= nop; ++iop) {
714 int j = 0;
715 for (idim = 0; idim < ndim_iter; ++idim) {
716 if (op_axes_arrays[i][idim] == -1 ||
717 (iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 &&
718 op_axes_arrays[iop][idim] != -1 &&
719 opshape[iop][op_axes_arrays[iop][idim]] != 1)) {
720 remainstride[iop][j++] = iterstride[iop][idim];
721 } else {
722 opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim];
723 }
724 }
725 }
726 MXNET_ACC_TYPE_SWITCH(out_data.type_flag_, DType, AType, {
727 mxnet::common::StaticArray<DType*, NPY_MAXARGS> op;
728 for (iop = 0; iop < nop; ++iop) {
729 op[iop] = inputs[iop + back].dptr<DType>();
730 }
731 MXNET_ASSIGN_REQ_SWITCH(req[i], req_type, {
732 MXNET_NDIM_SWITCH_EX(ndim_iter, dimension, {
733 mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> opstride_arr;
734 mxnet::common::StaticArray<mshadow::Shape<dimension>, NPY_MAXARGS> remainstride_arr;
735 for (iop = 0; iop <= nop; ++iop) {
736 opstride_arr[iop] = opstride[iop].get<dimension>();
737 remainstride_arr[iop] = remainstride[iop].get<dimension>();
738 }
739 Kernel<numpy_einsum<dimension, req_type, 1, AType>,
740 xpu>::Launch(ctx.get_stream<xpu>(),
741 opshape[i].Size(),
742 out_data.dptr<DType>(),
743 op,
744 opshape[i].get<dimension>(),
745 opstride_arr,
746 remainshape[i].get<dimension>(),
747 remainstride_arr,
748 nop,
749 i,
750 out_grad.dptr<DType>());
751 })
752 })
753 })
754 }
755 }
756 }
757
758 template<typename xpu>
NumpyEinsumForward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)759 inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
760 const OpContext& ctx,
761 const std::vector<TBlob>& inputs,
762 const std::vector<OpReqType>& req,
763 const std::vector<TBlob>& outputs) {
764 using namespace mshadow;
765 using namespace mxnet_op;
766 EinsumOp& state = state_ptr.get_state<EinsumOp>();
767 int num_args = state.num_args;
768 int optimize = state.optimize;
769 const char* subscripts = state.subscripts.c_str();
770 Stream<xpu> *s = ctx.get_stream<xpu>();
771 CHECK_EQ(inputs.size(), num_args);
772 CHECK_EQ(outputs.size(), 1U);
773 if (optimize == 0) {
774 NumpyEinsumProcess<xpu, 0>(inputs, req, outputs, subscripts, num_args, ctx);
775 return;
776 }
777 std::vector<Step>& paths = state.paths;
778 std::vector<std::vector<int> > pos;
779 std::string string_repr;
780 paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
781 int paths_len = paths.size();
782 size_t temp_space_size = 0, max_temp_space_size = 0;
783 std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
784 for (int i = 0; i + 1 < paths_len; ++i) {
785 temp_space_size += paths[i].oshape.Size();
786 }
787 for (int i = 0; i < paths_len; ++i) {
788 max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
789 }
790 temp_space_size += max_temp_space_size;
791 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
792 state.tempspace.reset<NDArray>(new NDArray(TShape(Shape1(temp_space_size)),
793 ctx.run_ctx.ctx,
794 false,
795 outputs[0].type_flag_));
796 Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
797 size_t begin = max_temp_space_size;
798 for (int i = 0; i < paths_len - 1; ++i) {
799 TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
800 temp_space_vec[i] = tblob.reshape(paths[i].oshape);
801 begin = begin + paths[i].oshape.Size();
802 }
803 for (int i = 0; i < paths_len; ++i) {
804 tmp_operands.clear();
805
806 // We remove inds from right to left
807 for (const int& p : paths[i].contract_inds) {
808 tmp_operands.push_back(operands[p]);
809 operands.erase(operands.begin() + p);
810 }
811 bool handle_out = (i == paths_len - 1);
812 // Call tensordot if still possible
813 if (paths[i].do_blas) {
814 // Contract!
815 if (paths[i].do_einsum || handle_out) {
816 TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
817 max_temp_space.FlatTo1D<xpu, DType>(s) = 0;
818 max_temp_space = max_temp_space.reshape(paths[i].tshape);
819 size_t tensordot_tempspace_size =
820 TensordotWorkspaceSize<xpu>(paths[i].left_pos,
821 paths[i].right_pos,
822 tmp_operands[0],
823 tmp_operands[1],
824 max_temp_space,
825 std::vector<OpReqType>{OpReqType::kWriteTo});
826 Tensor<xpu, 1, char> tensordot_tempspace =
827 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tensordot_tempspace_size), s);
828 TensordotImpl<xpu>(paths[i].left_pos,
829 paths[i].right_pos,
830 ctx,
831 tmp_operands[0],
832 tmp_operands[1],
833 max_temp_space,
834 std::vector<OpReqType>{OpReqType::kWriteTo},
835 tensordot_tempspace);
836 NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{max_temp_space},
837 handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
838 handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
839 paths[i].blas2einsum_str.c_str(),
840 1, ctx);
841 } else {
842 size_t tensordot_tempspace_size =
843 TensordotWorkspaceSize<xpu>(paths[i].left_pos,
844 paths[i].right_pos,
845 tmp_operands[0],
846 tmp_operands[1],
847 temp_space_vec[i],
848 std::vector<OpReqType>{OpReqType::kWriteTo});
849 Tensor<xpu, 1, char> tensordot_tempspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
850 Shape1(tensordot_tempspace_size), s);
851 TensordotImpl<xpu>(paths[i].left_pos,
852 paths[i].right_pos,
853 ctx,
854 tmp_operands[0],
855 tmp_operands[1],
856 temp_space_vec[i],
857 std::vector<OpReqType>{OpReqType::kWriteTo},
858 tensordot_tempspace);
859 }
860 } else {
861 NumpyEinsumProcess<xpu, 0>(tmp_operands,
862 handle_out ? req : std::vector<OpReqType>{OpReqType::kWriteTo},
863 handle_out ? outputs : std::vector<TBlob>{temp_space_vec[i]},
864 paths[i].einsum_str.c_str(), tmp_operands.size(), ctx);
865 }
866 if (!handle_out) {
867 operands.push_back(temp_space_vec[i]);
868 }
869 }
870 });
871 }
872
873 template<typename xpu>
NumpyEinsumBackward(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)874 inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
875 const OpContext& ctx,
876 const std::vector<TBlob>& inputs,
877 const std::vector<OpReqType>& req,
878 const std::vector<TBlob>& outputs) {
879 using namespace mshadow;
880 using namespace mshadow_op;
881 const EinsumOp& state = state_ptr.get_state<EinsumOp>();
882 int num_args = state.num_args;
883 int optimize = state.optimize;
884 const char* subscripts = state.subscripts.c_str();
885 Stream<xpu> *s = ctx.get_stream<xpu>();
886 CHECK_EQ(inputs.size(), 1 + num_args);
887 CHECK_EQ(outputs.size(), num_args);
888 if (optimize == 0) {
889 NumpyEinsumProcess<xpu, 1>(inputs, req, outputs, subscripts, num_args, ctx);
890 return;
891 }
892 // calculate temporary space size for temp_grad
893 const std::vector<Step>& paths = state.paths;
894 int paths_len = paths.size();
895 size_t temp_space_size = 0, max_temp_space_size = 0;
896 for (int i = 0; i < paths_len - 1; ++i) {
897 temp_space_size += paths[i].oshape.Size();
898 }
899 for (int i = 0; i < paths_len; ++i) {
900 max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
901 }
902 temp_space_size += max_temp_space_size;
903 // replay the forward process
904 std::vector<std::vector<int> > op_idx(paths_len + 1);
905 for (int i = 0; i <= paths_len; ++i) {
906 if (i == 0) {
907 op_idx[i].reserve(num_args);
908 for (int j = 0; j < num_args; ++j) {
909 op_idx[i].push_back(j + 1);
910 }
911 } else {
912 op_idx[i] = op_idx[i - 1];
913 // We remove inds from right to left
914 for (const int& p : paths[i - 1].contract_inds) {
915 op_idx[i].erase(op_idx[i].begin() + p);
916 }
917 op_idx[i].push_back(-static_cast<int>(i - 1));
918 }
919 }
920 // calculate temporary space size for tensordot
921 size_t tensordot_max_tempspace_size = 0;
922 size_t begin_tensordot_tempspace = 0;
923 std::vector<TBlob> temp_inputs, temp_outputs;
924 std::vector<OpReqType> temp_req;
925 std::vector<size_t> tensordot_tempspace_size;
926 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
927 for (int i = 0; i < paths_len; ++i) {
928 temp_inputs.clear();
929 temp_outputs.clear();
930 temp_req.clear();
931 bool handle_out = (i == paths_len - 1);
932
933 if (handle_out) {
934 temp_inputs.push_back(inputs[0]);
935 } else {
936 temp_inputs.push_back(TBlob((DType*)nullptr,
937 paths[i].oshape,
938 xpu::kDevMask));
939 }
940 for (auto p : paths[i].contract_inds) {
941 int idx = op_idx[i][p];
942 if (idx >= 1) {
943 temp_inputs.push_back(inputs[idx]);
944 temp_outputs.push_back(outputs[idx - 1]);
945 temp_req.push_back(req[idx - 1]);
946 } else {
947 temp_inputs.push_back(TBlob((DType*)nullptr,
948 paths[-idx].oshape,
949 xpu::kDevMask));
950 temp_outputs.push_back(TBlob((DType*)nullptr,
951 paths[-idx].oshape,
952 xpu::kDevMask));
953 temp_req.push_back(OpReqType::kWriteTo);
954 }
955 }
956 size_t cur_tensordot_tempspace_size = 0;
957 if (paths[i].do_blas) {
958 if (paths[i].do_einsum) {
959 cur_tensordot_tempspace_size =
960 TensordotBackwardWorkspaceSize<xpu>(paths[i].left_pos,
961 paths[i].right_pos,
962 TBlob((DType*)nullptr,
963 paths[i].tshape,
964 xpu::kDevMask),
965 temp_inputs[1],
966 temp_inputs[2],
967 temp_outputs[0],
968 temp_outputs[1],
969 temp_req);
970 } else {
971 cur_tensordot_tempspace_size =
972 TensordotBackwardWorkspaceSize<xpu>(paths[i].left_pos,
973 paths[i].right_pos,
974 temp_inputs[0],
975 temp_inputs[1],
976 temp_inputs[2],
977 temp_outputs[0],
978 temp_outputs[1],
979 temp_req);
980 }
981 }
982 tensordot_tempspace_size.push_back(cur_tensordot_tempspace_size);
983 tensordot_max_tempspace_size = std::max(tensordot_max_tempspace_size,
984 cur_tensordot_tempspace_size);
985 }
986 begin_tensordot_tempspace = temp_space_size;
987 temp_space_size += (tensordot_max_tempspace_size + sizeof(DType) - 1) / sizeof(DType);
988 });
989 // allocate temporary space and propagate
990 std::vector<TBlob> temp_grad(paths_len - 1), temp_data(paths_len - 1);
991 MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
992 // allocate temporary space for gradients of intermediate results
993 Tensor<xpu, 1, DType> temp_space = ctx.requested[0].get_space_typed<xpu, 1, DType>
994 (Shape1(temp_space_size), s);
995 size_t begin = max_temp_space_size;
996 for (int i = 0; i + 1 < paths_len; ++i) {
997 TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
998 temp_grad[i] = tblob.reshape(paths[i].oshape);
999 begin = begin + paths[i].oshape.Size();
1000 }
1001
1002 // reinterprete ndarray for intermediate results
1003 Tensor<xpu, 1, DType> ndarray_space = state.tempspace->data().FlatTo1D<xpu, DType>();
1004 begin = max_temp_space_size;
1005 for (int i = 0; i + 1 < paths_len; ++i) {
1006 TBlob tblob = TBlob(ndarray_space.Slice(begin, begin + paths[i].oshape.Size()));
1007 temp_data[i] = tblob.reshape(paths[i].oshape);
1008 begin = begin + paths[i].oshape.Size();
1009 }
1010
1011 // go through the paths in the reversed order
1012 for (int i = paths_len - 1; i >= 0; i--) {
1013 temp_inputs.clear();
1014 temp_outputs.clear();
1015 temp_req.clear();
1016 bool handle_out = (i == paths_len - 1);
1017
1018 if (handle_out) {
1019 temp_inputs.push_back(inputs[0]);
1020 } else {
1021 temp_inputs.push_back(temp_grad[i]);
1022 }
1023 for (auto p : paths[i].contract_inds) {
1024 int idx = op_idx[i][p];
1025 if (idx >= 1) {
1026 temp_inputs.push_back(inputs[idx]);
1027 temp_outputs.push_back(outputs[idx - 1]);
1028 temp_req.push_back(req[idx - 1]);
1029 } else {
1030 temp_inputs.push_back(temp_data[-idx]);
1031 temp_outputs.push_back(temp_grad[-idx]);
1032 temp_req.push_back(OpReqType::kWriteTo);
1033 }
1034 }
1035 if (paths[i].do_blas) {
1036 CHECK_EQ(temp_inputs.size(), 3U);
1037 CHECK_EQ(temp_outputs.size(), 2U);
1038 CHECK_EQ(temp_req.size(), 2U);
1039 Tensor<xpu, 1, DType> tensordot_tempspace = temp_space.Slice(begin_tensordot_tempspace,
1040 temp_space_size);
1041 Tensor<xpu, 1, char> char_tempspace =
1042 Tensor<xpu, 1, char>(reinterpret_cast<char*>(tensordot_tempspace.dptr_),
1043 Shape1(tensordot_tempspace_size[i]),
1044 tensordot_tempspace.stream_);
1045 if (paths[i].do_einsum) {
1046 TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
1047 max_temp_space = max_temp_space.reshape(paths[i].tshape);
1048 NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{temp_inputs[0]},
1049 std::vector<OpReqType>{kWriteTo},
1050 std::vector<TBlob>{max_temp_space},
1051 paths[i].einsum2blas_str.c_str(),
1052 1, ctx);
1053 TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
1054 max_temp_space, temp_inputs[1], temp_inputs[2],
1055 temp_outputs[0], temp_outputs[1], temp_req, char_tempspace);
1056 } else {
1057 TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
1058 temp_inputs[0], temp_inputs[1], temp_inputs[2],
1059 temp_outputs[0], temp_outputs[1], temp_req, char_tempspace);
1060 }
1061 } else {
1062 NumpyEinsumProcess<xpu, 1>(temp_inputs, temp_req, temp_outputs,
1063 paths[i].einsum_str.c_str(),
1064 temp_outputs.size(),
1065 ctx);
1066 }
1067 }
1068 });
1069 }
1070
1071 } // namespace op
1072 } // namespace mxnet
1073
1074 #endif // MXNET_OPERATOR_NUMPY_NP_EINSUM_OP_INL_H_
1075