1%
2% generate explicit CPTs
3%
4:- module(clpbn_aggregates, [
5        check_for_agg_vars/2,
6	cpt_average/6,
7	cpt_average/7,
8	cpt_max/6,
9	cpt_min/6
10	]).
11
12:- use_module(library(clpbn), [{}/1]).
13
14:- use_module(library(lists),
15	[last/2,
16	sumlist/2,
17	sum_list/3,
18	max_list/2,
19	min_list/2,
20	nth0/3
21    ]).
22
23:- use_module(library(matrix),
24	[matrix_new/3,
25	matrix_to_list/2,
26	matrix_set/3]).
27
28:- use_module(library('clpbn/dists'),
29	      [
30	       dist/4,
31	       get_dist_domain_size/2]).
32
33:- use_module(library('clpbn/matrix_cpt_utils'),
34	[normalise_CPT_on_lines/3]).
35
36check_for_agg_vars([], []).
37check_for_agg_vars([V|Vs0], [V|Vs1]) :-
38	clpbn:get_atts(V, [key(K), dist(Id,Parents)]), !,
39	simplify_dist(Id, V, K, Parents, Vs0, Vs00),
40	check_for_agg_vars(Vs00, Vs1).
41check_for_agg_vars([_|Vs0], Vs1) :-
42	check_for_agg_vars(Vs0, Vs1).
43
44% transform aggregate distribution into tree
45simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
46	cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
47	dist(NewDist, Id, Key, ParentsF),
48	clpbn:put_atts(V, [dist(Id,ParentsF)]).
49simplify_dist(_, _, _, _, Vs0, Vs0).
50
51cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
52	cpt_average(AllVars, Key, Els0, 1.0, Tab, Vs, NewVs).
53
54% support variables with evidence from domain. This should make everyone's life easier.
55cpt_average([Ev|Vars], Key, Els0, Softness, p(Els0, CPT, NewParents), Vs, NewVs) :-
56	find_evidence(Vars, 0, TotEvidence, RVars),
57	build_avg_table(RVars, Vars, Els0, Key, TotEvidence, Softness, MAT0, NewParents0, Vs, IVs),
58	include_qevidence(Ev, MAT0, MAT, NewParents0, NewParents, Vs, IVs, NewVs),
59	matrix_to_list(MAT, CPT), writeln(NewParents: Vs: NewVs: CPT).
60
61% find all fixed kids, this simplifies significantly the function.
62find_evidence([], TotEvidence, TotEvidence, []).
63find_evidence([V|Vars], TotEvidence0, TotEvidence, RVars) :-
64	clpbn:get_atts(V,[evidence(Ev)]), !,
65	TotEvidenceI is TotEvidence0+Ev,
66	find_evidence(Vars, TotEvidenceI, TotEvidence, RVars).
67find_evidence([V|Vars], TotEvidence0, TotEvidence, [V|RVars]) :-
68	find_evidence(Vars, TotEvidence0, TotEvidence, RVars).
69
70cpt_max([_|Vars], Key, Els0, CPT, Vs, NewVs) :-
71	build_max_table(Vars, Els0, Els0, Key, 1.0, CPT, Vs, NewVs).
72
73cpt_min([_|Vars], Key, Els0, CPT, Vs, NewVs) :-
74	build_min_table(Vars, Els0, Els0, Key, 1.0, CPT, Vs, NewVs).
75
76build_avg_table(Vars, OVars, Domain, _, TotEvidence, Softness, CPT, Vars, Vs, Vs) :-
77	length(Domain, SDomain),
78	int_power(Vars, SDomain, 1, TabSize),
79	TabSize =< 256,
80	/* case gmp is not there !! */
81	TabSize > 0, !,
82	average_cpt(Vars, OVars, Domain, TotEvidence,  Softness, CPT).
83build_avg_table(Vars, OVars, Domain, Key, TotEvidence, Softness, CPT, [V1,V2], Vs, [V1,V2|NewVs]) :-
84	length(Vars,L),
85	LL1 is L//2,
86	LL2 is L-LL1,
87	list_split(LL1, Vars, L1, L2),
88	Min = 0,
89	length(Domain,Max1), Max is Max1-1,
90	build_intermediate_table(LL1, sum(Min,Max), L1, V1, Key,  1.0, 0, I1, Vs, Vs1),
91	build_intermediate_table(LL2, sum(Min,Max), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
92	average_cpt([V1,V2], OVars, Domain, TotEvidence, Softness, CPT).
93
94build_max_table(Vars, Domain, Softness, p(Domain, CPT, Vars), Vs, Vs) :-
95	length(Domain, SDomain),
96	int_power(Vars, SDomain, 1, TabSize),
97	TabSize =< 16,
98	/* case gmp is not there !! */
99	TabSize > 0, !,
100	max_cpt(Vars, Domain, Softness, CPT).
101build_max_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2]), Vs, [V1,V2|NewVs]) :-
102	length(Vars,L),
103	LL1 is L//2,
104	LL2 is L-LL1,
105	list_split(LL1, Vars, L1, L2),
106	build_intermediate_table(LL1, max(Domain,CPT), L1, V1, Key, 1.0,  0, I1, Vs, Vs1),
107	build_intermediate_table(LL2, max(Domain,CPT), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
108	max_cpt([V1,V2], Domain, Softness, CPT).
109
110build_min_table(Vars, Domain, Softness, p(Domain, CPT, Vars), Vs, Vs) :-
111	length(Domain, SDomain),
112	int_power(Vars, SDomain, 1, TabSize),
113	TabSize =< 16,
114	/* case gmp is not there !! */
115	TabSize > 0, !,
116	min_cpt(Vars, Domain, Softness, CPT).
117build_min_table(Vars, Domain, Softness, p(Domain, CPT, [V1,V2]), Vs, [V1,V2|NewVs]) :-
118	length(Vars,L),
119	LL1 is L//2,
120	LL2 is L-LL1,
121	list_split(LL1, Vars, L1, L2),
122	build_intermediate_table(LL1, min(Domain,CPT), L1, V1, Key, 1.0,  0, I1, Vs, Vs1),
123	build_intermediate_table(LL2, min(Domain,CPT), L2, V2, Key, 1.0, I1, _, Vs1, NewVs),
124	min_cpt([V1,V2], Domain, Softness, CPT).
125
126int_power([], _, TabSize, TabSize).
127int_power([_|L], X, I0, TabSize) :-
128	I is I0*X,
129	int_power(L, X, I, TabSize).
130
131build_intermediate_table(1,_,[V],V, _, _, I, I, Vs, Vs) :- !.
132build_intermediate_table(2, Op, [V1,V2], V, Key, Softness, I0, If, Vs, Vs) :- !,
133	If is I0+1,
134	generate_tmp_random(Op, 2, [V1,V2], V, Key, Softness, I0).
135build_intermediate_table(N, Op, L, V, Key, Softness, I0, If, Vs, [V1,V2|NewVs]) :-
136	LL1 is N//2,
137	LL2 is N-LL1,
138	list_split(LL1, L, L1, L2),
139	I1 is I0+1,
140	build_intermediate_table(LL1, Op, L1, V1, Key, Softness, I1, I2, Vs, Vs1),
141	build_intermediate_table(LL2, Op, L2, V2, Key, Softness, I2, If, Vs1, NewVs),
142	generate_tmp_random(Op, N, [V1,V2], V, Key, Softness, I0).
143
144% averages are transformed into sums.
145generate_tmp_random(sum(Min,Max), N, [V1,V2], V, Key, Softness, I) :-
146	Lower is Min*N,
147	Upper is Max*N,
148	generate_list(Lower, Upper, Nbs),
149	sum_cpt([V1,V2], Nbs, Softness, CPT),
150%	write(sum(Nbs, CPT, [V1,V2])),nl, % debugging
151	{ V = 'AVG'(I,Key) with p(Nbs,CPT,[V1,V2]) }.
152generate_tmp_random(max(Domain,CPT), _, [V1,V2], V, Key, I) :-
153	{ V = 'MAX'(I,Key) with p(Domain,CPT,[V1,V2]) }.
154generate_tmp_random(min(Domain,CPT), _, [V1,V2], V, Key, I) :-
155	{ V = 'MIN'(I,Key) with p(Domain,CPT,[V1,V2]) }.
156
157generate_list(M, M, [M]) :- !.
158generate_list(I, M, [I|Nbs]) :-
159	I1 is I+1,
160	generate_list(I1, M, Nbs).
161
162list_split(0, L, [], L) :- !.
163list_split(I, [H|L], [H|L1], L2) :-
164	I1 is I-1,
165	list_split(I1, L, L1, L2).
166
167%
168% if we have evidence, we need to check if we are always consistent, never consistent, or can be consistent
169%
170include_qevidence(V, MAT0, MAT, NewParents0, NewParents, Vs, IVs, NewVs) :-
171	clpbn:get_atts(V,[evidence(Ev)]), !,
172	normalise_CPT_on_lines(MAT0, MAT1, L1),
173	check_consistency(L1, Ev, MAT0, MAT1, L1, MAT, NewParents0, NewParents, Vs, IVs, NewVs).
174include_qevidence(_, MAT, MAT, NewParents, NewParents, _, Vs, Vs).
175
176check_consistency(L1, Ev, MAT0, MAT1, L1, MAT, NewParents0, NewParents, Vs, IVs, NewVs) :-
177	sumlist(L1, Tot),
178	nth0(Ev, L1, Val),
179	(Val == Tot ->
180writeln(Ev:L1:Val:1),
181	    MAT1 = MAT,
182	    NewParents = [],
183	    Vs = NewVs
184	;
185	 Val == 0.0 ->
186writeln(Ev:L1:Val:2),
187	    throw(error(domain_error(incompatible_evidence),evidence(Ev)))
188	;
189writeln(Ev:L1:Val:3),
190	    MAT0 = MAT,
191	    NewParents = NewParents0,
192	    IVs = NewVs
193	).
194
195
196%
197% generate actual table, instead of trusting the solver
198%
199
200average_cpt(Vs, OVars, Vals, Base, _, MCPT) :-
201	get_ds_lengths(Vs,Lengs),
202	length(OVars, N),
203	length(Vals, SVals),
204	matrix_new(floats,[SVals|Lengs],MCPT),
205	fill_in_average(Lengs,N,Base,MCPT).
206
207get_ds_lengths([],[]).
208get_ds_lengths([V|Vs],[Sz|Lengs]) :-
209	get_vdist_size(V, Sz),
210	get_ds_lengths(Vs,Lengs).
211
212fill_in_average(Lengs, N, Base, MCPT) :-
213	generate(Lengs, Case),
214	average(Case, N, Base, Val),
215	matrix_set(MCPT,[Val|Case],1.0),
216	fail.
217fill_in_average(_,_,_,_).
218
219generate([], []).
220generate([N|Lengs], [C|Case]) :-
221	from(0,N,C),
222	generate(Lengs, Case).
223
224from(I,_,I).
225from(I1,M,J) :-
226	I is I1+1,
227	I < M,
228	from(I,M,J).
229
230average(Case, N, Base, Val) :-
231	sum_list(Case, Base, Tot),
232	Val is integer(round(Tot/N)).
233
234
235sum_cpt(Vs,Vals,_,CPT) :-
236	get_ds_lengths(Vs,Lengs),
237	length(Vals,SVals),
238	matrix_new(floats,[SVals|Lengs],MCPT),
239	fill_in_sum(Lengs,MCPT),
240	matrix_to_list(MCPT,CPT).
241
242fill_in_sum(Lengs,MCPT) :-
243	generate(Lengs, Case),
244	sumlist(Case, Val),
245	matrix_set(MCPT,[Val|Case],1.0),
246	fail.
247fill_in_sum(_,_).
248
249
250max_cpt(Vs,Vals,_,CPT) :-
251	get_ds_lengths(Vs,Lengs),
252	length(Vals,SVals),
253	matrix_new(floats,[SVals|Lengs],MCPT),
254	fill_in_max(Lengs,MCPT),
255	matrix_to_list(MCPT,CPT).
256
257fill_in_max(Lengs,MCPT) :-
258	generate(Lengs, Case),
259	max_list(Case, Val),
260	matrix_set(MCPT,[Val|Case],1.0),
261	fail.
262fill_in_max(_,_).
263
264
265min_cpt(Vs,Vals,_,CPT) :-
266	get_ds_lengths(Vs,Lengs),
267	length(Vals,SVals),
268	matrix_new(floats,[SVals|Lengs],MCPT),
269	fill_in_max(Lengs,MCPT),
270	matrix_to_list(MCPT,CPT).
271
272fill_in_min(Lengs,MCPT) :-
273	generate(Lengs, Case),
274	max_list(Case, Val),
275	matrix_set(MCPT,[Val|Case],1.0),
276	fail.
277fill_in_min(_,_).
278
279
280get_vdist_size(V, Sz) :-
281	clpbn:get_atts(V, [dist(Dist,_)]),
282	get_dist_domain_size(Dist, Sz).
283
284