1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_BASE_H_
16 #define GEMMLOWP_META_BASE_H_
17 
18 #include <cassert>
19 #include <cstdint>
20 
21 #include "../internal/common.h"
22 
23 namespace gemmlowp {
24 namespace meta {
25 
26 template <int align>
AlignTo(int value)27 inline int AlignTo(int value) {
28   return ((value + align - 1) / align) * align;
29 }
30 
AlignTo(int align,int value)31 inline int AlignTo(int align, int value) {
32   return ((value + align - 1) / align) * align;
33 }
34 
35 template <typename Kernel_, typename OutputStream_>
36 struct FusedKernelParams {
37  public:
38   typedef Kernel_ Kernel;
39   typedef OutputStream_ OutputStream;
40 
41   Kernel kernel;
42   OutputStream output_stream;
43 };
44 
45 template <typename InType_, typename OutType_, typename LeftStream_,
46           typename RightStream_, typename Kernel_, typename OutputStream_>
47 struct GemmParams {
48  public:
49   typedef InType_ InType;
50   typedef OutType_ OutType;
51   typedef LeftStream_ LeftStream;
52   typedef RightStream_ RightStream;
53   typedef Kernel_ Kernel;
54   typedef OutputStream_ OutputStream;
55 
56   typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
57 
58   // Common parameters.
59 
60   int m;
61   int n;
62   int k;
63 
64   const InType* lhs;
65   const InType* rhs;
66   OutType* result;
67   std::uint8_t* scratch;
68 
69   // Specialized parameters.
70 
71   LeftStream left_stream;
72   RightStream right_stream;
73   FusedKernel fused_kernel;
74 };
75 
76 template <typename InType, int lanes_count, int pack_size, int leftovers,
77           typename StreamParams>
78 class Stream {
79  public:
80   static void Pack(const InType* in, const StreamParams& params, InType* out);
81 
82   static int UnpackedAdvance(const StreamParams& params);
83 
84   static int PackedAdvance(const StreamParams& params);
85 
86   static int UnpackedStride(const StreamParams& params);
87 
88   static int PackedStride(const StreamParams& params);
89 };
90 
91 template <typename InType, typename StreamType>
92 class StreamUtil {
93  public:
94   static const InType* Offset(const StreamType& params, const InType* source,
95                               int offset_stride, int offset_advance);
96 
97   static int Scratch(const StreamType& params, int lanes);
98 };
99 
100 template <typename InType, typename OutType, typename Kernel,
101           typename OutputStream, int kernel_m, int kernel_n, int pack_size>
102 class MulKernel {
103  public:
104   static void Multiply(const InType* lhs, const InType* rhs,
105                        const FusedKernelParams<Kernel, OutputStream>& params,
106                        OutType* result);
107 };
108 
109 template <typename InType_, typename OutType_, typename Kernel_>
110 struct Transform1DParams {
111   typedef InType_ InType;
112   typedef OutType_ OutType;
113   typedef Kernel_ Kernel;
114 
115   const InType* input;
116   OutType* output;
117   std::uint8_t* scratch;
118 
119   Kernel kernel;
120 };
121 
122 template <typename InType, typename OutType, typename Kernel, int kernel_size,
123           int leftovers>
124 class Transform1DKernel {
125  public:
126   static void Transform(const InType* input, const Kernel& params,
127                         OutType* output);
128 };
129 
130 template <typename InType, typename OutType, typename Transform>
131 class Transform1DUtil {
132  public:
133   static int EstimateComputeCost(const Transform& params);
134 
135   static const InType* OffsetInput(const Transform& params, const InType* input,
136                                    int offset);
137 
138   static OutType* OffsetOutput(const Transform& params, OutType* output,
139                                int offset);
140 };
141 
142 }  // namespace meta
143 }  // namespace gemmlowp
144 
145 #endif  // GEMMLOWP_META_BASE_H_
146