1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_tensorinv-inl.h
22 * \brief Placeholder for tensor inverse operator
23 */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
26
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include "../../operator_common.h"
30 #include "../../mshadow_op.h"
31 #include "../../tensor/la_op.h"
32 #include "../../tensor/la_op-inl.h"
33
34 namespace mxnet {
35 namespace op {
36
37 using namespace mshadow;
38
39 struct TensorinvParam : public dmlc::Parameter<TensorinvParam> {
40 int ind;
DMLC_DECLARE_PARAMETERTensorinvParam41 DMLC_DECLARE_PARAMETER(TensorinvParam) {
42 DMLC_DECLARE_FIELD(ind)
43 .set_default(2)
44 .describe("Number of first indices that are involved in the inverse sum.");
45 }
46 };
47
48 template<typename xpu>
TensorinvOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)49 void TensorinvOpForward(const nnvm::NodeAttrs& attrs,
50 const OpContext& ctx,
51 const std::vector<TBlob>& inputs,
52 const std::vector<OpReqType>& req,
53 const std::vector<TBlob>& outputs) {
54 CHECK_EQ(inputs.size(), 1U);
55 CHECK_EQ(outputs.size(), 1U);
56 CHECK_EQ(req.size(), 1U);
57
58 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
59 const mxnet::TBlob& a_tblob = inputs[0];
60 const mxnet::TBlob& inv_a_tblob = outputs[0];
61 const mxnet::TShape& a_shape = a_tblob.shape_;
62 CHECK_EQ(inv_a_tblob.type_flag_, a_tblob.type_flag_)
63 << "Binary function only support input/output with the same type.";
64 MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
65 const int ind = nnvm::get<TensorinvParam>(attrs.parsed).ind;
66 dim_t prod_front = 1, prod_back = 1;
67 if (ind < a_shape.ndim()) {
68 for (int i = 0; i < ind; ++i) {
69 prod_front *= a_shape[i];
70 }
71 for (int i = ind; i < a_shape.ndim(); ++i) {
72 prod_back *= a_shape[i];
73 }
74 } else {
75 for (int i = 0; i < a_shape.ndim(); ++i) {
76 prod_front *= a_shape[i];
77 }
78 }
79 Tensor<xpu, 3, OType> A =
80 a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
81 Tensor<xpu, 3, OType> inv_A =
82 inv_a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
83 inverse::op(A, inv_A, ctx, attrs);
84 });
85 }
86
87 template<typename xpu>
TensorinvOpBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)88 void TensorinvOpBackward(const nnvm::NodeAttrs& attrs,
89 const OpContext& ctx,
90 const std::vector<TBlob>& inputs,
91 const std::vector<OpReqType>& req,
92 const std::vector<TBlob>& outputs) {
93 CHECK_EQ(inputs.size(), 2U);
94 CHECK_EQ(outputs.size(), 1U);
95 CHECK_EQ(req.size(), 1U);
96
97 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
98 const TBlob& out_grad = inputs[0];
99 const TBlob& inv_a = inputs[1];
100 const TBlob& grad_a = outputs[0];
101 const TShape& inv_a_shape = inv_a.shape_;
102 MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
103 const int axes = nnvm::get<TensorinvParam>(attrs.parsed).ind;
104 CHECK_LE(inv_a_shape.ndim(), 6U)
105 << "tensorinv backward only support tensor's dimension <= 6";
106 if (axes < inv_a_shape.ndim()) {
107 const int axes1 = inv_a_shape.ndim() - axes, axes2 = axes;
108 TShape inv_a_transpose_shape(inv_a_shape.ndim(), -1);
109 for (int i = 0; i < axes; ++i) {
110 inv_a_transpose_shape[i] = inv_a_shape[i + inv_a_shape.ndim() - axes];
111 }
112 for (int i = axes; i < inv_a_shape.ndim(); ++i) {
113 inv_a_transpose_shape[i] = inv_a_shape[i - axes];
114 }
115 TShape temp_shape(2 * axes, -1);
116 for (int i = 0; i < axes; ++i) {
117 temp_shape[i] = inv_a_transpose_shape[i];
118 temp_shape[i + axes] = inv_a_transpose_shape[i];
119 }
120 Tensor<xpu, 1, char> workspace =
121 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_shape.Size() * sizeof(OType)),
122 ctx.get_stream<xpu>());
123 TBlob temp_tblob =
124 TBlob(reinterpret_cast<OType*>(workspace.dptr_), temp_shape, xpu::kDevMask);
125 dim_t a1 = 1, a2 = 1;
126 for (int i = 0; i < axes2; ++i) {
127 a1 *= inv_a_transpose_shape[i];
128 }
129 for (int i = 0; i < axes1; ++i) {
130 a2 *= inv_a_shape[i];
131 }
132 Tensor<xpu, 3, OType> inv_a_tensor =
133 inv_a.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
134 Tensor<xpu, 3, OType> out_grad_tensor =
135 out_grad.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
136 Tensor<xpu, 3, OType> temp_tensor =
137 temp_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a1), s);
138 Tensor<xpu, 3, OType> grad_a_tensor =
139 grad_a.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a2), s);
140 gemm2::op(inv_a_tensor, out_grad_tensor, temp_tensor, OType(1), true, false, s);
141 gemm2::op(temp_tensor, inv_a_tensor, grad_a_tensor, OType(-1), false, true, s);
142 } else { // axes >= inv_a_shape.ndim()
143 dim_t a = 1;
144 for (int i = 0; i < inv_a_shape.ndim(); ++i) {
145 a *= inv_a_shape[i];
146 }
147 // check again
148 CHECK_EQ(a, 1U)
149 << "a shape must be square, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:]).";
150 Tensor<xpu, 1, OType> inv_a_tensor =
151 inv_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
152 Tensor<xpu, 1, OType> out_grad_tensor =
153 out_grad.get_with_shape<xpu, 1, OType>(Shape1(1), s);
154 Tensor<xpu, 1, OType> grad_a_tensor =
155 grad_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
156 ASSIGN_DISPATCH(grad_a_tensor, kWriteTo,
157 OType(-1) * inv_a_tensor * out_grad_tensor * inv_a_tensor);
158 }
159 });
160 }
161
162 } // namespace op
163 } // namespace mxnet
164
165 #endif // MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
166