1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file reduceto1d.h
22 * \brief support for sum_rows and sumall_except_dim
23 * \author Tianqi Chen
24 */
25 #ifndef MSHADOW_EXTENSION_REDUCETO1D_H_
26 #define MSHADOW_EXTENSION_REDUCETO1D_H_
27 #include "../extension.h"
28 namespace mshadow {
29 namespace expr {
30 /*!
31 * \brief reduction to 1 dimension tensor
32 * input: Tensor<Device,k>: ishape
33 * output: Tensor<Device,1> shape[0] = ishape[dimkeep];
34 *
35 * \tparam SrcExp type of expression to be reduced
36 * \tparam DType the data type of the scalar
37 * \tparam Reducer which reducer to use
38 * \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep
39 */
40 template<typename SrcExp, typename DType, typename Reducer, int m_dimkeep>
41 struct ReduceTo1DExp:
42 public Exp<ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
43 DType, type::kComplex> {
44 /*! \brief source operand */
45 const SrcExp &src_;
46 /*! \brief source operand, scale of the */
47 DType scale_;
48 /*! \brief construct a repmat expression from src and nrow */
ReduceTo1DExpReduceTo1DExp49 ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {}
50 };
51 /*!
52 * \brief a sum over all dimensions, except dimkeep
53 * \param exp input expression that must be a matrix Tensor<?,2>
54 * \return a expresion with type Tensor<Device,1>
55 * \tparam dimkeep the dimension that will be kept
56 * \tparam SrcExp expression
57 * \tparam etype type of expression
58 */
59 template<int dimkeep, typename SrcExp, typename DType, int etype>
60 inline ReduceTo1DExp<SrcExp, DType, red::sum,
61 ExpInfo<SrcExp>::kDim - dimkeep>
sumall_except_dim(const Exp<SrcExp,DType,etype> & exp)62 sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) {
63 return ReduceTo1DExp<SrcExp, DType, red::sum,
64 ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
65 }
66 /*!
67 * \brief reduce over all dimensions, except dimkeep
68 * \param exp input expression that must be a matrix Tensor<?,2>
69 * \return a expresion with type Tensor<Device,1>
70 * \tparam dimkeep the dimension that will be kept
71 * \tparam SrcExp expression
72 * \tparam etype type of expression
73 */
74 template<int dimkeep, typename Reducer, typename SrcExp, typename DType, int etype>
75 inline ReduceTo1DExp<SrcExp, DType, Reducer,
76 ExpInfo<SrcExp>::kDim - dimkeep>
reduce_except_dim(const Exp<SrcExp,DType,etype> & exp)77 reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) {
78 return ReduceTo1DExp<SrcExp, DType, Reducer,
79 ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
80 }
81 /*!
82 * \brief a expression that sum over rows of a matrix
83 * \param exp input expression that must be a matrix Tensor<?, 2>
84 * \return a expresion with type Tensor<Device, 1>
85 * \tparam SrcExp expression
86 * \tparam etype type of expression
87 */
88 template<typename SrcExp, typename DType, int etype>
89 inline ReduceTo1DExp<SrcExp, DType, red::sum, 1>
sum_rows(const Exp<SrcExp,DType,etype> & exp)90 sum_rows(const Exp<SrcExp, DType, etype> &exp) {
91 TypeCheckPass<ExpInfo<SrcExp>::kDim ==2>
92 ::Error_Expression_Does_Not_Meet_Dimension_Req();
93 return sumall_except_dim<1>(exp);
94 }
95 template<typename SV, typename Device, typename DType,
96 typename SrcExp, typename Reducer, int m_dimkeep>
97 struct ExpComplexEngine<SV,
98 Tensor<Device, 1, DType>,
99 ReduceTo1DExp<SrcExp, DType, Reducer, m_dimkeep>,
100 DType> {
101 static const int dimkeep = ExpInfo<SrcExp>::kDim - m_dimkeep;
102 inline static void Eval(Tensor<Device, 1, DType> *dst,
103 const ReduceTo1DExp<SrcExp, DType,
104 Reducer, m_dimkeep> &exp) {
105 TypeCheckPass<m_dimkeep != 1>
106 ::Error_Expression_Does_Not_Meet_Dimension_Req();
107 MapReduceKeepHighDim<SV, Reducer, dimkeep>(dst, exp.src_, exp.scale_);
108 }
109 };
110 template<typename SV, typename Device, typename DType,
111 typename SrcExp, typename Reducer>
112 struct ExpComplexEngine<SV,
113 Tensor<Device, 1, DType>,
114 ReduceTo1DExp<SrcExp, DType, Reducer, 1>, DType> {
115 inline static void Eval(Tensor<Device, 1, DType> *dst,
116 const ReduceTo1DExp<SrcExp, DType, Reducer, 1> &exp) {
117 MapReduceKeepLowest<SV, Reducer>(dst, exp.src_, exp.scale_);
118 }
119 };
120 } // namespace expr
121 } // namespace mshadow
122 #endif // MSHADOW_EXTENSION_REDUCETO1D_H_
123