1
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 // Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file implements the Xtrmm class (see the header for information about the class).
11 //
12 // =================================================================================================
13
14 #include "routines/level3/xtrmm.hpp"
15
16 #include <string>
17 #include <vector>
18
19 namespace clblast {
20 // =================================================================================================
21
22 // Constructor: forwards to base class constructor
23 template <typename T>
Xtrmm(Queue & queue,EventPointer event,const std::string & name)24 Xtrmm<T>::Xtrmm(Queue &queue, EventPointer event, const std::string &name):
25 Xgemm<T>(queue, event, name) {
26 }
27
28 // =================================================================================================
29
30 // The main routine
31 template <typename T>
DoTrmm(const Layout layout,const Side side,const Triangle triangle,const Transpose a_transpose,const Diagonal diagonal,const size_t m,const size_t n,const T alpha,const Buffer<T> & a_buffer,const size_t a_offset,const size_t a_ld,const Buffer<T> & b_buffer,const size_t b_offset,const size_t b_ld)32 void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle triangle,
33 const Transpose a_transpose, const Diagonal diagonal,
34 const size_t m, const size_t n,
35 const T alpha,
36 const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
37 const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
38
39 // Makes sure all dimensions are larger than zero
40 if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
41
42 // Computes the k dimension. This is based on whether or not matrix is A (on the left)
43 // or B (on the right) in the Xgemm routine.
44 auto k = (side == Side::kLeft) ? m : n;
45
46 // Checks for validity of the triangular A matrix
47 TestMatrixA(k, k, a_buffer, a_offset, a_ld);
48
49 // Checks for validity of the input/output B matrix
50 const auto b_one = (layout == Layout::kRowMajor) ? n : m;
51 const auto b_two = (layout == Layout::kRowMajor) ? m : n;
52 TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
53
54 // Creates a copy of B to avoid overwriting input in GEMM while computing output
55 const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
56 auto b_buffer_copy = Buffer<T>(context_, b_size);
57 b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
58
59 // Determines which kernel to run based on the layout (the Xgemm kernel assumes column-major as
60 // default) and on whether we are dealing with an upper or lower triangle of the triangular matrix
61 bool is_upper = ((triangle == Triangle::kUpper && layout != Layout::kRowMajor) ||
62 (triangle == Triangle::kLower && layout == Layout::kRowMajor));
63 auto kernel_name = (is_upper) ? "TriaUpperToSquared" : "TriaLowerToSquared";
64
65 // Determines whether or not the triangular matrix is unit-diagonal
66 auto unit_diagonal = (diagonal == Diagonal::kUnit) ? true : false;
67
68 // Temporary buffer for a copy of the triangular matrix
69 auto temp_triangular = Buffer<T>(context_, k*k);
70
71 // Creates a general matrix from the triangular matrix to be able to run the regular Xgemm
72 // routine afterwards
73 auto kernel = Kernel(program_, kernel_name);
74
75 // Sets the arguments for the triangular-to-squared kernel
76 kernel.SetArgument(0, static_cast<int>(k));
77 kernel.SetArgument(1, static_cast<int>(a_ld));
78 kernel.SetArgument(2, static_cast<int>(a_offset));
79 kernel.SetArgument(3, a_buffer());
80 kernel.SetArgument(4, static_cast<int>(k));
81 kernel.SetArgument(5, static_cast<int>(k));
82 kernel.SetArgument(6, static_cast<int>(0));
83 kernel.SetArgument(7, temp_triangular());
84 kernel.SetArgument(8, static_cast<int>(unit_diagonal));
85
86 // Uses the common padding kernel's thread configuration. This is allowed, since the
87 // triangular-to-squared kernel uses the same parameters.
88 auto global = std::vector<size_t>{Ceil(CeilDiv(k, db_["PAD_WPTX"]), db_["PAD_DIMX"]),
89 Ceil(CeilDiv(k, db_["PAD_WPTY"]), db_["PAD_DIMY"])};
90 auto local = std::vector<size_t>{db_["PAD_DIMX"], db_["PAD_DIMY"]};
91 auto kernelEvent = Event();
92 RunKernel(kernel, queue_, device_, global, local, kernelEvent.pointer());
93
94 // Synchronize now: 'DoGemm' does not accept a list of events to wait for
95 kernelEvent.WaitForCompletion();
96
97 // Runs the regular Xgemm code with either "B := alpha*A*B" or ...
98 if (side == Side::kLeft) {
99 DoGemm(layout, a_transpose, Transpose::kNo,
100 m, n, k,
101 alpha,
102 temp_triangular, 0, k,
103 b_buffer_copy, b_offset, b_ld,
104 ConstantZero<T>(),
105 b_buffer, b_offset, b_ld);
106 }
107
108 // ... with "B := alpha*B*A". Note that A and B are now reversed.
109 else {
110 try {
111 DoGemm(layout, Transpose::kNo, a_transpose,
112 m, n, k,
113 alpha,
114 b_buffer_copy, b_offset, b_ld,
115 temp_triangular, 0, k,
116 ConstantZero<T>(),
117 b_buffer, b_offset, b_ld);
118 } catch (BLASError &e) {
119 // A and B are now reversed, so also reverse the error codes returned from the Xgemm routine
120 switch(e.status()) {
121 case StatusCode::kInvalidMatrixA: throw BLASError(StatusCode::kInvalidMatrixB, e.details());
122 case StatusCode::kInvalidMatrixB: throw BLASError(StatusCode::kInvalidMatrixA, e.details());
123 case StatusCode::kInvalidLeadDimA: throw BLASError(StatusCode::kInvalidLeadDimB, e.details());
124 case StatusCode::kInvalidLeadDimB: throw BLASError(StatusCode::kInvalidLeadDimA, e.details());
125 case StatusCode::kInsufficientMemoryA: throw BLASError(StatusCode::kInsufficientMemoryB, e.details());
126 case StatusCode::kInsufficientMemoryB: throw BLASError(StatusCode::kInsufficientMemoryA, e.details());
127 default: throw;
128 }
129 }
130 }
131 }
132
133 // =================================================================================================
134
135 // Compiles the templated class
136 template class Xtrmm<half>;
137 template class Xtrmm<float>;
138 template class Xtrmm<double>;
139 template class Xtrmm<float2>;
140 template class Xtrmm<double2>;
141
142 // =================================================================================================
143 } // namespace clblast
144