1 //=================================================================================================
2 /*!
3 // \file blaze/math/expressions/DMatSoftmaxExpr.h
4 // \brief Header file for the dense matrix softmax expression
5 //
6 // Copyright (C) 2012-2020 Klaus Iglberger - All Rights Reserved
7 //
8 // This file is part of the Blaze library. You can redistribute it and/or modify it under
9 // the terms of the New (Revised) BSD License. Redistribution and use in source and binary
10 // forms, with or without modification, are permitted provided that the following conditions
11 // are met:
12 //
13 // 1. Redistributions of source code must retain the above copyright notice, this list of
14 // conditions and the following disclaimer.
15 // 2. Redistributions in binary form must reproduce the above copyright notice, this list
16 // of conditions and the following disclaimer in the documentation and/or other materials
17 // provided with the distribution.
18 // 3. Neither the names of the Blaze development group nor the names of its contributors
19 // may be used to endorse or promote products derived from this software without specific
20 // prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
23 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
24 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
25 // SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
27 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
28 // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
31 // DAMAGE.
32 */
33 //=================================================================================================
34
35 #ifndef _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
36 #define _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
37
38
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42
43 #include <blaze/math/expressions/DenseMatrix.h>
44 #include <blaze/math/ReductionFlag.h>
45 #include <blaze/math/views/Check.h>
46 #include <blaze/math/views/Column.h>
47 #include <blaze/math/views/Row.h>
48
49
50 namespace blaze {
51
52 //=================================================================================================
53 //
54 // GLOBAL FUNCTIONS
55 //
56 //=================================================================================================
57
58 //*************************************************************************************************
59 /*!\brief Computes the softmax function for the given dense matrix.
60 // \ingroup dense_matrix
61 //
62 // \param dm The given dense matrix for the softmax computation.
63 // \return The resulting matrix.
64 //
65 // This function computes the softmax function (i.e. the normalized exponential function) for
66 // the given dense matrix \a dm (see also https://en.wikipedia.org/wiki/Softmax_function). The
67 // resulting dense matrix consists of real values in the range (0..1], which add up to 1.
68
69 \code
70 // Creating the matrix
71 // ( 1 2 3 )
72 // ( 4 1 2 )
73 // ( 3 4 1 )
74 blaze::StaticMatrix<double,3UL,3UL> A{ { 1.0, 2.0, 3.0 }
75 , { 4.0, 1.0, 2.0 }
76 , { 3.0, 4.0, 1.0 } };
77
78 // Computing the total softmax of A (sum(B) == 1)
79 // ( 0.0157764 0.0428847 0.116573 )
80 // ( 0.316878 0.0157764 0.0428847 )
81 // ( 0.116573 0.316878 0.0157764 )
82 blaze::StaticMatrix<double,3UL,3UL> B;
83 B = softmax( A );
84 \endcode
85 */
86 template< typename MT // Type of the dense matrix
87 , bool SO > // Storage order
softmax(const DenseMatrix<MT,SO> & dm)88 auto softmax( const DenseMatrix<MT,SO>& dm )
89 {
90 auto tmp( evaluate( exp( *dm ) ) );
91 const auto scalar( sum( tmp ) );
92 tmp /= scalar;
93 return tmp;
94 }
95 //*************************************************************************************************
96
97
98 //*************************************************************************************************
99 /*!\brief Computes the row-/columnwise softmax function for the given dense matrix.
100 // \ingroup dense_matrix
101 //
102 // \param dm The given dense matrix for the softmax computation.
103 // \return The resulting matrix.
104 //
105 // This function computes the row-/columnwise softmax function (i.e. the normalized exponential
106 // function) for the given dense matrix \a dm (see also https://en.wikipedia.org/wiki/Softmax_function).
107 // The resulting dense matrix consists of real values in the range (0..1], which add up to the
108 // numbers of rows or columns, respectively.
109
110 \code
111 // Creating the matrix
112 // ( 1 2 3 )
113 // ( 4 1 2 )
114 // ( 3 4 1 )
115 blaze::StaticMatrix<double,3UL,3UL> A{ { 1.0, 2.0, 3.0 }
116 , { 4.0, 1.0, 2.0 }
117 , { 3.0, 4.0, 1.0 } };
118
119 // Computing the rowwise softmax of A (sum(B) == 3)
120 // ( 0.0900306 0.244728 0.665241 )
121 // ( 0.843795 0.0420101 0.114195 )
122 // ( 0.259496 0.705385 0.035119 )
123 blaze::StaticMatrix<double,3UL,3UL> B;
124 B = softmax<rowwise>( A );
125
126 // Computing the columnwise softmax of A (sum(C) == 3)
127 // ( 0.035119 0.114195 0.665241 )
128 // ( 0.705385 0.0420101 0.244728 )
129 // ( 0.259496 0.843795 0.0900306 )
130 blaze::StaticMatrix<double,3UL,3UL> C;
131 C = softmax<columnwise>( A );
132 \endcode
133 */
134 template< ReductionFlag RF // Reduction flag
135 , typename MT // Type of the dense matrix
136 , bool SO > // Storage order
softmax(const DenseMatrix<MT,SO> & dm)137 auto softmax( const DenseMatrix<MT,SO>& dm )
138 {
139 auto tmp( evaluate( exp( *dm ) ) );
140
141 if( RF == rowwise ) {
142 for( size_t i=0UL; i<tmp.rows(); ++i ) {
143 auto r = row( tmp, i, unchecked );
144 const auto scalar( sum( r ) );
145 r /= scalar;
146 }
147 }
148 else {
149 for( size_t j=0UL; j<tmp.columns(); ++j ) {
150 auto c = column( tmp, j, unchecked );
151 const auto scalar( sum( c ) );
152 c /= scalar;
153 }
154 }
155
156 return tmp;
157 }
158 //*************************************************************************************************
159
160 } // namespace blaze
161
162 #endif
163