1
2 #include "Cleanup.hh"
3 #include "Functional.hh"
4 #include "algorithms/factor_out.hh"
5 #include "algorithms/sort_product.hh"
6 #include <map>
7
8 using namespace cadabra;
9
factor_out(const Kernel & k,Ex & e,Ex & args,bool right)10 factor_out::factor_out(const Kernel& k, Ex& e, Ex& args, bool right)
11 : Algorithm(k, e), to_right(right)
12 {
13 cadabra::do_list(args, args.begin(), [&](Ex::iterator arg) {
14 to_factor_out.push_back(Ex(arg));
15 return true;
16 }
17 );
18 }
19
20 /// Check if the expression is a sum with more than one term
can_apply(iterator st)21 bool factor_out::can_apply(iterator st)
22 {
23 if(*st->name=="\\sum") return true;
24 return false;
25 }
26
apply(iterator & it)27 Algorithm::result_t factor_out::apply(iterator& it)
28 {
29 result_t result=result_t::l_no_action;
30
31 // For every term in the sum, we look at the factors in the product
32 // (or at the single object if there is no product). If this factor
33 // needs to be factored out, we determine if it can be moved all the
34 // way to the left of the expression. If so, move the object to
35 // a 'factored_out' temporary, and take out of the tree. Rinse/repeat.
36 // What's left at the end is two objects: the stuff factored out,
37 // and the rest. Look up if we already have 'the stuff factored out'.
38 // If not, create new. If so, add this term.
39
40 Ex_comparator comparator(kernel.properties);
41
42 typedef std::pair<Ex, std::vector<Ex> > new_term_t;
43 std::vector<new_term_t> new_terms;
44
45 auto term=tr.begin(it);
46 while(term!=tr.end(it)) {
47 auto next_term=term;
48 ++next_term;
49
50 iterator prod=term;
51 prod_wrap_single_term(prod);
52
53 Ex collector("\\prod"); // collect all factors that we have taken out
54
55 // Insert a dummy symbol at the very front or back.
56 // FIXME: there is now a 'can_move_to_front', use that.
57 iterator dummy;
58 if(to_right) dummy = tr.append_child(prod, str_node("dummy"));
59 else dummy = tr.prepend_child(prod, str_node("dummy"));
60
61 // Look at all factors in turn and determine if they should be taken out.
62 if(to_right) {
63 auto fac=tr.end(prod);
64 auto next=fac;
65 --next;
66 do {
67 fac=next;
68 --next;
69 for(size_t i=0; i<to_factor_out.size(); ++i) {
70 auto match=comparator.equal_subtree(fac, to_factor_out[i].begin());
71 if(match==Ex_comparator::match_t::subtree_match) {
72 int sign=comparator.can_move_adjacent(prod, dummy, fac, false);
73 if(sign!=0) {
74 collector.append_child(collector.begin(), iterator(fac));
75 multiply(prod->multiplier, sign);
76 next=tr.erase(fac);
77 result=result_t::l_applied;
78 break;
79 }
80 }
81 }
82 }
83 while(fac!=tr.begin(prod));
84 }
85 else {
86 auto fac=tr.begin(prod);
87 while(fac!=tr.end(prod)) {
88 auto next=fac;
89 ++next;
90 for(size_t i=0; i<to_factor_out.size(); ++i) {
91 auto match=comparator.equal_subtree(fac, to_factor_out[i].begin());
92 if(match==Ex_comparator::match_t::subtree_match) {
93 int sign=comparator.can_move_adjacent(prod, dummy, fac, true);
94 if(sign!=0) {
95 collector.append_child(collector.begin(), iterator(fac));
96 multiply(prod->multiplier, sign);
97 next=tr.erase(fac);
98 result=result_t::l_applied;
99 break;
100 }
101 }
102 }
103 fac=next;
104 }
105 }
106
107 tr.erase(dummy);
108 if(tr.number_of_children(prod)==0)
109 tr.append_child(prod, str_node("1"));
110
111 // std::cerr << "product after factoring out " << Ex(prod) << std::endl;
112
113 if(collector.number_of_children(collector.begin())!=0) {
114 // The stuff factored out of this term is in 'collector'. See if we have
115 // factored out that thing before. Because we may not always have collected
116 // factors in the same order (the original expression may not have had
117 // its product sorted), we first sort the collector product.
118
119 sort_product sp(kernel, collector);
120 sp.dont_cleanup(); // otherwise single-factor products will get stripped of the \prod wrapper.
121 auto coltop=collector.begin();
122 if(sp.can_apply(coltop)) {
123 sp.apply(coltop);
124 }
125 multiply(prod->multiplier, *coltop->multiplier);
126 one(coltop->multiplier);
127
128 // Scan through the things factored out so far.
129 bool found=false;
130 for(auto& nt: new_terms) {
131 if(nt.first==collector) { // have that factored out already, add the other factors
132 nt.second.push_back(Ex(prod));
133 found=true;
134 break;
135 }
136 }
137 // We hadn't factored this bit out before, make a new term.
138 if(!found) {
139 std::vector<Ex> v;
140 v.push_back(Ex(prod));
141 new_term_t nt(collector, v);
142 new_terms.push_back(nt);
143 }
144
145 // All info is now in new_terms; can remove the original.
146 tr.erase(prod);
147 }
148 else {
149 prod_unwrap_single_term(prod);
150 }
151 term=next_term;
152 }
153
154 // Everything has been collected now into new_terms. Expand those out
155 // into a proper sum of products.
156
157 for(auto& nt: new_terms) {
158 auto prod = tr.append_child(it, nt.first.begin());
159 if(nt.second.size()==1) { // factored, but only one term found.
160 auto top = nt.second[0].begin(); // prod node
161 if(to_right) {
162 auto ins = tr.end(top);
163 --ins;
164 while(tr.is_valid(ins)) {
165 tr.prepend_child(prod, iterator(ins));
166 --ins;
167 }
168 }
169 else {
170 auto ins = tr.begin(top);
171 while(ins!=tr.end(top)) {
172 tr.append_child(prod, iterator(ins));
173 ++ins;
174 }
175 }
176 multiply(prod->multiplier, *(nt.second[0].begin()->multiplier));
177 // FIXME: append_children has a BUG! Messes up the tree. But it is needed to
178 // handle terms where the sub-factor is not a simple element.
179 // tr.append_children(prod, nt.second[0].begin(top), nt.second[0].end(top));
180
181 cleanup_dispatch(kernel, tr, prod);
182 }
183 else {
184 iterator sum;
185 if(to_right)
186 sum = tr.prepend_child(prod, str_node("\\sum"));
187 else
188 sum = tr.append_child(prod, str_node("\\sum"));
189 for(auto& term: nt.second) {
190 auto tmp = tr.append_child(sum, term.begin());
191 cleanup_dispatch(kernel, tr, tmp);
192 }
193 }
194 }
195
196 // std::cerr << "end of factor_out: \n" << Ex(it) << std::endl;
197
198 return result;
199 }
200
201