1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file expr_scalar-inl.h
22  * \brief definitions of operators in expression with respect to scalar
23  *  this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types
24  *
25  * DO NOT add pragma once or macro guard
26  * \author Tianqi Chen, Bing Xu
27  */
28 // macro guard is harmful, used to pass the cpplint
29 #ifndef MSHADOW_EXPR_SCALAR_INL_H_
30 #define MSHADOW_EXPR_SCALAR_INL_H_
31 // undef the guard so it can be included multiple times
32 #undef MSHADOW_EXPR_SCALAR_INL_H_
33 
34 namespace mshadow {
35 namespace expr {
36 // DotExp
37 /*! \brief dot operator def */
38 template<typename TA, typename TB, bool ltrans, bool rtrans>
39 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
40 operator*(const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &lhs,
41           MSHADOW_SCALAR_ rhs) {
42   return DotExp<TA, TB, ltrans, rtrans,
43                 MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs);
44 }
45 /*! \brief scale of dot operation */
46 template<typename TA, typename TB, bool ltrans, bool rtrans>
47 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
48 operator*(MSHADOW_SCALAR_ lhs,
49           const DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_> &rhs) {
50   return DotExp<TA, TB, ltrans, rtrans,
51                 MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs);
52 }
53 
54 /*! \brief operator overload */
55 template<typename E, typename DType, typename R, int d>
56 inline ReduceTo1DExp<E, DType, R, d>
57 operator*(const ReduceTo1DExp<E, DType, R, d> &e, MSHADOW_SCALAR_ scale) {
58   return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
59 }
60 /*! \brief operator overload */
61 template<typename E, typename DType, typename R, int d>
62 inline ReduceTo1DExp<E, DType, R, d>
63 operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp<E, DType, R, d> &e) {
64   return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
65 }
66 
67 /*! \brief operator overload for const */
68 template<typename OP, typename TA, int ta>
69 inline BinaryMapExp<OP, TA, ScalarExp<MSHADOW_SCALAR_>,
70                     MSHADOW_SCALAR_, (ta|type::kMapper)>
F(const Exp<TA,MSHADOW_SCALAR_,ta> & lhs,const ScalarExp<MSHADOW_SCALAR_> & rhs)71 F(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
72   return MakeExp<OP>(lhs, rhs);
73 }
74 /*! \brief operator overload for const */
75 template<typename OP, typename TB, int tb>
76 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, TB,
77                     MSHADOW_SCALAR_, (tb|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> & lhs,const Exp<TB,MSHADOW_SCALAR_,tb> & rhs)78 F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
79   return MakeExp<OP>(lhs, rhs);
80 }
81 /*! \brief operator overload for const */
82 template<typename OP>
83 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
84                     MSHADOW_SCALAR_, (1|type::kMapper)>
F(const ScalarExp<MSHADOW_SCALAR_> & lhs,const ScalarExp<MSHADOW_SCALAR_> & rhs)85 F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
86   return MakeExp<OP>(lhs, rhs);
87 }
88 // constant operators
89 /*! \brief operator overload */
90 template<typename TA, int ta>
91 inline BinaryMapExp<op::plus, TA, ScalarExp<MSHADOW_SCALAR_>,
92                     MSHADOW_SCALAR_, (ta|type::kMapper)>
93 operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
94           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
95   return MakeExp<op::plus>(lhs, rhs);
96 }
97 /*! \brief operator overload */
98 template<typename TA, int ta>
99 inline BinaryMapExp<op::minus, TA, ScalarExp<MSHADOW_SCALAR_>,
100                     MSHADOW_SCALAR_, (ta|type::kMapper)>
101 operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
102           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
103   return MakeExp<op::minus>(lhs, rhs);
104 }
105 /*! \brief operator overload */
106 template<typename TA, int ta>
107 inline BinaryMapExp<op::mul, TA, ScalarExp<MSHADOW_SCALAR_>,
108                     MSHADOW_SCALAR_, (ta|type::kMapper)>
109 operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
110           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
111   return MakeExp<op::mul>(lhs, rhs);
112 }
113 /*! \brief operator overload */
114 template<typename TA, int ta>
115 inline BinaryMapExp<op::div, TA, ScalarExp<MSHADOW_SCALAR_>,
116                     MSHADOW_SCALAR_, (ta|type::kMapper)>
117 operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
118           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
119   return MakeExp<op::div>(lhs, rhs);
120 }
121 // constant operators 2
122 /*! \brief operator overload */
123 template<typename TB, int tb>
124 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, TB,
125                     MSHADOW_SCALAR_, (tb|type::kMapper)>
126 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
127           const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
128   return MakeExp<op::plus>(lhs, rhs);
129 }
130 /*! \brief operator overload */
131 template<typename TB, int tb>
132 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, TB,
133                     MSHADOW_SCALAR_, (tb|type::kMapper)>
134 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
135           const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
136   return MakeExp<op::minus>(lhs, rhs);
137 }
138 /*! \brief operator overload */
139 template<typename TB, int tb>
140 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, TB,
141                     MSHADOW_SCALAR_, (tb|type::kMapper)>
142 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
143           const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
144   return MakeExp<op::mul>(lhs, rhs);
145 }
146 /*! \brief operator overload */
147 template<typename TB, int tb>
148 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, TB,
149                     MSHADOW_SCALAR_, (tb|type::kMapper)>
150 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
151   return MakeExp<op::div>(lhs, rhs);
152 }
153 // constant operators 3
154 /*! \brief operator overload */
155 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
156                     MSHADOW_SCALAR_, (1|type::kMapper)>
157 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
158           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
159   return MakeExp<op::plus>(lhs, rhs);
160 }
161 /*! \brief operator overload */
162 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
163                     MSHADOW_SCALAR_, (1|type::kMapper)>
164 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
165           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
166   return MakeExp<op::minus>(lhs, rhs);
167 }
168 /*! \brief operator overload */
169 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
170                     MSHADOW_SCALAR_, (1|type::kMapper)>
171 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
172           const ScalarExp<MSHADOW_SCALAR_> &rhs) {
173   return MakeExp<op::mul>(lhs, rhs);
174 }
175 /*! \brief operator overload */
176 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
177                     MSHADOW_SCALAR_, (1|type::kMapper)>
178 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
179   return MakeExp<op::div>(lhs, rhs);
180 }
181 }  // namespace expr
182 }  // namespace mshadow
183 #endif  // MSHADOW_EXPR_SCALAR_INL_H_
184