1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_tensorinv-inl.h
22  * \brief Placeholder for tensor inverse operator
23  */
24 #ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
25 #define MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include "../../operator_common.h"
30 #include "../../mshadow_op.h"
31 #include "../../tensor/la_op.h"
32 #include "../../tensor/la_op-inl.h"
33 
34 namespace mxnet {
35 namespace op {
36 
37 using namespace mshadow;
38 
39 struct TensorinvParam : public dmlc::Parameter<TensorinvParam> {
40   int ind;
DMLC_DECLARE_PARAMETERTensorinvParam41   DMLC_DECLARE_PARAMETER(TensorinvParam) {
42     DMLC_DECLARE_FIELD(ind)
43       .set_default(2)
44       .describe("Number of first indices that are involved in the inverse sum.");
45   }
46 };
47 
48 template<typename xpu>
TensorinvOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)49 void TensorinvOpForward(const nnvm::NodeAttrs& attrs,
50                         const OpContext& ctx,
51                         const std::vector<TBlob>& inputs,
52                         const std::vector<OpReqType>& req,
53                         const std::vector<TBlob>& outputs) {
54   CHECK_EQ(inputs.size(), 1U);
55   CHECK_EQ(outputs.size(), 1U);
56   CHECK_EQ(req.size(), 1U);
57 
58   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
59   const mxnet::TBlob& a_tblob = inputs[0];
60   const mxnet::TBlob& inv_a_tblob = outputs[0];
61   const mxnet::TShape& a_shape = a_tblob.shape_;
62   CHECK_EQ(inv_a_tblob.type_flag_, a_tblob.type_flag_)
63       << "Binary function only support input/output with the same type.";
64   MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
65     const int ind = nnvm::get<TensorinvParam>(attrs.parsed).ind;
66     dim_t prod_front = 1, prod_back = 1;
67     if (ind < a_shape.ndim()) {
68       for (int i = 0; i < ind; ++i) {
69         prod_front *= a_shape[i];
70       }
71       for (int i = ind; i < a_shape.ndim(); ++i) {
72         prod_back *= a_shape[i];
73       }
74     } else {
75       for (int i = 0; i < a_shape.ndim(); ++i) {
76         prod_front *= a_shape[i];
77       }
78     }
79     Tensor<xpu, 3, OType> A =
80       a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
81     Tensor<xpu, 3, OType> inv_A =
82       inv_a_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, prod_back, prod_front), s);
83     inverse::op(A, inv_A, ctx, attrs);
84   });
85 }
86 
87 template<typename xpu>
TensorinvOpBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)88 void TensorinvOpBackward(const nnvm::NodeAttrs& attrs,
89                          const OpContext& ctx,
90                          const std::vector<TBlob>& inputs,
91                          const std::vector<OpReqType>& req,
92                          const std::vector<TBlob>& outputs) {
93   CHECK_EQ(inputs.size(), 2U);
94   CHECK_EQ(outputs.size(), 1U);
95   CHECK_EQ(req.size(), 1U);
96 
97   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
98   const TBlob& out_grad = inputs[0];
99   const TBlob& inv_a = inputs[1];
100   const TBlob& grad_a = outputs[0];
101   const TShape& inv_a_shape = inv_a.shape_;
102   MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
103     const int axes = nnvm::get<TensorinvParam>(attrs.parsed).ind;
104     CHECK_LE(inv_a_shape.ndim(), 6U)
105       << "tensorinv backward only support tensor's dimension <= 6";
106     if (axes < inv_a_shape.ndim()) {
107       const int axes1 = inv_a_shape.ndim() - axes, axes2 = axes;
108       TShape inv_a_transpose_shape(inv_a_shape.ndim(), -1);
109       for (int i = 0; i < axes; ++i) {
110         inv_a_transpose_shape[i] = inv_a_shape[i + inv_a_shape.ndim() - axes];
111       }
112       for (int i = axes; i < inv_a_shape.ndim(); ++i) {
113         inv_a_transpose_shape[i] = inv_a_shape[i - axes];
114       }
115       TShape temp_shape(2 * axes, -1);
116       for (int i = 0; i < axes; ++i) {
117         temp_shape[i] = inv_a_transpose_shape[i];
118         temp_shape[i + axes] = inv_a_transpose_shape[i];
119       }
120       Tensor<xpu, 1, char> workspace =
121         ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_shape.Size() * sizeof(OType)),
122                                                        ctx.get_stream<xpu>());
123       TBlob temp_tblob =
124         TBlob(reinterpret_cast<OType*>(workspace.dptr_), temp_shape, xpu::kDevMask);
125       dim_t a1 = 1, a2 = 1;
126       for (int i = 0; i < axes2; ++i) {
127         a1 *= inv_a_transpose_shape[i];
128       }
129       for (int i = 0; i < axes1; ++i) {
130         a2 *= inv_a_shape[i];
131       }
132       Tensor<xpu, 3, OType> inv_a_tensor =
133         inv_a.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
134       Tensor<xpu, 3, OType> out_grad_tensor =
135         out_grad.get_with_shape<xpu, 3, OType>(Shape3(1, a2, a1), s);
136       Tensor<xpu, 3, OType> temp_tensor =
137         temp_tblob.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a1), s);
138       Tensor<xpu, 3, OType> grad_a_tensor =
139         grad_a.get_with_shape<xpu, 3, OType>(Shape3(1, a1, a2), s);
140       gemm2::op(inv_a_tensor, out_grad_tensor, temp_tensor, OType(1), true, false, s);
141       gemm2::op(temp_tensor, inv_a_tensor, grad_a_tensor, OType(-1), false, true, s);
142     } else {  // axes >= inv_a_shape.ndim()
143       dim_t a = 1;
144       for (int i = 0; i < inv_a_shape.ndim(); ++i) {
145         a *= inv_a_shape[i];
146       }
147       // check again
148       CHECK_EQ(a, 1U)
149         << "a shape must be square, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:]).";
150       Tensor<xpu, 1, OType> inv_a_tensor =
151         inv_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
152       Tensor<xpu, 1, OType> out_grad_tensor =
153         out_grad.get_with_shape<xpu, 1, OType>(Shape1(1), s);
154       Tensor<xpu, 1, OType> grad_a_tensor =
155         grad_a.get_with_shape<xpu, 1, OType>(Shape1(1), s);
156       ASSIGN_DISPATCH(grad_a_tensor, kWriteTo,
157         OType(-1) * inv_a_tensor * out_grad_tensor * inv_a_tensor);
158     }
159   });
160 }
161 
162 }  // namespace op
163 }  // namespace mxnet
164 
165 #endif  // MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_
166