1 // Copyright (c) 2019 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "gmock/gmock.h"
16 #include "test/link/linker_fixture.h"
17 
18 namespace spvtools {
19 namespace {
20 
21 using TypeMatch = spvtest::LinkerTest;
22 
23 // Basic types
24 #define PartInt(D, N) D(N) " = OpTypeInt 32 0"
25 #define PartFloat(D, N) D(N) " = OpTypeFloat 32"
26 #define PartOpaque(D, N) D(N) " = OpTypeOpaque \"bar\""
27 #define PartSampler(D, N) D(N) " = OpTypeSampler"
28 #define PartEvent(D, N) D(N) " = OpTypeEvent"
29 #define PartDeviceEvent(D, N) D(N) " = OpTypeDeviceEvent"
30 #define PartReserveId(D, N) D(N) " = OpTypeReserveId"
31 #define PartQueue(D, N) D(N) " = OpTypeQueue"
32 #define PartPipe(D, N) D(N) " = OpTypePipe ReadWrite"
33 #define PartPipeStorage(D, N) D(N) " = OpTypePipeStorage"
34 #define PartNamedBarrier(D, N) D(N) " = OpTypeNamedBarrier"
35 
36 // Compound types
37 #define PartVector(DR, DA, N, T) DR(N) " = OpTypeVector " DA(T) " 3"
38 #define PartMatrix(DR, DA, N, T) DR(N) " = OpTypeMatrix " DA(T) " 4"
39 #define PartImage(DR, DA, N, T) \
40   DR(N) " = OpTypeImage " DA(T) " 2D 0 0 0 0 Rgba32f"
41 #define PartSampledImage(DR, DA, N, T) DR(N) " = OpTypeSampledImage " DA(T)
42 #define PartArray(DR, DA, N, T) DR(N) " = OpTypeArray " DA(T) " " DA(const)
43 #define PartRuntimeArray(DR, DA, N, T) DR(N) " = OpTypeRuntimeArray " DA(T)
44 #define PartStruct(DR, DA, N, T) DR(N) " = OpTypeStruct " DA(T) " " DA(T)
45 #define PartPointer(DR, DA, N, T) DR(N) " = OpTypePointer Workgroup " DA(T)
46 #define PartFunction(DR, DA, N, T) DR(N) " = OpTypeFunction " DA(T) " " DA(T)
47 
48 #define CheckDecoRes(S) "[[" #S ":%\\w+]]"
49 #define CheckDecoArg(S) "[[" #S "]]"
50 #define InstDeco(S) "%" #S
51 
52 #define MatchPart1(F, N) \
53   "; CHECK: " Part##F(CheckDecoRes, N) "\n" Part##F(InstDeco, N) "\n"
54 #define MatchPart2(F, N, T)                                           \
55   "; CHECK: " Part##F(CheckDecoRes, CheckDecoArg, N, T) "\n" Part##F( \
56       InstDeco, InstDeco, N, T) "\n"
57 
58 #define MatchF(N, CODE)                                         \
59   TEST_F(TypeMatch, N) {                                        \
60     const std::string base =                                    \
61         "OpCapability Linkage\n"                                \
62         "OpCapability NamedBarrier\n"                           \
63         "OpCapability PipeStorage\n"                            \
64         "OpCapability Pipes\n"                                  \
65         "OpCapability DeviceEnqueue\n"                          \
66         "OpCapability Kernel\n"                                 \
67         "OpCapability Shader\n"                                 \
68         "OpCapability Addresses\n"                              \
69         "OpDecorate %var LinkageAttributes \"foo\" "            \
70         "{Import,Export}\n"                                     \
71         "; CHECK: [[baseint:%\\w+]] = OpTypeInt 32 1\n"         \
72         "%baseint = OpTypeInt 32 1\n"                           \
73         "; CHECK: [[const:%\\w+]] = OpConstant [[baseint]] 3\n" \
74         "%const = OpConstant %baseint 3\n" CODE                 \
75         "; CHECK: OpVariable [[type]] Uniform\n"                \
76         "%var = OpVariable %type Uniform";                      \
77     ExpandAndMatch(base);                                       \
78   }
79 
80 #define Match1(T) MatchF(Type##T, MatchPart1(T, type))
81 #define Match2(T, A) \
82   MatchF(T##OfType##A, MatchPart1(A, a) MatchPart2(T, type, a))
83 #define Match3(T, A, B)   \
84   MatchF(T##Of##A##Of##B, \
85          MatchPart1(B, b) MatchPart2(A, a, b) MatchPart2(T, type, a))
86 
87 // clang-format off
88 // Basic types
89 Match1(Int)
90 Match1(Float)
91 Match1(Opaque)
92 Match1(Sampler)
93 Match1(Event)
94 Match1(DeviceEvent)
95 Match1(ReserveId)
96 Match1(Queue)
97 Match1(Pipe)
98 Match1(PipeStorage)
99 Match1(NamedBarrier)
100 
101 // Simpler (restricted) compound types
102 Match2(Vector, Float)
103 Match3(Matrix, Vector, Float)
104 Match2(Image, Float)
105 
106 // Unrestricted compound types
107 #define MatchCompounds1(A) \
108   Match2(RuntimeArray, A)  \
109   Match2(Struct, A)        \
110   Match2(Pointer, A)       \
111   Match2(Function, A)      \
112   Match2(Array, A)
113 #define MatchCompounds2(A, B) \
114   Match3(RuntimeArray, A, B)  \
115   Match3(Struct, A, B)        \
116   Match3(Pointer, A, B)       \
117   Match3(Function, A, B)      \
118   Match3(Array, A, B)
119 
120 MatchCompounds1(Float)
121 MatchCompounds2(Array, Float)
122 MatchCompounds2(RuntimeArray, Float)
123 MatchCompounds2(Struct, Float)
124 MatchCompounds2(Pointer, Float)
125 MatchCompounds2(Function, Float)
126 // clang-format on
127 
128 // ForwardPointer tests, which don't fit into the previous mold
129 #define MatchFpF(N, CODE)                                             \
130   MatchF(N,                                                           \
131          "; CHECK: OpTypeForwardPointer [[type:%\\w+]] Workgroup\n"   \
132          "OpTypeForwardPointer %type Workgroup\n" CODE                \
133          "; CHECK: [[type]] = OpTypePointer Workgroup [[realtype]]\n" \
134          "%type = OpTypePointer Workgroup %realtype\n")
135 #define MatchFp1(T) MatchFpF(ForwardPointerOf##T, MatchPart1(T, realtype))
136 #define MatchFp2(T, A) \
137   MatchFpF(ForwardPointerOf##T, MatchPart1(A, a) MatchPart2(T, realtype, a))
138 
139     // clang-format off
140 MatchFp1(Float)
141 MatchFp2(Array, Float)
142 MatchFp2(RuntimeArray, Float)
143 MatchFp2(Struct, Float)
144 MatchFp2(Function, Float)
145 // clang-format on
146 
147 }  // namespace
148 }  // namespace spvtools
149