1# cython: language_level=3 2# cython: profile=False 3# -*- coding: utf-8 -*- 4 5 6from mathics.core.expression import Expression, system_symbols, ensure_context 7from mathics.core.util import subsets, subranges, permutations 8from itertools import chain 9 10 11# from mathics.core.pattern_nocython import ( 12# StopGenerator #, Pattern #, ExpressionPattern) 13# from mathics.core import pattern_nocython 14 15 16def Pattern_create(expr): 17 from mathics.builtin import pattern_objects 18 19 # from mathics.core.pattern import AtomPattern, ExpressionPattern 20 21 name = expr.get_head_name() 22 pattern_object = pattern_objects.get(name) 23 if pattern_object is not None: 24 return pattern_object(expr) 25 if expr.is_atom(): 26 return AtomPattern(expr) 27 else: 28 return ExpressionPattern(expr) 29 30 31class StopGenerator(Exception): 32 def __init__(self, value=None): 33 self.value = value 34 35 36class StopGenerator_ExpressionPattern_match(StopGenerator): 37 pass 38 39 40class StopGenerator_Pattern(StopGenerator): 41 pass 42 43 44class Pattern(object): 45 create = staticmethod(Pattern_create) 46 47 def match( 48 self, 49 yield_func, 50 expression, 51 vars, 52 evaluation, 53 head=None, 54 leaf_index=None, 55 leaf_count=None, 56 fully=True, 57 wrap_oneid=True, 58 ): 59 raise NotImplementedError 60 61 """def match(self, expression, vars, evaluation, 62 head=None, leaf_index=None, leaf_count=None, 63 fully=True, wrap_oneid=True): 64 #raise NotImplementedError 65 result = [] 66 def yield_func(vars, rest): 67 result.append(vars, rest) 68 self._match(yield_func, expression, vars, evaluation, head, 69 leaf_index, leaf_count, fully, wrap_oneid) 70 return result""" 71 72 def does_match(self, expression, evaluation, vars=None, fully=True): 73 74 if vars is None: 75 vars = {} 76 # for sub_vars, rest in self.match( # nopep8 77 # expression, vars, evaluation, fully=fully): 78 # return True 79 80 def yield_match(sub_vars, rest): 81 raise StopGenerator_Pattern(True) 82 83 try: 84 self.match(yield_match, expression, vars, evaluation, fully=fully) 85 except StopGenerator_Pattern as exc: 86 return exc.value 87 return False 88 89 def get_name(self): 90 return self.expr.get_name() 91 92 def is_atom(self): 93 return self.expr.is_atom() 94 95 def get_head_name(self): 96 return self.expr.get_head_name() 97 98 def sameQ(self, other) -> bool: 99 """Mathics SameQ""" 100 return self.expr.sameQ(other.expr) 101 102 def get_head(self): 103 return self.expr.get_head() 104 105 def get_leaves(self): 106 return self.expr.get_leaves() 107 108 def get_sort_key(self, pattern_sort=False): 109 return self.expr.get_sort_key(pattern_sort=pattern_sort) 110 111 def get_lookup_name(self): 112 return self.expr.get_lookup_name() 113 114 def get_attributes(self, definitions): 115 return self.expr.get_attributes(definitions) 116 117 def get_sequence(self): 118 return self.expr.get_sequence() 119 120 def get_option_values(self): 121 return self.expr.get_option_values() 122 123 def has_form(self, *args): 124 return self.expr.has_form(*args) 125 126 def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}): 127 return [] 128 129 def get_match_candidates_count( 130 self, leaves, expression, attributes, evaluation, vars={} 131 ): 132 return len( 133 self.get_match_candidates(leaves, expression, attributes, evaluation, vars) 134 ) 135 136 137class AtomPattern(Pattern): 138 def __init__(self, expr): 139 self.atom = expr 140 self.expr = expr 141 142 def __repr__(self): 143 return "<AtomPattern: %s>" % self.atom 144 145 def match( 146 self, 147 yield_func, 148 expression, 149 vars, 150 evaluation, 151 head=None, 152 leaf_index=None, 153 leaf_count=None, 154 fully=True, 155 wrap_oneid=True, 156 ): 157 if expression.sameQ(self.atom): 158 # yield vars, None 159 yield_func(vars, None) 160 161 def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}): 162 return [leaf for leaf in leaves if leaf.sameQ(self.atom)] 163 164 def get_match_count(self, vars={}): 165 return (1, 1) 166 167 168# class StopGenerator_ExpressionPattern_match(StopGenerator): 169# pass 170 171 172class ExpressionPattern(Pattern): 173 # get_pre_choices = pattern_nocython.get_pre_choices 174 # match = pattern_nocython.match 175 176 def match( 177 self, 178 yield_func, 179 expression, 180 vars, 181 evaluation, 182 head=None, 183 leaf_index=None, 184 leaf_count=None, 185 fully=True, 186 wrap_oneid=True, 187 ): 188 evaluation.check_stopped() 189 190 attributes = self.head.get_attributes(evaluation.definitions) 191 if "System`Flat" not in attributes: 192 fully = True 193 if not expression.is_atom(): 194 # don't do this here, as self.get_pre_choices changes the 195 # ordering of the leaves! 196 # if self.leaves: 197 # next_leaf = self.leaves[0] 198 # next_leaves = self.leaves[1:] 199 200 def yield_choice(pre_vars): 201 next_leaf = self.leaves[0] 202 next_leaves = self.leaves[1:] 203 204 # "leading_blanks" below handles expressions with leading Blanks H[x_, y_, ...] 205 # much more efficiently by not calling get_match_candidates_count() on leaves 206 # that have already been matched with one of the leading Blanks. this approach 207 # is only valid for Expressions that are not Orderless (as with Orderless, the 208 # concept of leading items does not exist). 209 # 210 # simple performance test case: 211 # 212 # f[x_, {a__, b_}] = 0; 213 # f[x_, y_] := y + Total[x]; 214 # First[Timing[f[Range[5000], 1]]]" 215 # 216 # without "leading_blanks", Range[5000] will be tested against {a__, b_} in a 217 # call to get_match_candidates_count(), which is slow. 218 219 unmatched_leaves = expression.leaves 220 leading_blanks = "System`Orderless" not in attributes 221 222 for leaf in self.leaves: 223 match_count = leaf.get_match_count() 224 225 if leading_blanks: 226 if tuple(match_count) == ( 227 1, 228 1, 229 ): # Blank? (i.e. length exactly 1?) 230 if not unmatched_leaves: 231 raise StopGenerator_ExpressionPattern_match() 232 if not leaf.does_match( 233 unmatched_leaves[0], evaluation, pre_vars 234 ): 235 raise StopGenerator_ExpressionPattern_match() 236 unmatched_leaves = unmatched_leaves[1:] 237 else: 238 leading_blanks = False 239 240 if not leading_blanks: 241 candidates = leaf.get_match_candidates_count( 242 unmatched_leaves, 243 expression, 244 attributes, 245 evaluation, 246 pre_vars, 247 ) 248 if candidates < match_count[0]: 249 raise StopGenerator_ExpressionPattern_match() 250 251 # for new_vars, rest in self.match_leaf( # nopep8 252 # self.leaves[0], self.leaves[1:], ([], expression.leaves), 253 # pre_vars, expression, attributes, evaluation, first=True, 254 # fully=fully, leaf_count=len(self.leaves), 255 # wrap_oneid=expression.get_head_name() != 'System`MakeBoxes'): 256 # def yield_leaf(new_vars, rest): 257 # yield_func(new_vars, rest) 258 self.match_leaf( 259 yield_func, 260 next_leaf, 261 next_leaves, 262 ([], expression.leaves), 263 pre_vars, 264 expression, 265 attributes, 266 evaluation, 267 first=True, 268 fully=fully, 269 leaf_count=len(self.leaves), 270 wrap_oneid=expression.get_head_name() != "System`MakeBoxes", 271 ) 272 273 # for head_vars, _ in self.head.match(expression.get_head(), vars, 274 # evaluation): 275 def yield_head(head_vars, _): 276 if self.leaves: 277 # pre_choices = self.get_pre_choices( 278 # expression, attributes, head_vars) 279 # for pre_vars in pre_choices: 280 281 self.get_pre_choices( 282 yield_choice, expression, attributes, head_vars 283 ) 284 else: 285 if not expression.leaves: 286 yield_func(head_vars, None) 287 else: 288 return 289 290 try: 291 self.head.match(yield_head, expression.get_head(), vars, evaluation) 292 except StopGenerator_ExpressionPattern_match: 293 return 294 if ( 295 wrap_oneid 296 and not evaluation.ignore_oneidentity 297 and "System`OneIdentity" in attributes 298 and expression.get_head() != self.head # nopep8 299 and expression != self.head 300 ): 301 # and 'OneIdentity' not in 302 # (expression.get_attributes(evaluation.definitions) | 303 # expression.get_head().get_attributes(evaluation.definitions)): 304 new_expression = Expression(self.head, expression) 305 for leaf in self.leaves: 306 leaf.match_count = leaf.get_match_count() 307 leaf.candidates = [expression] 308 # leaf.get_match_candidates( 309 # new_expression.leaves, new_expression, attributes, 310 # evaluation, vars) 311 if len(leaf.candidates) < leaf.match_count[0]: 312 return 313 # for new_vars, rest in self.match_leaf( 314 # self.leaves[0], self.leaves[1:], 315 # ([], [expression]), vars, new_expression, attributes, 316 # evaluation, first=True, fully=fully, 317 # leaf_count=len(self.leaves), wrap_oneid=True): 318 # def yield_leaf(new_vars, rest): 319 # yield_func(new_vars, rest) 320 self.match_leaf( 321 yield_func, 322 self.leaves[0], 323 self.leaves[1:], 324 ([], [expression]), 325 vars, 326 new_expression, 327 attributes, 328 evaluation, 329 first=True, 330 fully=fully, 331 leaf_count=len(self.leaves), 332 wrap_oneid=True, 333 ) 334 335 def get_pre_choices(self, yield_func, expression, attributes, vars): 336 if "System`Orderless" in attributes: 337 self.sort() 338 patterns = self.filter_leaves("Pattern") 339 groups = {} 340 prev_pattern = prev_name = None 341 for pattern in patterns: 342 name = pattern.leaves[0].get_name() 343 existing = vars.get(name, None) 344 if existing is None: 345 # There's no need for pre-choices if the variable is 346 # already set. 347 if name == prev_name: 348 if name in groups: 349 groups[name].append(pattern) 350 else: 351 groups[name] = [prev_pattern, pattern] 352 prev_pattern = pattern 353 prev_name = name 354 # prev_leaf = None 355 356 # count duplicate leaves 357 expr_groups = {} 358 for leaf in expression.leaves: 359 expr_groups[leaf] = expr_groups.get(leaf, 0) + 1 360 361 def per_name(yield_name, groups, vars): 362 """ 363 Yields possible variable settings (dictionaries) for the 364 remaining pattern groups 365 """ 366 367 if groups: 368 name, patterns = groups[0] 369 370 match_count = [0, None] 371 for pattern in patterns: 372 sub_match_count = pattern.get_match_count() 373 if sub_match_count[0] > match_count[0]: 374 match_count[0] = sub_match_count[0] 375 if match_count[1] is None or ( 376 sub_match_count[1] is not None 377 and sub_match_count[1] < match_count[1] 378 ): 379 match_count[1] = sub_match_count[1] 380 # possibilities = [{}] 381 # sum = 0 382 383 def per_expr(yield_expr, expr_groups, sum=0): 384 """ 385 Yields possible values (sequence lists) for the current 386 variable (name) taking into account the 387 (expression, count)'s in expr_groups 388 """ 389 390 if expr_groups: 391 expr, count = expr_groups.popitem() 392 max_per_pattern = count // len(patterns) 393 for per_pattern in range(max_per_pattern, -1, -1): 394 for next in per_expr( # nopep8 395 expr_groups, sum + per_pattern 396 ): 397 yield_expr([expr] * per_pattern + next) 398 else: 399 if sum >= match_count[0]: 400 yield_expr([]) 401 # Until we learn that the below is incorrect, we'll return basically no match. 402 yield None 403 404 # for sequence in per_expr(expr_groups.items()): 405 def yield_expr(sequence): 406 # FIXME: this call is wrong and needs a 407 # wrapper_function as the 1st parameter. 408 wrappings = self.get_wrappings( 409 sequence, match_count[1], expression, attributes 410 ) 411 for wrapping in wrappings: 412 # for next in per_name(groups[1:], vars): 413 def yield_next(next): 414 setting = next.copy() 415 setting[name] = wrapping 416 yield_name(setting) 417 418 per_name(yield_next, groups[1:], vars) 419 420 per_expr(yield_expr, expr_groups) 421 else: # no groups left 422 yield_name(vars) 423 424 # for setting in per_name(groups.items(), vars): 425 # def yield_name(setting): 426 # yield_func(setting) 427 per_name(yield_func, list(groups.items()), vars) 428 else: 429 yield_func(vars) 430 431 def __init__(self, expr): 432 self.head = Pattern.create(expr.head) 433 self.leaves = [Pattern.create(leaf) for leaf in expr.leaves] 434 self.expr = expr 435 436 def filter_leaves(self, head_name): 437 head_name = ensure_context(head_name) 438 return [leaf for leaf in self.leaves if leaf.get_head_name() == head_name] 439 440 def __repr__(self): 441 return "<ExpressionPattern: %s>" % self.expr 442 443 def get_match_count(self, vars={}): 444 return (1, 1) 445 446 def get_wrappings( 447 self, 448 yield_func, 449 items, 450 max_count, 451 expression, 452 attributes, 453 include_flattened=True, 454 ): 455 if len(items) == 1: 456 yield_func(items[0]) 457 else: 458 if max_count is None or len(items) <= max_count: 459 if "System`Orderless" in attributes: 460 for perm in permutations(items): 461 sequence = Expression("Sequence", *perm) 462 sequence.pattern_sequence = True 463 yield_func(sequence) 464 else: 465 sequence = Expression("Sequence", *items) 466 sequence.pattern_sequence = True 467 yield_func(sequence) 468 if "System`Flat" in attributes and include_flattened: 469 yield_func(Expression(expression.get_head(), *items)) 470 471 def match_leaf( 472 self, 473 yield_func, 474 leaf, 475 rest_leaves, 476 rest_expression, 477 vars, 478 expression, 479 attributes, 480 evaluation, 481 leaf_index=1, 482 leaf_count=None, 483 first=False, 484 fully=True, 485 depth=1, 486 wrap_oneid=True, 487 ): 488 489 if rest_expression is None: 490 rest_expression = ([], []) 491 492 evaluation.check_stopped() 493 494 match_count = leaf.get_match_count(vars) 495 leaf_candidates = leaf.get_match_candidates( 496 rest_expression[1], # leaf.candidates, 497 expression, 498 attributes, 499 evaluation, 500 vars, 501 ) 502 503 if len(leaf_candidates) < match_count[0]: 504 return 505 506 candidates = rest_expression[1] 507 508 # "Artificially" only use more leaves than specified for some kind 509 # of pattern. 510 # TODO: This could be further optimized! 511 try_flattened = ("System`Flat" in attributes) and ( 512 leaf.get_head_name() 513 in ( 514 system_symbols( 515 "Pattern", 516 "PatternTest", 517 "Condition", 518 "Optional", 519 "Blank", 520 "BlankSequence", 521 "BlankNullSequence", 522 "Alternatives", 523 "OptionsPattern", 524 "Repeated", 525 "RepeatedNull", 526 ) 527 ) 528 ) 529 530 if try_flattened: 531 set_lengths = (match_count[0], None) 532 else: 533 set_lengths = match_count 534 535 # try_flattened is used later to decide whether wrapping of leaves 536 # into one operand may occur. 537 # This can of course also be when flat and same head. 538 try_flattened = try_flattened or ( 539 ("System`Flat" in attributes) and leaf.get_head() == expression.head 540 ) 541 542 less_first = len(rest_leaves) > 0 543 544 if "System`Orderless" in attributes: 545 # we only want leaf_candidates to be a set if we're orderless. 546 # otherwise, constructing a set() is very slow for large lists. 547 # performance test case: 548 # x = Range[100000]; Timing[Combinatorica`BinarySearch[x, 100]] 549 leaf_candidates = set(leaf_candidates) # for fast lookup 550 551 sets = None 552 if leaf.get_head_name() == "System`Pattern": 553 varname = leaf.leaves[0].get_name() 554 existing = vars.get(varname, None) 555 if existing is not None: 556 head = existing.get_head() 557 if head.get_name() == "System`Sequence" or ( 558 "System`Flat" in attributes and head == expression.get_head() 559 ): 560 needed = existing.leaves 561 else: 562 needed = [existing] 563 available = candidates[:] 564 for needed_leaf in needed: 565 if ( 566 needed_leaf in available 567 and needed_leaf in leaf_candidates # nopep8 568 ): 569 available.remove(needed_leaf) 570 else: 571 return 572 sets = [(needed, ([], available))] 573 574 if sets is None: 575 sets = subsets( 576 candidates, 577 included=leaf_candidates, 578 less_first=less_first, 579 *set_lengths 580 ) 581 else: 582 sets = subranges( 583 candidates, 584 flexible_start=first and not fully, 585 included=leaf_candidates, 586 less_first=less_first, 587 *set_lengths 588 ) 589 590 if rest_leaves: 591 next_leaf = rest_leaves[0] 592 next_rest_leaves = rest_leaves[1:] 593 next_depth = depth + 1 594 next_index = leaf_index + 1 595 596 for items, items_rest in sets: 597 # Include wrappings like Plus[a, b] only if not all items taken 598 # - in that case we would match the same expression over and over. 599 600 include_flattened = try_flattened and 0 < len(items) < len( 601 expression.leaves 602 ) 603 604 # Don't try flattened when the expression would remain the same! 605 606 def leaf_yield(next_vars, next_rest): 607 # if next_rest is None: 608 # next_rest = ([], []) 609 # yield_func(next_vars, (rest_expression[0] + items_rest[0], 610 # next_rest[1])) 611 if next_rest is None: 612 yield_func( 613 next_vars, (list(chain(rest_expression[0], items_rest[0])), []) 614 ) 615 else: 616 yield_func( 617 next_vars, 618 (list(chain(rest_expression[0], items_rest[0])), next_rest[1]), 619 ) 620 621 def match_yield(new_vars, _): 622 if rest_leaves: 623 self.match_leaf( 624 leaf_yield, 625 next_leaf, 626 next_rest_leaves, 627 items_rest, 628 new_vars, 629 expression, 630 attributes, 631 evaluation, 632 fully=fully, 633 depth=next_depth, 634 leaf_index=next_index, 635 leaf_count=leaf_count, 636 wrap_oneid=wrap_oneid, 637 ) 638 else: 639 if not fully or (not items_rest[0] and not items_rest[1]): 640 yield_func(new_vars, items_rest) 641 642 def yield_wrapping(item): 643 leaf.match( 644 match_yield, 645 item, 646 vars, 647 evaluation, 648 fully=True, 649 head=expression.head, 650 leaf_index=leaf_index, 651 leaf_count=leaf_count, 652 wrap_oneid=wrap_oneid, 653 ) 654 655 self.get_wrappings( 656 yield_wrapping, 657 items, 658 match_count[1], 659 expression, 660 attributes, 661 include_flattened=include_flattened, 662 ) 663 664 def get_match_candidates(self, leaves, expression, attributes, evaluation, vars={}): 665 """ 666 Finds possible leaves that could match the pattern, ignoring future 667 pattern variable definitions, but taking into account already fixed 668 variables. 669 """ 670 # TODO: fixed_vars! 671 672 return [leaf for leaf in leaves if self.does_match(leaf, evaluation, vars)] 673 674 def get_match_candidates_count( 675 self, leaves, expression, attributes, evaluation, vars={} 676 ): 677 """ 678 Finds possible leaves that could match the pattern, ignoring future 679 pattern variable definitions, but taking into account already fixed 680 variables. 681 """ 682 # TODO: fixed_vars! 683 684 count = 0 685 for leaf in leaves: 686 if self.does_match(leaf, evaluation, vars): 687 count += 1 688 return count 689 690 def sort(self): 691 self.leaves.sort(key=lambda e: e.get_sort_key(pattern_sort=True)) 692