1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file expr_scalar-inl.h
22 * \brief definitions of operators in expression with respect to scalar
23 * this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types
24 *
25 * DO NOT add pragma once or macro guard
26 * \author Tianqi Chen, Bing Xu
27 */
28 // macro guard is harmful, used to pass the cpplint
29 #ifndef MSHADOW_EXPR_SCALAR_INL_H_
30 #define MSHADOW_EXPR_SCALAR_INL_H_
31 // undef the guard so it can be included multiple times
32 #undef MSHADOW_EXPR_SCALAR_INL_H_
33
34 namespace mshadow {
35 namespace expr {
36 // DotExp
37 /*! \brief dot operator def */
38 template<typename TA, typename TB, bool ltrans, bool rtrans>
39 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
40 operator*(const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &lhs,
41 MSHADOW_SCALAR_ rhs) {
42 return DotExp<TA, TB, ltrans, rtrans,
43 MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs);
44 }
45 /*! \brief scale of dot operation */
46 template<typename TA, typename TB, bool ltrans, bool rtrans>
47 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
48 operator*(MSHADOW_SCALAR_ lhs,
49 const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &rhs) {
50 return DotExp<TA, TB, ltrans, rtrans,
51 MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs);
52 }
53
54 /*! \brief operator overload */
55 template<typename E, typename DType, typename R, int d>
56 inline ReduceTo1DExp<E, DType, R, d>
57 operator*(const ReduceTo1DExp<E, DType, R, d> &e, MSHADOW_SCALAR_ scale) {
58 return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
59 }
60 /*! \brief operator overload */
61 template<typename E, typename DType, typename R, int d>
62 inline ReduceTo1DExp<E, DType, R, d>
63 operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp<E, DType, R, d> &e) {
64 return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
65 }
66
67 /*! \brief operator overload for const */
68 template<typename OP, typename TA, int ta>
69 inline BinaryMapExp<OP, TA, ScalarExp<MSHADOW_SCALAR_>,
70 MSHADOW_SCALAR_, (ta|type::kMapper)>
F(const Exp<TA,MSHADOW_SCALAR_,ta> & lhs,const ScalarExp<MSHADOW_SCALAR_> & rhs)71 F(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
72 return MakeExp<OP>(lhs, rhs);
73 }
74 /*! \brief operator overload for const */
75 template<typename OP, typename TB, int tb>
76 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, TB,
77 MSHADOW_SCALAR_, (tb|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> & lhs,const Exp<TB,MSHADOW_SCALAR_,tb> & rhs)78 F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
79 return MakeExp<OP>(lhs, rhs);
80 }
81 /*! \brief operator overload for const */
82 template<typename OP>
83 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
84 MSHADOW_SCALAR_, (1|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> & lhs,const ScalarExp<MSHADOW_SCALAR_> & rhs)85 F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
86 return MakeExp<OP>(lhs, rhs);
87 }
88 // constant operators
89 /*! \brief operator overload */
90 template<typename TA, int ta>
91 inline BinaryMapExp<op::plus, TA, ScalarExp<MSHADOW_SCALAR_>,
92 MSHADOW_SCALAR_, (ta|type::kMapper)>
93 operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
94 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
95 return MakeExp<op::plus>(lhs, rhs);
96 }
97 /*! \brief operator overload */
98 template<typename TA, int ta>
99 inline BinaryMapExp<op::minus, TA, ScalarExp<MSHADOW_SCALAR_>,
100 MSHADOW_SCALAR_, (ta|type::kMapper)>
101 operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
102 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
103 return MakeExp<op::minus>(lhs, rhs);
104 }
105 /*! \brief operator overload */
106 template<typename TA, int ta>
107 inline BinaryMapExp<op::mul, TA, ScalarExp<MSHADOW_SCALAR_>,
108 MSHADOW_SCALAR_, (ta|type::kMapper)>
109 operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
110 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
111 return MakeExp<op::mul>(lhs, rhs);
112 }
113 /*! \brief operator overload */
114 template<typename TA, int ta>
115 inline BinaryMapExp<op::div, TA, ScalarExp<MSHADOW_SCALAR_>,
116 MSHADOW_SCALAR_, (ta|type::kMapper)>
117 operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
118 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
119 return MakeExp<op::div>(lhs, rhs);
120 }
121 // constant operators 2
122 /*! \brief operator overload */
123 template<typename TB, int tb>
124 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, TB,
125 MSHADOW_SCALAR_, (tb|type::kMapper)>
126 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
127 const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
128 return MakeExp<op::plus>(lhs, rhs);
129 }
130 /*! \brief operator overload */
131 template<typename TB, int tb>
132 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, TB,
133 MSHADOW_SCALAR_, (tb|type::kMapper)>
134 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
135 const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
136 return MakeExp<op::minus>(lhs, rhs);
137 }
138 /*! \brief operator overload */
139 template<typename TB, int tb>
140 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, TB,
141 MSHADOW_SCALAR_, (tb|type::kMapper)>
142 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
143 const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
144 return MakeExp<op::mul>(lhs, rhs);
145 }
146 /*! \brief operator overload */
147 template<typename TB, int tb>
148 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, TB,
149 MSHADOW_SCALAR_, (tb|type::kMapper)>
150 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
151 return MakeExp<op::div>(lhs, rhs);
152 }
153 // constant operators 3
154 /*! \brief operator overload */
155 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
156 MSHADOW_SCALAR_, (1|type::kMapper)>
157 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
158 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
159 return MakeExp<op::plus>(lhs, rhs);
160 }
161 /*! \brief operator overload */
162 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
163 MSHADOW_SCALAR_, (1|type::kMapper)>
164 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
165 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
166 return MakeExp<op::minus>(lhs, rhs);
167 }
168 /*! \brief operator overload */
169 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
170 MSHADOW_SCALAR_, (1|type::kMapper)>
171 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
172 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
173 return MakeExp<op::mul>(lhs, rhs);
174 }
175 /*! \brief operator overload */
176 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
177 MSHADOW_SCALAR_, (1|type::kMapper)>
178 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
179 return MakeExp<op::div>(lhs, rhs);
180 }
181 } // namespace expr
182 } // namespace mshadow
183 #endif // MSHADOW_EXPR_SCALAR_INL_H_
184