1
2 #include "Cleanup.hh"
3 #include "algorithms/replace_match.hh"
4 #include "algorithms/substitute.hh"
5
6 using namespace cadabra;
7
replace_match(const Kernel & k,Ex & e)8 replace_match::replace_match(const Kernel& k, Ex& e)
9 : Algorithm(k, e)
10 {
11 }
12
can_apply(iterator)13 bool replace_match::can_apply(iterator)
14 {
15 if(tr.history_size()>0) return true;
16 return false;
17 }
18
apply(iterator & it)19 Algorithm::result_t replace_match::apply(iterator& it)
20 {
21 // Preserve the expression before popping. After this, the 'tr' is
22 // the original expression from before 'take_match'.
23 Ex current(tr);
24 auto to_keep=tr.pop_history();
25 if(to_keep.size()==0) {
26 return result_t::l_applied;
27 }
28
29 // Remove the terms which we will replace, by converting
30 // the 'to_keep' paths above to iterators, then removing.
31 iterator sum_node = tr.parent(tr.iterator_from_path(to_keep[0], tr.begin()));
32 std::vector<iterator> to_erase;
33 for(const auto& p: to_keep)
34 to_erase.push_back( tr.iterator_from_path(p, tr.begin()) );
35 for(auto& erase: to_erase)
36 tr.erase(erase);
37
38 // If the replacement is zero, there is nothing to substitute.
39 if(!current.begin()->is_zero()) {
40
41 // We already have an iterator to the sum node in the now-current
42 // expression (sum_node). We also need one to the sum node in the
43 // replacement sum.
44 iterator replacement_sum_node = current.iterator_from_path(tr.path_from_iterator(sum_node, tr.begin()), current.begin());
45
46 // If the original sum has disappeared (because subsequent manipulations
47 // made all but one terms vanish), wrap it again in a sum.
48 if(*replacement_sum_node->name!="\\sum")
49 replacement_sum_node = current.wrap(replacement_sum_node, str_node("\\sum"));
50
51 // If we are inside an integral, determine the \int multiplier in the original
52 // and in the replacement.
53 multiplier_t rescale=1;
54 if(!tr.is_head(it) && *tr.parent(it)->name=="\\int") {
55 multiplier_t orig_mult = *tr.parent(it)->multiplier;
56 multiplier_t repl_mult = *current.parent(replacement_sum_node)->multiplier;
57 rescale = repl_mult/orig_mult;
58 }
59
60 sibling_iterator repit=current.begin(replacement_sum_node);
61 while(repit!=current.end(replacement_sum_node)) {
62 multiply( tr.append_child(sum_node, iterator(repit))->multiplier, rescale);
63 ++repit;
64 }
65 }
66
67 cleanup_dispatch(kernel, tr, it);
68 return result_t::l_applied;
69 }
70
71