1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file match_exhaustion.cc
22  * \brief Checking Relay match expression exhaustiveness.
23  *
24  * This file implements a function that checks whether a match
25  * expression is exhaustive, that is, whether a given match clause
26  * matches every possible case. This is important for ensuring
27  * code correctness, since hitting an unmatched case results in a
28  * dynamic error unless exhaustiveness is checked in advance.
29  */
30 #include <tvm/relay/adt.h>
31 #include <tvm/relay/error.h>
32 #include <tvm/relay/expr_functor.h>
33 #include <tvm/relay/pattern_functor.h>
34 #include <stack>
35 
36 namespace tvm {
37 namespace relay {
38 
39 /*! \brief Possible pattern match results */
40 enum MatchResult : int {
41   kMatch = 0,        // pattern matches
42   kClash = 1,        // pattern conflicts
43   kUnspecified = 2,  // ambiguous: candidate needs more constructors specified
44 };
45 
46 class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const Pattern&)> {
47  public:
CandidateChecker()48   explicit CandidateChecker() {}
49 
Check(const Pattern & pat,const Pattern & candidate)50   MatchResult Check(const Pattern& pat, const Pattern& candidate) {
51     return this->VisitPattern(pat, candidate);
52   }
53 
54   // for a constructor pattern, we must ensure that the candidate is
55   // a ConstructorPattern, that it has the same constructor, and
56   // that its fields match the subpatterns.
VisitPattern_(const PatternConstructorNode * op,const Pattern & cand)57   MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override {
58     auto* ctor_cand = cand.as<PatternConstructorNode>();
59     // attempting to match non-constructor to constructor pattern: need to specify
60     if (ctor_cand == nullptr) {
61       return MatchResult::kUnspecified;
62     }
63 
64     // check that constructors match
65     if (!op->constructor.same_as(ctor_cand->constructor)) {
66       return MatchResult::kClash;
67     }
68 
69     // now check that subpatterns match
70     CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
71     bool unspecified = false;
72     for (size_t i = 0; i < op->patterns.size(); i++) {
73       MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
74       // if we have a clash anywhere, then we can return clash
75       if (submatch == MatchResult::kClash) {
76         return MatchResult::kClash;
77       }
78       if (submatch == MatchResult::kUnspecified) {
79         unspecified = true;
80       }
81     }
82     // only return unspecified if we have ruled out a clash
83     if (unspecified) {
84       return MatchResult::kUnspecified;
85     }
86     return MatchResult::kMatch;
87   }
88 
VisitPattern_(const PatternTupleNode * op,const Pattern & cand)89   MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
90     auto* tuple_cand = cand.as<PatternTupleNode>();
91     // attempting to match non-tuple to constructor pattern: need to specify
92     if (tuple_cand == nullptr) {
93       return MatchResult::kUnspecified;
94     }
95 
96     // now check that subpatterns match
97     CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
98     bool unspecified = false;
99     for (size_t i = 0; i < op->patterns.size(); i++) {
100       MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
101       // if we have a clash anywhere, then we can return clash
102       if (submatch == MatchResult::kClash) {
103         return MatchResult::kClash;
104       }
105       if (submatch == MatchResult::kUnspecified) {
106         unspecified = true;
107       }
108     }
109     // only return unspecified if we have ruled out a clash
110     if (unspecified) {
111       return MatchResult::kUnspecified;
112     }
113     return MatchResult::kMatch;
114   }
115 
116   // wildcard and var patterns always match
VisitPattern_(const PatternWildcardNode *,const Pattern &)117   MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
118     return MatchResult::kMatch;
119   }
120 
VisitPattern_(const PatternVarNode *,const Pattern &)121   MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override {
122     return MatchResult::kMatch;
123   }
124 };
125 
126 // Returns list of arrays corresponding to Cartesian product of input list
CartesianProduct(Array<Array<Pattern>> fields)127 Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
128   CHECK_NE(fields.size(), 0);
129   Array<Pattern> field_vals = fields[fields.size() - 1];
130   Array<Array<Pattern>> ret;
131 
132   // base case: this is the last field left
133   if (fields.size() == 1) {
134     for (auto val : field_vals) {
135       ret.push_back(Array<Pattern>{val});
136     }
137     return ret;
138   }
139 
140   // if we have more fields left, get the sub-candidates by getting
141   // their cartesian product and appending the elements here onto those
142   Array<Array<Pattern>> remaining_fields;
143   for (size_t i = 0; i < fields.size() - 1; i++) {
144     remaining_fields.push_back(fields[i]);
145   }
146   Array<Array<Pattern>> candidates = CartesianProduct(remaining_fields);
147   for (auto val : field_vals) {
148     for (auto candidate : candidates) {
149       candidate.push_back(val);
150       ret.push_back(candidate);
151     }
152   }
153   return ret;
154 }
155 
156 Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
157                                           const Pattern& cand,
158                                           const Module& mod);
159 
160 Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
161                                     const Pattern& cand,
162                                     const Module& mod);
163 
164 // Expands all wildcards in the candidate pattern once
165 // Returns a list of all possible expansions.
ExpandWildcards(const Pattern & clause_pat,const Pattern & cand,const Module & mod)166 Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
167                                const Pattern& cand,
168                                const Module& mod) {
169   if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
170     return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
171   } else {
172     return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
173   }
174 }
175 
176 // Expands all wildcards in the candidate pattern once.
177 // Use the pattern to decide which constructors to insert.
178 // Returns a list of all possible expansions.
ExpandWildcardsConstructor(const PatternConstructor & clause_ctor,const Pattern & cand,const Module & mod)179 Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
180                                           const Pattern& cand,
181                                           const Module& mod) {
182   auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
183 
184   // for a wildcard node, create constructor nodes with wildcards for all args.
185   if (cand.as<PatternWildcardNode>()) {
186     TypeData td = mod->LookupDef(gtv);
187     // for each constructor add a candidate.
188     Array<Pattern> ret;
189     for (auto constructor : td->constructors) {
190       Array<Pattern> args;
191       for (auto inp : constructor->inputs) {
192         args.push_back(PatternWildcardNode::make());
193       }
194       ret.push_back(PatternConstructorNode::make(constructor, args));
195     }
196     return ret;
197   }
198 
199   auto ctor_cand = Downcast<PatternConstructor>(cand);
200 
201   // for constructors, we will expand the wildcards in any field that is an ADT.
202   Array<Array<Pattern>> values_by_field;
203   for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
204     bool subpattern =
205       clause_ctor->patterns[i].as<PatternConstructorNode>() ||
206       clause_ctor->patterns[i].as<PatternTupleNode>();
207     // for non-ADT fields, we can only have a wildcard for the value.
208     if (!subpattern) {
209       values_by_field.push_back({PatternWildcardNode::make()});
210     } else {
211       // otherwise, recursively expand.
212       values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
213                                                 ctor_cand->patterns[i],
214                                                 mod));
215     }
216   }
217 
218   // generate new candidates using a cartesian product.
219   auto all_subfields = CartesianProduct(values_by_field);
220   Array<Pattern> ret;
221   for (auto subfields : all_subfields) {
222     ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
223   }
224   return ret;
225 }
226 
227 // Expands all wildcards in the candidate pattern once.
228 // Returns a list of all possible expansions.
ExpandWildcardsTuple(const PatternTuple & clause_tuple,const Pattern & cand,const Module & mod)229 Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
230                                     const Pattern& cand,
231                                     const Module& mod) {
232   // for a wildcard node, create constructor nodes with wildcards for all args.
233   if (cand.as<PatternWildcardNode>()) {
234     Array<Pattern> args;
235     for (auto inp : clause_tuple->patterns) {
236       args.push_back(PatternWildcardNode::make());
237     }
238     return {PatternTupleNode::make(args)};
239   }
240 
241   auto tuple_cand = Downcast<PatternTuple>(cand);
242 
243   // for constructors, we will expand the wildcards in any field that is an ADT.
244   Array<Array<Pattern>> values_by_field;
245   for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
246     bool subpattern =
247       clause_tuple->patterns[i].as<PatternConstructorNode>() ||
248       clause_tuple->patterns[i].as<PatternTupleNode>();
249     // for non-ADT fields, we can only have a wildcard for the value
250     if (!subpattern) {
251       values_by_field.push_back({PatternWildcardNode::make()});
252     } else {
253       // otherwise, recursively expand
254       values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
255                                                 tuple_cand->patterns[i],
256                                                 mod));
257     }
258   }
259 
260   // generate new candidates using a cartesian product
261   auto all_subfields = CartesianProduct(values_by_field);
262   Array<Pattern> ret;
263   for (auto subfields : all_subfields) {
264     ret.push_back(PatternTupleNode::make(subfields));
265   }
266   return ret;
267 }
268 
269 /*!
270  * \brief Finds cases that the match expression does not catch, if any.
271  * \return Returns a list of cases that are not handled by the match
272  * expression.
273  */
UnmatchedCases(const Match & match,const Module & mod)274 Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
275   /* algorithm:
276    * candidates = { Wildcard }
277    * while candidates not empty {
278    *   cand = candidates.pop()
279    *   for clause in clauses {
280    *     if clause fails: next clause
281    *     if clause matches candidate: next candidate
282    *     if candidate is not specific enough:
283    *        candidates += expand_possible_wildcards(cand)
284    *        next candidate
285    *   }
286    *   failed_candidates += { cand }
287    * }
288    * return failed_candidates
289    */
290   std::stack<Pattern> candidates;
291   candidates.push(PatternWildcardNode::make());
292   CandidateChecker checker;
293 
294   Array<Pattern> failures;
295 
296   while (!candidates.empty()) {
297     Pattern cand = candidates.top();
298     candidates.pop();
299 
300     bool failure = true;
301     for (auto clause : match->clauses) {
302       // if the check fails, we move on to the next
303       MatchResult check = checker.Check(clause->lhs, cand);
304       if (check == MatchResult::kClash) {
305         continue;
306       }
307 
308       // either success or we need to generate more candidates;
309       // either way, we're done with this candidate
310       failure = false;
311       if (check == MatchResult::kUnspecified) {
312         auto new_candidates = ExpandWildcards(clause->lhs, cand, mod);
313         for (auto candidate : new_candidates) {
314           candidates.push(candidate);
315         }
316       }
317       break;
318     }
319 
320     if (failure) {
321       failures.push_back(cand);
322     }
323   }
324 
325   return failures;
326 }
327 
328 // expose for testing only
329 TVM_REGISTER_API("relay._analysis.unmatched_cases")
330 .set_body_typed<Array<Pattern>(const Match&, const Module&)>(
__anon0eb961f60102(const Match& match, const Module& mod_ref) 331   [](const Match& match, const Module& mod_ref) {
332     Module call_mod = mod_ref;
333     if (!call_mod.defined()) {
334       call_mod = ModuleNode::make({}, {});
335     }
336     return UnmatchedCases(match, call_mod);
337   });
338 
339 }  // namespace relay
340 }  // namespace tvm
341