1 
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 //   Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file implements the Xtrmm class (see the header for information about the class).
11 //
12 // =================================================================================================
13 
14 #include "routines/level3/xtrmm.hpp"
15 
16 #include <string>
17 #include <vector>
18 
19 namespace clblast {
20 // =================================================================================================
21 
22 // Constructor: forwards to base class constructor
23 template <typename T>
Xtrmm(Queue & queue,EventPointer event,const std::string & name)24 Xtrmm<T>::Xtrmm(Queue &queue, EventPointer event, const std::string &name):
25     Xgemm<T>(queue, event, name) {
26 }
27 
28 // =================================================================================================
29 
30 // The main routine
31 template <typename T>
DoTrmm(const Layout layout,const Side side,const Triangle triangle,const Transpose a_transpose,const Diagonal diagonal,const size_t m,const size_t n,const T alpha,const Buffer<T> & a_buffer,const size_t a_offset,const size_t a_ld,const Buffer<T> & b_buffer,const size_t b_offset,const size_t b_ld)32 void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle triangle,
33                       const Transpose a_transpose, const Diagonal diagonal,
34                       const size_t m, const size_t n,
35                       const T alpha,
36                       const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
37                       const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
38 
39   // Makes sure all dimensions are larger than zero
40   if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
41 
42   // Computes the k dimension. This is based on whether or not matrix is A (on the left)
43   // or B (on the right) in the Xgemm routine.
44   auto k = (side == Side::kLeft) ? m : n;
45 
46   // Checks for validity of the triangular A matrix
47   TestMatrixA(k, k, a_buffer, a_offset, a_ld);
48 
49   // Checks for validity of the input/output B matrix
50   const auto b_one = (layout == Layout::kRowMajor) ? n : m;
51   const auto b_two = (layout == Layout::kRowMajor) ? m : n;
52   TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
53 
54   // Creates a copy of B to avoid overwriting input in GEMM while computing output
55   const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
56   auto b_buffer_copy = Buffer<T>(context_, b_size);
57   b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
58 
59   // Determines which kernel to run based on the layout (the Xgemm kernel assumes column-major as
60   // default) and on whether we are dealing with an upper or lower triangle of the triangular matrix
61   bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) ||
62                    (triangle == Triangle::kLower && layout == Layout::kRowMajor));
63   auto kernel_name = (is_upper) ? "TriaUpperToSquared" : "TriaLowerToSquared";
64 
65   // Determines whether or not the triangular matrix is unit-diagonal
66   auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false;
67 
68   // Temporary buffer for a copy of the triangular matrix
69   auto temp_triangular = Buffer<T>(context_, k*k);
70 
71   // Creates a general matrix from the triangular matrix to be able to run the regular Xgemm
72   // routine afterwards
73   auto kernel = Kernel(program_, kernel_name);
74 
75   // Sets the arguments for the triangular-to-squared kernel
76   kernel.SetArgument(0, static_cast<int>(k));
77   kernel.SetArgument(1, static_cast<int>(a_ld));
78   kernel.SetArgument(2, static_cast<int>(a_offset));
79   kernel.SetArgument(3, a_buffer());
80   kernel.SetArgument(4, static_cast<int>(k));
81   kernel.SetArgument(5, static_cast<int>(k));
82   kernel.SetArgument(6, static_cast<int>(0));
83   kernel.SetArgument(7, temp_triangular());
84   kernel.SetArgument(8, static_cast<int>(unit_diagonal));
85 
86   // Uses the common padding kernel's thread configuration. This is allowed, since the
87   // triangular-to-squared kernel uses the same parameters.
88   auto global = std::vector<size_t>{Ceil(CeilDiv(k, db_["PAD_WPTX"]), db_["PAD_DIMX"]),
89                                     Ceil(CeilDiv(k, db_["PAD_WPTY"]), db_["PAD_DIMY"])};
90   auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]};
91   auto kernelEvent = Event();
92   RunKernel(kernel, queue_, device_, global, local, kernelEvent.pointer());
93 
94   // Synchronize now: 'DoGemm' does not accept a list of events to wait for
95   kernelEvent.WaitForCompletion();
96 
97   // Runs the regular Xgemm code with either "B := alpha*A*B" or ...
98   if (side == Side::kLeft) {
99     DoGemm(layout, a_transpose, Transpose::kNo,
100            m, n, k,
101            alpha,
102            temp_triangular, 0, k,
103            b_buffer_copy, b_offset, b_ld,
104            ConstantZero<T>(),
105            b_buffer, b_offset, b_ld);
106   }
107 
108   // ... with "B := alpha*B*A". Note that A and B are now reversed.
109   else {
110     try {
111       DoGemm(layout, Transpose::kNo, a_transpose,
112              m, n, k,
113              alpha,
114              b_buffer_copy, b_offset, b_ld,
115              temp_triangular, 0, k,
116              ConstantZero<T>(),
117              b_buffer, b_offset, b_ld);
118     } catch (BLASError &e) {
119       // A and B are now reversed, so also reverse the error codes returned from the Xgemm routine
120       switch(e.status()) {
121         case StatusCode::kInvalidMatrixA:      throw BLASError(StatusCode::kInvalidMatrixB, e.details());
122         case StatusCode::kInvalidMatrixB:      throw BLASError(StatusCode::kInvalidMatrixA, e.details());
123         case StatusCode::kInvalidLeadDimA:     throw BLASError(StatusCode::kInvalidLeadDimB, e.details());
124         case StatusCode::kInvalidLeadDimB:     throw BLASError(StatusCode::kInvalidLeadDimA, e.details());
125         case StatusCode::kInsufficientMemoryA: throw BLASError(StatusCode::kInsufficientMemoryB, e.details());
126         case StatusCode::kInsufficientMemoryB: throw BLASError(StatusCode::kInsufficientMemoryA, e.details());
127         default:                               throw;
128       }
129     }
130   }
131 }
132 
133 // =================================================================================================
134 
135 // Compiles the templated class
136 template class Xtrmm<half>;
137 template class Xtrmm<float>;
138 template class Xtrmm<double>;
139 template class Xtrmm<float2>;
140 template class Xtrmm<double2>;
141 
142 // =================================================================================================
143 } // namespace clblast
144