1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file implicit_gemm.h
22  * \brief support for implicit GEMM operation
23  * \author Tianqi Chen
24  */
25 #ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
26 #define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
27 
28 #include "../extension.h"
29 #include "../packet-inl.h"
30 
31 namespace mshadow {
32 namespace expr {
33 /*!
34  * \brief Matrix multiplication.
35  * \tparam LhsExp type of lhs expression
36  * \tparam LhsExp type of rhs expression
37  * \tparam DType the type of elements
38  */
39 template<typename LhsExp, typename RhsExp, typename DType>
40 struct ImplicitGEMMExp:
41       public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
42                  DType, type::kChainer> {
43   /*! \brief lhs operand */
44   const LhsExp &lhs_;
45   /*! \brief rhs operand */
46   const RhsExp &rhs_;
47   /*! \brief internal production size*/
48   index_t prod_size_;
49   /*! \brief the shape of this expression */
50   Shape<2> shape_;
51   /*! \brief constructor */
ImplicitGEMMExpImplicitGEMMExp52   ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
53       : lhs_(lhs), rhs_(rhs) {
54     Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_);
55     Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_);
56     this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
57     prod_size_ = slhs[1];
58   }
59 };
60 
61 
62 template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
63 inline ImplicitGEMMExp<LhsExp, RhsExp, DType>
implicit_dot(const Exp<LhsExp,DType,e1> & lhs,const Exp<RhsExp,DType,e2> & rhs)64 implicit_dot(const Exp<LhsExp, DType, e1> &lhs,
65              const Exp<RhsExp, DType, e2> &rhs) {
66   TypeCheckPass<ExpInfo<LhsExp>::kDim == 2 && ExpInfo<RhsExp>::kDim == 2>
67       ::Error_Expression_Does_Not_Meet_Dimension_Req();
68   return ImplicitGEMMExp<LhsExp, RhsExp, DType>(lhs.self(), rhs.self());
69 }
70 
71 //----------------------
72 // Execution plan
73 //----------------------
74 template<typename LhsExp, typename RhsExp, typename DType>
75 struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
76  public:
77   explicit Plan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &e)
78       : lhs_(MakePlan(e.lhs_)),
79         rhs_(MakePlan(e.rhs_)),
80         prod_size_(e.prod_size_),
81         prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
82   }
83 
84   MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
85     typedef packet::Packet<DType> Packet;
86     Packet sum = Packet::Fill(0);
87 
88     const size_t packetSize = Packet::size;
89     DType lhs_temp[packetSize], rhs_temp[packetSize];
90 
91     for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) {
92       // unroll
93       for (index_t j = 0; j < packetSize; ++j) {
94         lhs_temp[j] = lhs_.Eval(y, i + j);
95       }
96       for (index_t j = 0; j < packetSize; ++j) {
97         rhs_temp[j] = rhs_.Eval(i + j, x);
98       }
99       sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
100     }
101     DType ret_result = sum.Sum();
102 
103     for (index_t i =  prod_size_lower_align_; i < prod_size_; ++i) {
104       ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
105     }
106     return ret_result;
107   }
108 
109  private:
110   expr::Plan<LhsExp, DType> lhs_;
111   expr::Plan<RhsExp, DType> rhs_;
112   const index_t prod_size_;
113   const index_t prod_size_lower_align_;
114 };
115 
116 template<typename LhsExp, typename RhsExp, typename DType>
117 inline Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>
118 MakePlan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &exp) {
119   return Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>(exp);
120 }
121 
122 
123 template<int dim, typename LhsExp, typename RhsExp, typename DType>
124 struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
125   inline static Shape<dim>
126   Check(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &t) {
127     CHECK(dim == 2)
128         << "ImplicitGEMMExp only support 2 dimension";
129     Shape<dim> shape1 = ShapeCheck<dim, LhsExp>::Check(t.lhs_);
130     Shape<dim> shape2 = ShapeCheck<dim, RhsExp>::Check(t.rhs_);
131     CHECK_EQ(shape1[1], shape2[0])
132       << "implicit_dot The matrix shape do  not match";
133     return t.shape_;
134   }
135 };
136 
137 template<typename LhsExp, typename RhsExp, typename DType>
138 struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
139   static const int kDim = 2;
140   static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask;
141 };
142 
143 }  // namespace expr
144 }  // namespace mshadow
145 #endif  // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
146 
147