1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file match_exhaustion.cc
22 * \brief Checking Relay match expression exhaustiveness.
23 *
24 * This file implements a function that checks whether a match
25 * expression is exhaustive, that is, whether a given match clause
26 * matches every possible case. This is important for ensuring
27 * code correctness, since hitting an unmatched case results in a
28 * dynamic error unless exhaustiveness is checked in advance.
29 */
30 #include <tvm/relay/adt.h>
31 #include <tvm/relay/error.h>
32 #include <tvm/relay/expr_functor.h>
33 #include <tvm/relay/pattern_functor.h>
34 #include <stack>
35
36 namespace tvm {
37 namespace relay {
38
39 /*! \brief Possible pattern match results */
40 enum MatchResult : int {
41 kMatch = 0, // pattern matches
42 kClash = 1, // pattern conflicts
43 kUnspecified = 2, // ambiguous: candidate needs more constructors specified
44 };
45
46 class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
47 public:
CandidateChecker()48 explicit CandidateChecker() {}
49
Check(const Pattern & pat,const Pattern & candidate)50 MatchResult Check(const Pattern& pat, const Pattern& candidate) {
51 return this->VisitPattern(pat, candidate);
52 }
53
54 // for a constructor pattern, we must ensure that the candidate is
55 // a ConstructorPattern, that it has the same constructor, and
56 // that its fields match the subpatterns.
VisitPattern_(const PatternConstructorNode * op,const Pattern & cand)57 MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
58 auto* ctor_cand = cand.as<PatternConstructorNode>();
59 // attempting to match non-constructor to constructor pattern: need to specify
60 if (ctor_cand == nullptr) {
61 return MatchResult::kUnspecified;
62 }
63
64 // check that constructors match
65 if (!op->constructor.same_as(ctor_cand->constructor)) {
66 return MatchResult::kClash;
67 }
68
69 // now check that subpatterns match
70 CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
71 bool unspecified = false;
72 for (size_t i = 0; i < op->patterns.size(); i++) {
73 MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
74 // if we have a clash anywhere, then we can return clash
75 if (submatch == MatchResult::kClash) {
76 return MatchResult::kClash;
77 }
78 if (submatch == MatchResult::kUnspecified) {
79 unspecified = true;
80 }
81 }
82 // only return unspecified if we have ruled out a clash
83 if (unspecified) {
84 return MatchResult::kUnspecified;
85 }
86 return MatchResult::kMatch;
87 }
88
VisitPattern_(const PatternTupleNode * op,const Pattern & cand)89 MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
90 auto* tuple_cand = cand.as<PatternTupleNode>();
91 // attempting to match non-tuple to constructor pattern: need to specify
92 if (tuple_cand == nullptr) {
93 return MatchResult::kUnspecified;
94 }
95
96 // now check that subpatterns match
97 CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
98 bool unspecified = false;
99 for (size_t i = 0; i < op->patterns.size(); i++) {
100 MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
101 // if we have a clash anywhere, then we can return clash
102 if (submatch == MatchResult::kClash) {
103 return MatchResult::kClash;
104 }
105 if (submatch == MatchResult::kUnspecified) {
106 unspecified = true;
107 }
108 }
109 // only return unspecified if we have ruled out a clash
110 if (unspecified) {
111 return MatchResult::kUnspecified;
112 }
113 return MatchResult::kMatch;
114 }
115
116 // wildcard and var patterns always match
VisitPattern_(const PatternWildcardNode *,const Pattern &)117 MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
118 return MatchResult::kMatch;
119 }
120
VisitPattern_(const PatternVarNode *,const Pattern &)121 MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
122 return MatchResult::kMatch;
123 }
124 };
125
126 // Returns list of arrays corresponding to Cartesian product of input list
CartesianProduct(Array<Array<Pattern>> fields)127 Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
128 CHECK_NE(fields.size(), 0);
129 Array<Pattern> field_vals = fields[fields.size() - 1];
130 Array<Array<Pattern>> ret;
131
132 // base case: this is the last field left
133 if (fields.size() == 1) {
134 for (auto val : field_vals) {
135 ret.push_back(Array<Pattern>{val});
136 }
137 return ret;
138 }
139
140 // if we have more fields left, get the sub-candidates by getting
141 // their cartesian product and appending the elements here onto those
142 Array<Array<Pattern>> remaining_fields;
143 for (size_t i = 0; i < fields.size() - 1; i++) {
144 remaining_fields.push_back(fields[i]);
145 }
146 Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
147 for (auto val : field_vals) {
148 for (auto candidate : candidates) {
149 candidate.push_back(val);
150 ret.push_back(candidate);
151 }
152 }
153 return ret;
154 }
155
156 Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
157 const Pattern& cand,
158 const Module& mod);
159
160 Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
161 const Pattern& cand,
162 const Module& mod);
163
164 // Expands all wildcards in the candidate pattern once
165 // Returns a list of all possible expansions.
ExpandWildcards(const Pattern & clause_pat,const Pattern & cand,const Module & mod)166 Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
167 const Pattern& cand,
168 const Module& mod) {
169 if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
170 return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
171 } else {
172 return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
173 }
174 }
175
176 // Expands all wildcards in the candidate pattern once.
177 // Use the pattern to decide which constructors to insert.
178 // Returns a list of all possible expansions.
ExpandWildcardsConstructor(const PatternConstructor & clause_ctor,const Pattern & cand,const Module & mod)179 Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
180 const Pattern& cand,
181 const Module& mod) {
182 auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
183
184 // for a wildcard node, create constructor nodes with wildcards for all args.
185 if (cand.as<PatternWildcardNode>()) {
186 TypeData td = mod->LookupDef(gtv);
187 // for each constructor add a candidate.
188 Array<Pattern> ret;
189 for (auto constructor : td->constructors) {
190 Array<Pattern> args;
191 for (auto inp : constructor->inputs) {
192 args.push_back(PatternWildcardNode::make());
193 }
194 ret.push_back(PatternConstructorNode::make(constructor, args));
195 }
196 return ret;
197 }
198
199 auto ctor_cand = Downcast<PatternConstructor>(cand);
200
201 // for constructors, we will expand the wildcards in any field that is an ADT.
202 Array<Array<Pattern>> values_by_field;
203 for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
204 bool subpattern =
205 clause_ctor->patterns[i].as<PatternConstructorNode>() ||
206 clause_ctor->patterns[i].as<PatternTupleNode>();
207 // for non-ADT fields, we can only have a wildcard for the value.
208 if (!subpattern) {
209 values_by_field.push_back({PatternWildcardNode::make()});
210 } else {
211 // otherwise, recursively expand.
212 values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
213 ctor_cand->patterns[i],
214 mod));
215 }
216 }
217
218 // generate new candidates using a cartesian product.
219 auto all_subfields = CartesianProduct(values_by_field);
220 Array<Pattern> ret;
221 for (auto subfields : all_subfields) {
222 ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
223 }
224 return ret;
225 }
226
227 // Expands all wildcards in the candidate pattern once.
228 // Returns a list of all possible expansions.
ExpandWildcardsTuple(const PatternTuple & clause_tuple,const Pattern & cand,const Module & mod)229 Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
230 const Pattern& cand,
231 const Module& mod) {
232 // for a wildcard node, create constructor nodes with wildcards for all args.
233 if (cand.as<PatternWildcardNode>()) {
234 Array<Pattern> args;
235 for (auto inp : clause_tuple->patterns) {
236 args.push_back(PatternWildcardNode::make());
237 }
238 return {PatternTupleNode::make(args)};
239 }
240
241 auto tuple_cand = Downcast<PatternTuple>(cand);
242
243 // for constructors, we will expand the wildcards in any field that is an ADT.
244 Array<Array<Pattern>> values_by_field;
245 for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
246 bool subpattern =
247 clause_tuple->patterns[i].as<PatternConstructorNode>() ||
248 clause_tuple->patterns[i].as<PatternTupleNode>();
249 // for non-ADT fields, we can only have a wildcard for the value
250 if (!subpattern) {
251 values_by_field.push_back({PatternWildcardNode::make()});
252 } else {
253 // otherwise, recursively expand
254 values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
255 tuple_cand->patterns[i],
256 mod));
257 }
258 }
259
260 // generate new candidates using a cartesian product
261 auto all_subfields = CartesianProduct(values_by_field);
262 Array<Pattern> ret;
263 for (auto subfields : all_subfields) {
264 ret.push_back(PatternTupleNode::make(subfields));
265 }
266 return ret;
267 }
268
269 /*!
270 * \brief Finds cases that the match expression does not catch, if any.
271 * \return Returns a list of cases that are not handled by the match
272 * expression.
273 */
UnmatchedCases(const Match & match,const Module & mod)274 Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
275 /* algorithm:
276 * candidates = { Wildcard }
277 * while candidates not empty {
278 * cand = candidates.pop()
279 * for clause in clauses {
280 * if clause fails: next clause
281 * if clause matches candidate: next candidate
282 * if candidate is not specific enough:
283 * candidates += expand_possible_wildcards(cand)
284 * next candidate
285 * }
286 * failed_candidates += { cand }
287 * }
288 * return failed_candidates
289 */
290 std::stack<Pattern> candidates;
291 candidates.push(PatternWildcardNode::make());
292 CandidateChecker checker;
293
294 Array<Pattern> failures;
295
296 while (!candidates.empty()) {
297 Pattern cand = candidates.top();
298 candidates.pop();
299
300 bool failure = true;
301 for (auto clause : match->clauses) {
302 // if the check fails, we move on to the next
303 MatchResult check = checker.Check(clause->lhs, cand);
304 if (check == MatchResult::kClash) {
305 continue;
306 }
307
308 // either success or we need to generate more candidates;
309 // either way, we're done with this candidate
310 failure = false;
311 if (check == MatchResult::kUnspecified) {
312 auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
313 for (auto candidate : new_candidates) {
314 candidates.push(candidate);
315 }
316 }
317 break;
318 }
319
320 if (failure) {
321 failures.push_back(cand);
322 }
323 }
324
325 return failures;
326 }
327
328 // expose for testing only
329 TVM_REGISTER_API("relay._analysis.unmatched_cases")
330 .set_body_typed<Array<Pattern>(const Match&, const Module&)>(
__anon0eb961f60102(const Match& match, const Module& mod_ref) 331 [](const Match& match, const Module& mod_ref) {
332 Module call_mod = mod_ref;
333 if (!call_mod.defined()) {
334 call_mod = ModuleNode::make({}, {});
335 }
336 return UnmatchedCases(match, call_mod);
337 });
338
339 } // namespace relay
340 } // namespace tvm
341