1# cython: language_level=3
2# cython: profile=False
3# -*- coding: utf-8 -*-
4
5
6from mathics.core.expression import Expression, system_symbols, ensure_context
7from mathics.core.util import subsets, subranges, permutations
8from itertools import chain
9
10
11# from mathics.core.pattern_nocython import (
12#    StopGenerator #, Pattern #, ExpressionPattern)
13# from mathics.core import pattern_nocython
14
15
16def Pattern_create(expr):
17    from mathics.builtin import pattern_objects
18
19    # from mathics.core.pattern import AtomPattern, ExpressionPattern
20
21    name = expr.get_head_name()
22    pattern_object = pattern_objects.get(name)
23    if pattern_object is not None:
24        return pattern_object(expr)
25    if expr.is_atom():
26        return AtomPattern(expr)
27    else:
28        return ExpressionPattern(expr)
29
30
31class StopGenerator(Exception):
32    def __init__(self, value=None):
33        self.value = value
34
35
36class StopGenerator_ExpressionPattern_match(StopGenerator):
37    pass
38
39
40class StopGenerator_Pattern(StopGenerator):
41    pass
42
43
44class Pattern(object):
45    create = staticmethod(Pattern_create)
46
47    def match(
48        self,
49        yield_func,
50        expression,
51        vars,
52        evaluation,
53        head=None,
54        leaf_index=None,
55        leaf_count=None,
56        fully=True,
57        wrap_oneid=True,
58    ):
59        raise NotImplementedError
60
61    """def match(self, expression, vars, evaluation,
62              head=None, leaf_index=None, leaf_count=None,
63        fully=True, wrap_oneid=True):
64        #raise NotImplementedError
65        result = []
66        def yield_func(vars, rest):
67            result.append(vars, rest)
68        self._match(yield_func, expression, vars, evaluation, head,
69                    leaf_index, leaf_count, fully, wrap_oneid)
70        return result"""
71
72    def does_match(self, expression, evaluation, vars=None, fully=True):
73
74        if vars is None:
75            vars = {}
76        # for sub_vars, rest in self.match(  # nopep8
77        #    expression, vars, evaluation, fully=fully):
78        #    return True
79
80        def yield_match(sub_vars, rest):
81            raise StopGenerator_Pattern(True)
82
83        try:
84            self.match(yield_match, expression, vars, evaluation, fully=fully)
85        except StopGenerator_Pattern as exc:
86            return exc.value
87        return False
88
89    def get_name(self):
90        return self.expr.get_name()
91
92    def is_atom(self):
93        return self.expr.is_atom()
94
95    def get_head_name(self):
96        return self.expr.get_head_name()
97
98    def sameQ(self, other) -> bool:
99        """Mathics SameQ"""
100        return self.expr.sameQ(other.expr)
101
102    def get_head(self):
103        return self.expr.get_head()
104
105    def get_leaves(self):
106        return self.expr.get_leaves()
107
108    def get_sort_key(self, pattern_sort=False):
109        return self.expr.get_sort_key(pattern_sort=pattern_sort)
110
111    def get_lookup_name(self):
112        return self.expr.get_lookup_name()
113
114    def get_attributes(self, definitions):
115        return self.expr.get_attributes(definitions)
116
117    def get_sequence(self):
118        return self.expr.get_sequence()
119
120    def get_option_values(self):
121        return self.expr.get_option_values()
122
123    def has_form(self, *args):
124        return self.expr.has_form(*args)
125
126    def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}):
127        return []
128
129    def get_match_candidates_count(
130        self, leaves, expression, attributes, evaluation, vars={}
131    ):
132        return len(
133            self.get_match_candidates(leaves, expression, attributes, evaluation, vars)
134        )
135
136
137class AtomPattern(Pattern):
138    def __init__(self, expr):
139        self.atom = expr
140        self.expr = expr
141
142    def __repr__(self):
143        return "<AtomPattern: %s>" % self.atom
144
145    def match(
146        self,
147        yield_func,
148        expression,
149        vars,
150        evaluation,
151        head=None,
152        leaf_index=None,
153        leaf_count=None,
154        fully=True,
155        wrap_oneid=True,
156    ):
157        if expression.sameQ(self.atom):
158            # yield vars, None
159            yield_func(vars, None)
160
161    def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}):
162        return [leaf for leaf in leaves if leaf.sameQ(self.atom)]
163
164    def get_match_count(self, vars={}):
165        return (1, 1)
166
167
168# class StopGenerator_ExpressionPattern_match(StopGenerator):
169#    pass
170
171
172class ExpressionPattern(Pattern):
173    # get_pre_choices = pattern_nocython.get_pre_choices
174    # match = pattern_nocython.match
175
176    def match(
177        self,
178        yield_func,
179        expression,
180        vars,
181        evaluation,
182        head=None,
183        leaf_index=None,
184        leaf_count=None,
185        fully=True,
186        wrap_oneid=True,
187    ):
188        evaluation.check_stopped()
189
190        attributes = self.head.get_attributes(evaluation.definitions)
191        if "System`Flat" not in attributes:
192            fully = True
193        if not expression.is_atom():
194            # don't do this here, as self.get_pre_choices changes the
195            # ordering of the leaves!
196            # if self.leaves:
197            #    next_leaf = self.leaves[0]
198            #    next_leaves = self.leaves[1:]
199
200            def yield_choice(pre_vars):
201                next_leaf = self.leaves[0]
202                next_leaves = self.leaves[1:]
203
204                # "leading_blanks" below handles expressions with leading Blanks H[x_, y_, ...]
205                # much more efficiently by not calling get_match_candidates_count() on leaves
206                # that have already been matched with one of the leading Blanks. this approach
207                # is only valid for Expressions that are not Orderless (as with Orderless, the
208                # concept of leading items does not exist).
209                #
210                # simple performance test case:
211                #
212                # f[x_, {a__, b_}] = 0;
213                # f[x_, y_] := y + Total[x];
214                # First[Timing[f[Range[5000], 1]]]"
215                #
216                # without "leading_blanks", Range[5000] will be tested against {a__, b_} in a
217                # call to get_match_candidates_count(), which is slow.
218
219                unmatched_leaves = expression.leaves
220                leading_blanks = "System`Orderless" not in attributes
221
222                for leaf in self.leaves:
223                    match_count = leaf.get_match_count()
224
225                    if leading_blanks:
226                        if tuple(match_count) == (
227                            1,
228                            1,
229                        ):  # Blank? (i.e. length exactly 1?)
230                            if not unmatched_leaves:
231                                raise StopGenerator_ExpressionPattern_match()
232                            if not leaf.does_match(
233                                unmatched_leaves[0], evaluation, pre_vars
234                            ):
235                                raise StopGenerator_ExpressionPattern_match()
236                            unmatched_leaves = unmatched_leaves[1:]
237                        else:
238                            leading_blanks = False
239
240                    if not leading_blanks:
241                        candidates = leaf.get_match_candidates_count(
242                            unmatched_leaves,
243                            expression,
244                            attributes,
245                            evaluation,
246                            pre_vars,
247                        )
248                        if candidates < match_count[0]:
249                            raise StopGenerator_ExpressionPattern_match()
250
251                # for new_vars, rest in self.match_leaf(    # nopep8
252                #    self.leaves[0], self.leaves[1:], ([], expression.leaves),
253                #    pre_vars, expression, attributes, evaluation, first=True,
254                #    fully=fully, leaf_count=len(self.leaves),
255                #    wrap_oneid=expression.get_head_name() != 'System`MakeBoxes'):
256                # def yield_leaf(new_vars, rest):
257                #    yield_func(new_vars, rest)
258                self.match_leaf(
259                    yield_func,
260                    next_leaf,
261                    next_leaves,
262                    ([], expression.leaves),
263                    pre_vars,
264                    expression,
265                    attributes,
266                    evaluation,
267                    first=True,
268                    fully=fully,
269                    leaf_count=len(self.leaves),
270                    wrap_oneid=expression.get_head_name() != "System`MakeBoxes",
271                )
272
273            # for head_vars, _ in self.head.match(expression.get_head(), vars,
274            # evaluation):
275            def yield_head(head_vars, _):
276                if self.leaves:
277                    # pre_choices = self.get_pre_choices(
278                    #    expression, attributes, head_vars)
279                    # for pre_vars in pre_choices:
280
281                    self.get_pre_choices(
282                        yield_choice, expression, attributes, head_vars
283                    )
284                else:
285                    if not expression.leaves:
286                        yield_func(head_vars, None)
287                    else:
288                        return
289
290            try:
291                self.head.match(yield_head, expression.get_head(), vars, evaluation)
292            except StopGenerator_ExpressionPattern_match:
293                return
294        if (
295            wrap_oneid
296            and not evaluation.ignore_oneidentity
297            and "System`OneIdentity" in attributes
298            and expression.get_head() != self.head  # nopep8
299            and expression != self.head
300        ):
301            # and 'OneIdentity' not in
302            # (expression.get_attributes(evaluation.definitions) |
303            # expression.get_head().get_attributes(evaluation.definitions)):
304            new_expression = Expression(self.head, expression)
305            for leaf in self.leaves:
306                leaf.match_count = leaf.get_match_count()
307                leaf.candidates = [expression]
308                # leaf.get_match_candidates(
309                #    new_expression.leaves, new_expression, attributes,
310                #    evaluation, vars)
311                if len(leaf.candidates) < leaf.match_count[0]:
312                    return
313            # for new_vars, rest in self.match_leaf(
314            #    self.leaves[0], self.leaves[1:],
315            #    ([], [expression]), vars, new_expression, attributes,
316            #    evaluation, first=True, fully=fully,
317            #    leaf_count=len(self.leaves), wrap_oneid=True):
318            # def yield_leaf(new_vars, rest):
319            #    yield_func(new_vars, rest)
320            self.match_leaf(
321                yield_func,
322                self.leaves[0],
323                self.leaves[1:],
324                ([], [expression]),
325                vars,
326                new_expression,
327                attributes,
328                evaluation,
329                first=True,
330                fully=fully,
331                leaf_count=len(self.leaves),
332                wrap_oneid=True,
333            )
334
335    def get_pre_choices(self, yield_func, expression, attributes, vars):
336        if "System`Orderless" in attributes:
337            self.sort()
338            patterns = self.filter_leaves("Pattern")
339            groups = {}
340            prev_pattern = prev_name = None
341            for pattern in patterns:
342                name = pattern.leaves[0].get_name()
343                existing = vars.get(name, None)
344                if existing is None:
345                    # There's no need for pre-choices if the variable is
346                    # already set.
347                    if name == prev_name:
348                        if name in groups:
349                            groups[name].append(pattern)
350                        else:
351                            groups[name] = [prev_pattern, pattern]
352                    prev_pattern = pattern
353                    prev_name = name
354            # prev_leaf = None
355
356            # count duplicate leaves
357            expr_groups = {}
358            for leaf in expression.leaves:
359                expr_groups[leaf] = expr_groups.get(leaf, 0) + 1
360
361            def per_name(yield_name, groups, vars):
362                """
363                Yields possible variable settings (dictionaries) for the
364                remaining pattern groups
365                """
366
367                if groups:
368                    name, patterns = groups[0]
369
370                    match_count = [0, None]
371                    for pattern in patterns:
372                        sub_match_count = pattern.get_match_count()
373                        if sub_match_count[0] > match_count[0]:
374                            match_count[0] = sub_match_count[0]
375                        if match_count[1] is None or (
376                            sub_match_count[1] is not None
377                            and sub_match_count[1] < match_count[1]
378                        ):
379                            match_count[1] = sub_match_count[1]
380                    # possibilities = [{}]
381                    # sum = 0
382
383                    def per_expr(yield_expr, expr_groups, sum=0):
384                        """
385                        Yields possible values (sequence lists) for the current
386                        variable (name) taking into account the
387                        (expression, count)'s in expr_groups
388                        """
389
390                        if expr_groups:
391                            expr, count = expr_groups.popitem()
392                            max_per_pattern = count // len(patterns)
393                            for per_pattern in range(max_per_pattern, -1, -1):
394                                for next in per_expr(  # nopep8
395                                    expr_groups, sum + per_pattern
396                                ):
397                                    yield_expr([expr] * per_pattern + next)
398                        else:
399                            if sum >= match_count[0]:
400                                yield_expr([])
401                            # Until we learn that the below is incorrect, we'll return basically no match.
402                            yield None
403
404                    # for sequence in per_expr(expr_groups.items()):
405                    def yield_expr(sequence):
406                        # FIXME: this call is wrong and needs a
407                        # wrapper_function as the 1st parameter.
408                        wrappings = self.get_wrappings(
409                            sequence, match_count[1], expression, attributes
410                        )
411                        for wrapping in wrappings:
412                            # for next in per_name(groups[1:], vars):
413                            def yield_next(next):
414                                setting = next.copy()
415                                setting[name] = wrapping
416                                yield_name(setting)
417
418                            per_name(yield_next, groups[1:], vars)
419
420                    per_expr(yield_expr, expr_groups)
421                else:  # no groups left
422                    yield_name(vars)
423
424            # for setting in per_name(groups.items(), vars):
425            # def yield_name(setting):
426            #    yield_func(setting)
427            per_name(yield_func, list(groups.items()), vars)
428        else:
429            yield_func(vars)
430
431    def __init__(self, expr):
432        self.head = Pattern.create(expr.head)
433        self.leaves = [Pattern.create(leaf) for leaf in expr.leaves]
434        self.expr = expr
435
436    def filter_leaves(self, head_name):
437        head_name = ensure_context(head_name)
438        return [leaf for leaf in self.leaves if leaf.get_head_name() == head_name]
439
440    def __repr__(self):
441        return "<ExpressionPattern: %s>" % self.expr
442
443    def get_match_count(self, vars={}):
444        return (1, 1)
445
446    def get_wrappings(
447        self,
448        yield_func,
449        items,
450        max_count,
451        expression,
452        attributes,
453        include_flattened=True,
454    ):
455        if len(items) == 1:
456            yield_func(items[0])
457        else:
458            if max_count is None or len(items) <= max_count:
459                if "System`Orderless" in attributes:
460                    for perm in permutations(items):
461                        sequence = Expression("Sequence", *perm)
462                        sequence.pattern_sequence = True
463                        yield_func(sequence)
464                else:
465                    sequence = Expression("Sequence", *items)
466                    sequence.pattern_sequence = True
467                    yield_func(sequence)
468            if "System`Flat" in attributes and include_flattened:
469                yield_func(Expression(expression.get_head(), *items))
470
471    def match_leaf(
472        self,
473        yield_func,
474        leaf,
475        rest_leaves,
476        rest_expression,
477        vars,
478        expression,
479        attributes,
480        evaluation,
481        leaf_index=1,
482        leaf_count=None,
483        first=False,
484        fully=True,
485        depth=1,
486        wrap_oneid=True,
487    ):
488
489        if rest_expression is None:
490            rest_expression = ([], [])
491
492        evaluation.check_stopped()
493
494        match_count = leaf.get_match_count(vars)
495        leaf_candidates = leaf.get_match_candidates(
496            rest_expression[1],  # leaf.candidates,
497            expression,
498            attributes,
499            evaluation,
500            vars,
501        )
502
503        if len(leaf_candidates) < match_count[0]:
504            return
505
506        candidates = rest_expression[1]
507
508        # "Artificially" only use more leaves than specified for some kind
509        # of pattern.
510        # TODO: This could be further optimized!
511        try_flattened = ("System`Flat" in attributes) and (
512            leaf.get_head_name()
513            in (
514                system_symbols(
515                    "Pattern",
516                    "PatternTest",
517                    "Condition",
518                    "Optional",
519                    "Blank",
520                    "BlankSequence",
521                    "BlankNullSequence",
522                    "Alternatives",
523                    "OptionsPattern",
524                    "Repeated",
525                    "RepeatedNull",
526                )
527            )
528        )
529
530        if try_flattened:
531            set_lengths = (match_count[0], None)
532        else:
533            set_lengths = match_count
534
535        # try_flattened is used later to decide whether wrapping of leaves
536        # into one operand may occur.
537        # This can of course also be when flat and same head.
538        try_flattened = try_flattened or (
539            ("System`Flat" in attributes) and leaf.get_head() == expression.head
540        )
541
542        less_first = len(rest_leaves) > 0
543
544        if "System`Orderless" in attributes:
545            # we only want leaf_candidates to be a set if we're orderless.
546            # otherwise, constructing a set() is very slow for large lists.
547            # performance test case:
548            # x = Range[100000]; Timing[Combinatorica`BinarySearch[x, 100]]
549            leaf_candidates = set(leaf_candidates)  # for fast lookup
550
551            sets = None
552            if leaf.get_head_name() == "System`Pattern":
553                varname = leaf.leaves[0].get_name()
554                existing = vars.get(varname, None)
555                if existing is not None:
556                    head = existing.get_head()
557                    if head.get_name() == "System`Sequence" or (
558                        "System`Flat" in attributes and head == expression.get_head()
559                    ):
560                        needed = existing.leaves
561                    else:
562                        needed = [existing]
563                    available = candidates[:]
564                    for needed_leaf in needed:
565                        if (
566                            needed_leaf in available
567                            and needed_leaf in leaf_candidates  # nopep8
568                        ):
569                            available.remove(needed_leaf)
570                        else:
571                            return
572                    sets = [(needed, ([], available))]
573
574            if sets is None:
575                sets = subsets(
576                    candidates,
577                    included=leaf_candidates,
578                    less_first=less_first,
579                    *set_lengths
580                )
581        else:
582            sets = subranges(
583                candidates,
584                flexible_start=first and not fully,
585                included=leaf_candidates,
586                less_first=less_first,
587                *set_lengths
588            )
589
590        if rest_leaves:
591            next_leaf = rest_leaves[0]
592            next_rest_leaves = rest_leaves[1:]
593        next_depth = depth + 1
594        next_index = leaf_index + 1
595
596        for items, items_rest in sets:
597            # Include wrappings like Plus[a, b] only if not all items taken
598            # - in that case we would match the same expression over and over.
599
600            include_flattened = try_flattened and 0 < len(items) < len(
601                expression.leaves
602            )
603
604            # Don't try flattened when the expression would remain the same!
605
606            def leaf_yield(next_vars, next_rest):
607                # if next_rest is None:
608                #    next_rest = ([], [])
609                # yield_func(next_vars, (rest_expression[0] + items_rest[0],
610                # next_rest[1]))
611                if next_rest is None:
612                    yield_func(
613                        next_vars, (list(chain(rest_expression[0], items_rest[0])), [])
614                    )
615                else:
616                    yield_func(
617                        next_vars,
618                        (list(chain(rest_expression[0], items_rest[0])), next_rest[1]),
619                    )
620
621            def match_yield(new_vars, _):
622                if rest_leaves:
623                    self.match_leaf(
624                        leaf_yield,
625                        next_leaf,
626                        next_rest_leaves,
627                        items_rest,
628                        new_vars,
629                        expression,
630                        attributes,
631                        evaluation,
632                        fully=fully,
633                        depth=next_depth,
634                        leaf_index=next_index,
635                        leaf_count=leaf_count,
636                        wrap_oneid=wrap_oneid,
637                    )
638                else:
639                    if not fully or (not items_rest[0] and not items_rest[1]):
640                        yield_func(new_vars, items_rest)
641
642            def yield_wrapping(item):
643                leaf.match(
644                    match_yield,
645                    item,
646                    vars,
647                    evaluation,
648                    fully=True,
649                    head=expression.head,
650                    leaf_index=leaf_index,
651                    leaf_count=leaf_count,
652                    wrap_oneid=wrap_oneid,
653                )
654
655            self.get_wrappings(
656                yield_wrapping,
657                items,
658                match_count[1],
659                expression,
660                attributes,
661                include_flattened=include_flattened,
662            )
663
664    def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}):
665        """
666        Finds possible leaves that could match the pattern, ignoring future
667        pattern variable definitions, but taking into account already fixed
668        variables.
669        """
670        # TODO: fixed_vars!
671
672        return [leaf for leaf in leaves if self.does_match(leaf, evaluation, vars)]
673
674    def get_match_candidates_count(
675        self, leaves, expression, attributes, evaluation, vars={}
676    ):
677        """
678        Finds possible leaves that could match the pattern, ignoring future
679        pattern variable definitions, but taking into account already fixed
680        variables.
681        """
682        # TODO: fixed_vars!
683
684        count = 0
685        for leaf in leaves:
686            if self.does_match(leaf, evaluation, vars):
687                count += 1
688        return count
689
690    def sort(self):
691        self.leaves.sort(key=lambda e: e.get_sort_key(pattern_sort=True))
692