1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef GEMMLOWP_META_BASE_H_
16 #define GEMMLOWP_META_BASE_H_
17
18 #include <cassert>
19 #include <cstdint>
20
21 #include "../internal/common.h"
22
23 namespace gemmlowp {
24 namespace meta {
25
26 template <int align>
AlignTo(int value)27 inline int AlignTo(int value) {
28 return ((value + align - 1) / align) * align;
29 }
30
AlignTo(int align,int value)31 inline int AlignTo(int align, int value) {
32 return ((value + align - 1) / align) * align;
33 }
34
35 template <typename Kernel_, typename OutputStream_>
36 struct FusedKernelParams {
37 public:
38 typedef Kernel_ Kernel;
39 typedef OutputStream_ OutputStream;
40
41 Kernel kernel;
42 OutputStream output_stream;
43 };
44
45 template <typename InType_, typename OutType_, typename LeftStream_,
46 typename RightStream_, typename Kernel_, typename OutputStream_>
47 struct GemmParams {
48 public:
49 typedef InType_ InType;
50 typedef OutType_ OutType;
51 typedef LeftStream_ LeftStream;
52 typedef RightStream_ RightStream;
53 typedef Kernel_ Kernel;
54 typedef OutputStream_ OutputStream;
55
56 typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
57
58 // Common parameters.
59
60 int m;
61 int n;
62 int k;
63
64 const InType* lhs;
65 const InType* rhs;
66 OutType* result;
67 std::uint8_t* scratch;
68
69 // Specialized parameters.
70
71 LeftStream left_stream;
72 RightStream right_stream;
73 FusedKernel fused_kernel;
74 };
75
76 template <typename InType, int lanes_count, int pack_size, int leftovers,
77 typename StreamParams>
78 class Stream {
79 public:
80 static void Pack(const InType* in, const StreamParams& params, InType* out);
81
82 static int UnpackedAdvance(const StreamParams& params);
83
84 static int PackedAdvance(const StreamParams& params);
85
86 static int UnpackedStride(const StreamParams& params);
87
88 static int PackedStride(const StreamParams& params);
89 };
90
91 template <typename InType, typename StreamType>
92 class StreamUtil {
93 public:
94 static const InType* Offset(const StreamType& params, const InType* source,
95 int offset_stride, int offset_advance);
96
97 static int Scratch(const StreamType& params, int lanes);
98 };
99
100 template <typename InType, typename OutType, typename Kernel,
101 typename OutputStream, int kernel_m, int kernel_n, int pack_size>
102 class MulKernel {
103 public:
104 static void Multiply(const InType* lhs, const InType* rhs,
105 const FusedKernelParams<Kernel, OutputStream>& params,
106 OutType* result);
107 };
108
109 template <typename InType_, typename OutType_, typename Kernel_>
110 struct Transform1DParams {
111 typedef InType_ InType;
112 typedef OutType_ OutType;
113 typedef Kernel_ Kernel;
114
115 const InType* input;
116 OutType* output;
117 std::uint8_t* scratch;
118
119 Kernel kernel;
120 };
121
122 template <typename InType, typename OutType, typename Kernel, int kernel_size,
123 int leftovers>
124 class Transform1DKernel {
125 public:
126 static void Transform(const InType* input, const Kernel& params,
127 OutType* output);
128 };
129
130 template <typename InType, typename OutType, typename Transform>
131 class Transform1DUtil {
132 public:
133 static int EstimateComputeCost(const Transform& params);
134
135 static const InType* OffsetInput(const Transform& params, const InType* input,
136 int offset);
137
138 static OutType* OffsetOutput(const Transform& params, OutType* output,
139 int offset);
140 };
141
142 } // namespace meta
143 } // namespace gemmlowp
144
145 #endif // GEMMLOWP_META_BASE_H_
146