1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file reduceto1d.h
22  * \brief support for sum_rows and sumall_except_dim
23  * \author Tianqi Chen
24  */
25 #ifndef MSHADOW_EXTENSION_REDUCETO1D_H_
26 #define MSHADOW_EXTENSION_REDUCETO1D_H_
27 #include "../extension.h"
28 namespace mshadow {
29 namespace expr {
30 /*!
31  * \brief reduction to 1 dimension tensor
32  * input: Tensor<Device,k>: ishape
33  * output: Tensor<Device,1> shape[0] = ishape[dimkeep];
34  *
35  * \tparam SrcExp type of expression to be reduced
36  * \tparam DType the data type of the scalar
37  * \tparam Reducer which reducer to use
38  * \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep
39  */
40 template<typename SrcExp, typename DType, typename Reducer, int m_dimkeep>
41 struct ReduceTo1DExp:
42       public Exp<ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
43                  DType, type::kComplex> {
44   /*! \brief source operand */
45   const SrcExp &src_;
46   /*! \brief source operand, scale of the  */
47   DType scale_;
48   /*! \brief construct a repmat expression from src and nrow */
ReduceTo1DExpReduceTo1DExp49   ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {}
50 };
51 /*!
52  * \brief a sum over all dimensions, except dimkeep
53  * \param exp input expression that must be a matrix Tensor<?,2>
54  * \return a expresion with type Tensor<Device,1>
55  * \tparam dimkeep the dimension that will be kept
56  * \tparam SrcExp expression
57  * \tparam etype type of expression
58  */
59 template<int dimkeep,  typename SrcExp, typename DType, int etype>
60 inline ReduceTo1DExp<SrcExp, DType, red::sum,
61                      ExpInfo<SrcExp>::kDim - dimkeep>
sumall_except_dim(const Exp<SrcExp,DType,etype> & exp)62 sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) {
63   return ReduceTo1DExp<SrcExp, DType, red::sum,
64                        ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
65 }
66 /*!
67  * \brief reduce over all dimensions, except dimkeep
68  * \param exp input expression that must be a matrix Tensor<?,2>
69  * \return a expresion with type Tensor<Device,1>
70  * \tparam dimkeep the dimension that will be kept
71  * \tparam SrcExp expression
72  * \tparam etype type of expression
73  */
74 template<int dimkeep, typename Reducer, typename SrcExp, typename DType, int etype>
75 inline ReduceTo1DExp<SrcExp, DType, Reducer,
76                      ExpInfo<SrcExp>::kDim - dimkeep>
reduce_except_dim(const Exp<SrcExp,DType,etype> & exp)77 reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) {
78   return ReduceTo1DExp<SrcExp, DType, Reducer,
79                        ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
80 }
81 /*!
82  * \brief a expression that sum over rows of a matrix
83  * \param exp input expression that must be a matrix Tensor<?, 2>
84  * \return a expresion with type Tensor<Device, 1>
85  * \tparam SrcExp expression
86  * \tparam etype type of expression
87  */
88 template<typename SrcExp, typename DType, int etype>
89 inline ReduceTo1DExp<SrcExp, DType, red::sum, 1>
sum_rows(const Exp<SrcExp,DType,etype> & exp)90 sum_rows(const Exp<SrcExp, DType, etype> &exp) {
91   TypeCheckPass<ExpInfo<SrcExp>::kDim ==2>
92       ::Error_Expression_Does_Not_Meet_Dimension_Req();
93   return sumall_except_dim<1>(exp);
94 }
95 template<typename SV, typename Device, typename DType,
96          typename SrcExp, typename Reducer, int m_dimkeep>
97 struct ExpComplexEngine<SV,
98                         Tensor<Device, 1, DType>,
99                         ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
100                         DType> {
101   static const int dimkeep = ExpInfo<SrcExp>::kDim - m_dimkeep;
102   inline static void Eval(Tensor<Device, 1, DType> *dst,
103                           const ReduceTo1DExp<SrcExp, DType,
104                                               Reducer, m_dimkeep> &exp) {
105     TypeCheckPass<m_dimkeep != 1>
106         ::Error_Expression_Does_Not_Meet_Dimension_Req();
107     MapReduceKeepHighDim<SV, Reducer, dimkeep>(dst, exp.src_, exp.scale_);
108   }
109 };
110 template<typename SV, typename Device, typename DType,
111          typename SrcExp, typename Reducer>
112 struct ExpComplexEngine<SV,
113                         Tensor<Device, 1, DType>,
114                         ReduceTo1DExp<SrcExp, DType, Reducer, 1>, DType> {
115   inline static void Eval(Tensor<Device, 1, DType> *dst,
116                           const ReduceTo1DExp<SrcExp, DType, Reducer, 1> &exp) {
117     MapReduceKeepLowest<SV, Reducer>(dst, exp.src_, exp.scale_);
118   }
119 };
120 }  // namespace expr
121 }  // namespace mshadow
122 #endif  // MSHADOW_EXTENSION_REDUCETO1D_H_
123