1 
2 #include "Cleanup.hh"
3 #include "algorithms/replace_match.hh"
4 #include "algorithms/substitute.hh"
5 
6 using namespace cadabra;
7 
replace_match(const Kernel & k,Ex & e)8 replace_match::replace_match(const Kernel& k, Ex& e)
9 	: Algorithm(k, e)
10 	{
11 	}
12 
can_apply(iterator)13 bool replace_match::can_apply(iterator)
14 	{
15 	if(tr.history_size()>0) return true;
16 	return false;
17 	}
18 
apply(iterator & it)19 Algorithm::result_t replace_match::apply(iterator& it)
20 	{
21 	// Preserve the expression before popping. After this, the 'tr' is
22 	// the original expression from before 'take_match'.
23 	Ex current(tr);
24 	auto to_keep=tr.pop_history();
25 	if(to_keep.size()==0) {
26 		return result_t::l_applied;
27 		}
28 
29 	// Remove the terms which we will replace, by converting
30 	// the 'to_keep' paths above to iterators, then removing.
31 	iterator sum_node = tr.parent(tr.iterator_from_path(to_keep[0], tr.begin()));
32 	std::vector<iterator> to_erase;
33 	for(const auto& p: to_keep)
34 		to_erase.push_back( tr.iterator_from_path(p, tr.begin()) );
35 	for(auto& erase: to_erase)
36 		tr.erase(erase);
37 
38 	// If the replacement is zero, there is nothing to substitute.
39 	if(!current.begin()->is_zero()) {
40 
41 		// We already have an iterator to the sum node in the now-current
42 		// expression (sum_node). We also need one to the sum node in the
43 		// replacement sum.
44 		iterator replacement_sum_node = current.iterator_from_path(tr.path_from_iterator(sum_node, tr.begin()), current.begin());
45 
46 		// If the original sum has disappeared (because subsequent manipulations
47 		// made all but one terms vanish), wrap it again in a sum.
48 		if(*replacement_sum_node->name!="\\sum")
49 			replacement_sum_node = current.wrap(replacement_sum_node, str_node("\\sum"));
50 
51 		// If we are inside an integral, determine the \int multiplier in the original
52 		// and in the replacement.
53 		multiplier_t rescale=1;
54 		if(!tr.is_head(it) && *tr.parent(it)->name=="\\int") {
55 			multiplier_t orig_mult = *tr.parent(it)->multiplier;
56 			multiplier_t repl_mult = *current.parent(replacement_sum_node)->multiplier;
57 			rescale = repl_mult/orig_mult;
58 			}
59 
60 		sibling_iterator repit=current.begin(replacement_sum_node);
61 		while(repit!=current.end(replacement_sum_node)) {
62 			multiply( tr.append_child(sum_node, iterator(repit))->multiplier, rescale);
63 			++repit;
64 			}
65 		}
66 
67 	cleanup_dispatch(kernel, tr, it);
68 	return result_t::l_applied;
69 	}
70 
71