1 
2 #include "Cleanup.hh"
3 #include "Functional.hh"
4 #include "algorithms/factor_out.hh"
5 #include "algorithms/sort_product.hh"
6 #include <map>
7 
8 using namespace cadabra;
9 
factor_out(const Kernel & k,Ex & e,Ex & args,bool right)10 factor_out::factor_out(const Kernel& k, Ex& e, Ex& args, bool right)
11 	: Algorithm(k, e), to_right(right)
12 	{
13 	cadabra::do_list(args, args.begin(), [&](Ex::iterator arg) {
14 		to_factor_out.push_back(Ex(arg));
15 		return true;
16 		}
17 	                );
18 	}
19 
20 /// Check if the expression is a sum with more than one term
can_apply(iterator st)21 bool factor_out::can_apply(iterator st)
22 	{
23 	if(*st->name=="\\sum")  return true;
24 	return false;
25 	}
26 
apply(iterator & it)27 Algorithm::result_t factor_out::apply(iterator& it)
28 	{
29 	result_t result=result_t::l_no_action;
30 
31 	// For every term in the sum, we look at the factors in the product
32 	// (or at the single object if there is no product). If this factor
33 	// needs to be factored out, we determine if it can be moved all the
34 	// way to the left of the expression. If so, move the object to
35 	// a 'factored_out' temporary, and take out of the tree. Rinse/repeat.
36 	// What's left at the end is two objects: the stuff factored out,
37 	// and the rest. Look up if we already have 'the stuff factored out'.
38 	// If not, create new. If so, add this term.
39 
40 	Ex_comparator comparator(kernel.properties);
41 
42 	typedef std::pair<Ex, std::vector<Ex> > new_term_t;
43 	std::vector<new_term_t> new_terms;
44 
45 	auto term=tr.begin(it);
46 	while(term!=tr.end(it)) {
47 		auto next_term=term;
48 		++next_term;
49 
50 		iterator prod=term;
51 		prod_wrap_single_term(prod);
52 
53 		Ex collector("\\prod"); // collect all factors that we have taken out
54 
55 		// Insert a dummy symbol at the very front or back.
56 		// FIXME: there is now a 'can_move_to_front', use that.
57 		iterator dummy;
58 		if(to_right) dummy = tr.append_child(prod, str_node("dummy"));
59 		else         dummy = tr.prepend_child(prod, str_node("dummy"));
60 
61 		// Look at all factors in turn and determine if they should be taken out.
62 		if(to_right) {
63 			auto fac=tr.end(prod);
64 			auto next=fac;
65 			--next;
66 			do {
67 				fac=next;
68 				--next;
69 				for(size_t i=0; i<to_factor_out.size(); ++i) {
70 					auto match=comparator.equal_subtree(fac, to_factor_out[i].begin());
71 					if(match==Ex_comparator::match_t::subtree_match) {
72 						int sign=comparator.can_move_adjacent(prod, dummy, fac, false);
73 						if(sign!=0) {
74 							collector.append_child(collector.begin(), iterator(fac));
75 							multiply(prod->multiplier, sign);
76 							next=tr.erase(fac);
77 							result=result_t::l_applied;
78 							break;
79 							}
80 						}
81 					}
82 				}
83 			while(fac!=tr.begin(prod));
84 			}
85 		else {
86 			auto fac=tr.begin(prod);
87 			while(fac!=tr.end(prod)) {
88 				auto next=fac;
89 				++next;
90 				for(size_t i=0; i<to_factor_out.size(); ++i) {
91 					auto match=comparator.equal_subtree(fac, to_factor_out[i].begin());
92 					if(match==Ex_comparator::match_t::subtree_match) {
93 						int sign=comparator.can_move_adjacent(prod, dummy, fac, true);
94 						if(sign!=0) {
95 							collector.append_child(collector.begin(), iterator(fac));
96 							multiply(prod->multiplier, sign);
97 							next=tr.erase(fac);
98 							result=result_t::l_applied;
99 							break;
100 							}
101 						}
102 					}
103 				fac=next;
104 				}
105 			}
106 
107 		tr.erase(dummy);
108 		if(tr.number_of_children(prod)==0)
109 			tr.append_child(prod, str_node("1"));
110 
111 		// std::cerr << "product after factoring out " << Ex(prod) << std::endl;
112 
113 		if(collector.number_of_children(collector.begin())!=0) {
114 			// The stuff factored out of this term is in 'collector'. See if we have
115 			// factored out that thing before. Because we may not always have collected
116 			// factors in the same order (the original expression may not have had
117 			// its product sorted), we first sort the collector product.
118 
119 			sort_product sp(kernel, collector);
120 			sp.dont_cleanup(); // otherwise single-factor products will get stripped of the \prod wrapper.
121 			auto coltop=collector.begin();
122 			if(sp.can_apply(coltop)) {
123 				sp.apply(coltop);
124 				}
125 			multiply(prod->multiplier, *coltop->multiplier);
126 			one(coltop->multiplier);
127 
128 			// Scan through the things factored out so far.
129 			bool found=false;
130 			for(auto& nt: new_terms) {
131 				if(nt.first==collector) { // have that factored out already, add the other factors
132 					nt.second.push_back(Ex(prod));
133 					found=true;
134 					break;
135 					}
136 				}
137 			// We hadn't factored this bit out before, make a new term.
138 			if(!found) {
139 				std::vector<Ex> v;
140 				v.push_back(Ex(prod));
141 				new_term_t nt(collector, v);
142 				new_terms.push_back(nt);
143 				}
144 
145 			// All info is now in new_terms; can remove the original.
146 			tr.erase(prod);
147 			}
148 		else {
149 			prod_unwrap_single_term(prod);
150 			}
151 		term=next_term;
152 		}
153 
154 	// Everything has been collected now into new_terms. Expand those out
155 	// into a proper sum of products.
156 
157 	for(auto& nt: new_terms) {
158 		auto prod = tr.append_child(it, nt.first.begin());
159 		if(nt.second.size()==1) { // factored, but only one term found.
160 			auto top = nt.second[0].begin(); // prod node
161 			if(to_right) {
162 				auto ins = tr.end(top);
163 				--ins;
164 				while(tr.is_valid(ins)) {
165 					tr.prepend_child(prod, iterator(ins));
166 					--ins;
167 					}
168 				}
169 			else {
170 				auto ins = tr.begin(top);
171 				while(ins!=tr.end(top)) {
172 					tr.append_child(prod, iterator(ins));
173 					++ins;
174 					}
175 				}
176 			multiply(prod->multiplier, *(nt.second[0].begin()->multiplier));
177 			// FIXME: append_children has a BUG! Messes up the tree. But it is needed to
178 			// handle terms where the sub-factor is not a simple element.
179 			//			tr.append_children(prod, nt.second[0].begin(top), nt.second[0].end(top));
180 
181 			cleanup_dispatch(kernel, tr, prod);
182 			}
183 		else {
184 			iterator sum;
185 			if(to_right)
186 				sum = tr.prepend_child(prod, str_node("\\sum"));
187 			else
188 				sum = tr.append_child(prod, str_node("\\sum"));
189 			for(auto& term: nt.second) {
190 				auto tmp = tr.append_child(sum, term.begin());
191 				cleanup_dispatch(kernel, tr, tmp);
192 				}
193 			}
194 		}
195 
196 	// std::cerr << "end of factor_out: \n" << Ex(it) << std::endl;
197 
198 	return result;
199 	}
200 
201