1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file implicit_gemm.h
22 * \brief support for implicit GEMM operation
23 * \author Tianqi Chen
24 */
25 #ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
26 #define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
27
28 #include "../extension.h"
29 #include "../packet-inl.h"
30
31 namespace mshadow {
32 namespace expr {
33 /*!
34 * \brief Matrix multiplication.
35 * \tparam LhsExp type of lhs expression
36 * \tparam LhsExp type of rhs expression
37 * \tparam DType the type of elements
38 */
39 template<typename LhsExp, typename RhsExp, typename DType>
40 struct ImplicitGEMMExp:
41 public Exp<ImplicitGEMMExp<LhsExp, RhsExp, DType>,
42 DType, type::kChainer> {
43 /*! \brief lhs operand */
44 const LhsExp &lhs_;
45 /*! \brief rhs operand */
46 const RhsExp &rhs_;
47 /*! \brief internal production size*/
48 index_t prod_size_;
49 /*! \brief the shape of this expression */
50 Shape<2> shape_;
51 /*! \brief constructor */
ImplicitGEMMExpImplicitGEMMExp52 ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs)
53 : lhs_(lhs), rhs_(rhs) {
54 Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_);
55 Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_);
56 this->shape_ = mshadow::Shape2(slhs[0], srhs[1]);
57 prod_size_ = slhs[1];
58 }
59 };
60
61
62 template<typename LhsExp, typename RhsExp, typename DType, int e1, int e2>
63 inline ImplicitGEMMExp<LhsExp, RhsExp, DType>
implicit_dot(const Exp<LhsExp,DType,e1> & lhs,const Exp<RhsExp,DType,e2> & rhs)64 implicit_dot(const Exp<LhsExp, DType, e1> &lhs,
65 const Exp<RhsExp, DType, e2> &rhs) {
66 TypeCheckPass<ExpInfo<LhsExp>::kDim == 2 && ExpInfo<RhsExp>::kDim == 2>
67 ::Error_Expression_Does_Not_Meet_Dimension_Req();
68 return ImplicitGEMMExp<LhsExp, RhsExp, DType>(lhs.self(), rhs.self());
69 }
70
71 //----------------------
72 // Execution plan
73 //----------------------
74 template<typename LhsExp, typename RhsExp, typename DType>
75 struct Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType> {
76 public:
77 explicit Plan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &e)
78 : lhs_(MakePlan(e.lhs_)),
79 rhs_(MakePlan(e.rhs_)),
80 prod_size_(e.prod_size_),
81 prod_size_lower_align_(packet::LowerAlign<DType, MSHADOW_DEFAULT_PACKET>(e.prod_size_)) {
82 }
83
84 MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
85 typedef packet::Packet<DType> Packet;
86 Packet sum = Packet::Fill(0);
87
88 const size_t packetSize = Packet::size;
89 DType lhs_temp[packetSize], rhs_temp[packetSize];
90
91 for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) {
92 // unroll
93 for (index_t j = 0; j < packetSize; ++j) {
94 lhs_temp[j] = lhs_.Eval(y, i + j);
95 }
96 for (index_t j = 0; j < packetSize; ++j) {
97 rhs_temp[j] = rhs_.Eval(i + j, x);
98 }
99 sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp);
100 }
101 DType ret_result = sum.Sum();
102
103 for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) {
104 ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x);
105 }
106 return ret_result;
107 }
108
109 private:
110 expr::Plan<LhsExp, DType> lhs_;
111 expr::Plan<RhsExp, DType> rhs_;
112 const index_t prod_size_;
113 const index_t prod_size_lower_align_;
114 };
115
116 template<typename LhsExp, typename RhsExp, typename DType>
117 inline Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>
118 MakePlan(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &exp) {
119 return Plan<ImplicitGEMMExp<LhsExp, RhsExp, DType>, DType>(exp);
120 }
121
122
123 template<int dim, typename LhsExp, typename RhsExp, typename DType>
124 struct ShapeCheck<dim, ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
125 inline static Shape<dim>
126 Check(const ImplicitGEMMExp<LhsExp, RhsExp, DType> &t) {
127 CHECK(dim == 2)
128 << "ImplicitGEMMExp only support 2 dimension";
129 Shape<dim> shape1 = ShapeCheck<dim, LhsExp>::Check(t.lhs_);
130 Shape<dim> shape2 = ShapeCheck<dim, RhsExp>::Check(t.rhs_);
131 CHECK_EQ(shape1[1], shape2[0])
132 << "implicit_dot The matrix shape do not match";
133 return t.shape_;
134 }
135 };
136
137 template<typename LhsExp, typename RhsExp, typename DType>
138 struct ExpInfo<ImplicitGEMMExp<LhsExp, RhsExp, DType> > {
139 static const int kDim = 2;
140 static const int kDevMask = ExpInfo<LhsExp>::kDevMask & ExpInfo<RhsExp>::kDevMask;
141 };
142
143 } // namespace expr
144 } // namespace mshadow
145 #endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_
146
147