1
2%:- style_check(all).
3
4:- module(viterbi, [viterbi/4]).
5
6:- use_module(library(lists),
7	      [nth/3,
8	       member/2]).
9
10:- use_module(library(assoc)).
11
12:- use_module(library(dgraphs)).
13
14:- use_module(library(matrix)).
15
16:- use_module(library(clpbn), []).
17
18:- ensure_loaded(library('clpbn/hmm')).
19
20:- use_module(library('clpbn/dists'), [
21	get_dist_params/2]).
22
23:- meta_predicate viterbi(:,:,+,-).
24
25
26viterbi(Start,End,String,Trace) :-
27	init_hmm,
28	Start,
29	mk_graph(NOfNodes, Map, ViterbiCode),
30	compile_trace(String, Emissions),
31	get_id(Start, Map, SI),
32	get_id(End, Map, EI),
33	% add a random symbol in front (for the c/1 state).
34	compiled_viterbi(NOfNodes, SI, ViterbiCode, Emissions, Dump, L),
35	backtrace(Dump, EI, Map, L, Trace).
36
37state_from_goal(_:Start,S) :-
38	state_from_goal(Start,S).
39state_from_goal(Start,S) :-
40	functor(Start, N, Ar),
41	% get rid of position and random var
42	NAr is Ar-2,
43	functor(S, N, NAr).
44
45
46mk_graph(NOfNodes, Map, ViterbiCode) :-
47	attributes:all_attvars(Vars0),
48	empty_assoc(KeyMap0),
49	get_graph(Vars0, Nodes, Edges, KeyMap0, KeyMap),
50	dgraph_new(G0),
51	dgraph_add_vertices(G0, Nodes, G1),
52	dgraph_add_edges(G1, Edges, G2),
53	dgraph_top_sort(G2, SortedNodes),
54	compile_viterbi(SortedNodes, KeyMap, NOfNodes, Map, ViterbiCode).
55
56get_graph([V|Vs], [NKey|Keys], EdgesF, KeyMap0, KeyMap) :-
57	clpbn:get_atts(V,[key(Key), dist(Id,Parents)]),
58	( Key =.. [N,2|More] ; Key = s(0), N=s, More=[] ), !,
59	NKey =.. [N|More],
60	fetch_edges(Parents, NKey, EdgesF, Edges0, PKeys),
61	get_emission(V, Key, EmissionProb),
62	put_assoc(NKey,KeyMap0,nodeinfo(_,Id,EmissionProb,PKeys),KeyMapI),
63	get_graph(Vs, Keys, Edges0, KeyMapI, KeyMap).
64get_graph([_|Vs], Keys, Edges, KeyMap0, KeyMap) :-
65	get_graph(Vs, Keys, Edges, KeyMap0, KeyMap).
66get_graph([], [], [], KeyMap, KeyMap).
67
68get_emission(V, Key, EmissionProbs) :-
69	hmm:get_atts(V,[emission(_)]), !,
70	user:emission_cpt(Key, EmissionProbs).
71get_emission(_, _, []).
72
73fetch_edges([V|Parents], Key0, EdgesF, Edges0, [Slice-AKey|PKeys]) :-
74	var(V), !,
75	clpbn:get_atts(V,[key(Key)]),
76	abstract_key(Key, AKey, Slice),
77	(
78	 Slice < 3
79	->
80	 EdgesF = [Key0-AKey|EdgesI]
81	;
82	 EdgesF = EdgesI
83	),
84	fetch_edges(Parents, Key0, EdgesI, Edges0, PKeys).
85fetch_edges([Key|Parents], Key0, EdgesF, Edges0, [Slice-AKey|PKeys]) :-
86	abstract_key(Key, AKey, Slice),
87	(
88	 Slice < 3
89	->
90	 EdgesF = [Key0-AKey|EdgesI]
91	;
92	 EdgesF = EdgesI
93	),
94	fetch_edges(Parents, Key0, EdgesI, Edges0, PKeys).
95fetch_edges([], _, Edges, Edges, []).
96
97abstract_key(Key, NKey, Slice) :-
98	Key =.. [N,Slice|More],
99	NKey =.. [N|More].
100
101
102compile_viterbi(Keys, KeyMap, Nodes, Map, ViterbiCode) :-
103	enum_keys(Keys, KeyMap, 0, Nodes, Map),
104	compile_keys(Keys, KeyMap, ViterbiCode).
105
106% just enumerate keys
107enum_keys([], _, I, I, []).
108enum_keys([Key|Keys], KeyMap, I0, Nodes, [I0-Key|Map]) :-
109	get_assoc(Key,KeyMap,nodeinfo(I0,_,_,_)),
110	I is I0+1,
111	enum_keys(Keys, KeyMap, I, Nodes, Map).
112
113compile_keys([Key|Keys], KeyMap, ViterbiCodeF) :-
114	get_assoc(Key,KeyMap,nodeinfo(IKey,Id,Emission,PKeys)),
115	compile_emission(Emission,IKey,ViterbiCodeF,ViterbiCodeI),
116	get_dist_params(Id,Probs),
117	compile_propagation(PKeys,Probs,IKey,KeyMap,ViterbiCodeI,ViterbiCode0),
118	compile_keys(Keys, KeyMap, ViterbiCode0).
119compile_keys([], _, []).
120
121
122% add a random symbol to the end.
123compile_emission([],_) --> !, [].
124compile_emission(EmissionTerm,IKey) --> [emit(IKey,EmissionTerm)].
125
126compile_propagation([],[],_,_) --> [].
127compile_propagation([0-PKey|Ps], [Prob|Probs], IKey, KeyMap) -->
128	 [prop_same(IKey,Parent,Prob)],
129	 { get_assoc(PKey,KeyMap,nodeinfo(Parent,_,_,_)) },
130	 compile_propagation(Ps, Probs, IKey, KeyMap).
131compile_propagation([2-PKey|Ps], [Prob|Probs], IKey, KeyMap) -->
132	 [prop_same(IKey,Parent,Prob)],
133	 { get_assoc(PKey,KeyMap,nodeinfo(Parent,_,_,_)) },
134	 compile_propagation(Ps, Probs, IKey, KeyMap).
135compile_propagation([3-PKey|Ps], [Prob|Probs], IKey, KeyMap) -->
136	 [prop_next(IKey,Parent,Prob)],
137	 { get_assoc(PKey,KeyMap,nodeinfo(Parent,_,_,_)) },
138	 compile_propagation(Ps, Probs, IKey, KeyMap).
139
140get_id(_:S, Map, SI) :- !,
141	get_id(S, Map, SI).
142get_id(S, Map, SI) :-
143	functor(S,N,A),
144	A2 is A-2,
145	functor(S2,N,A2),
146	once(member(SI-S2,Map)).
147
148compile_trace(Trace, Emissions) :-
149	user:hmm_domain(Domain),
150	(atom(Domain) ->
151	 hmm:cvt_vals(Domain, Vals)
152	;
153	 Vals = Domain
154	),
155	compile_trace(Trace, Vals, Emissions).
156
157compile_trace([], _, []).
158compile_trace([El|Trace], Vals, [N|Emissions]) :-
159	once(nth(N, Vals, El)),
160	compile_trace(Trace, Vals, Emissions).
161
162compiled_viterbi(Nodes, S, Commands, Input, Trace, L) :-
163	length(Input,L),
164	prolog_flag(min_tagged_integer, Min),
165	matrix_new_set(ints,[Nodes], Min, Current),
166	matrix_new_set(ints,[Nodes], Min, Next),
167	L1 is L+1,
168	matrix_new(ints,[L1,Nodes], Trace),
169	matrix_set(Current, [S], 0),
170	run_commands(Input, Commands, 0, Current, Next, Trace, Min).
171
172
173run_commands([], _, _, _, _, _, _).
174run_commands([E|Input], Commands, I, Current, Next, Trace, Min) :-
175	run_code(Commands, E, I, Current, Next, Trace),
176	matrix_get(Current, [32], M10),
177	matrix_get(Current, [34], C),
178	matrix_set_all(Current,Min),
179	I1 is I+1,
180	run_commands(Input, Commands, I1, Next, Current, Trace, Min).
181
182run_code([], _, _, _, _, Trace).
183run_code([Inst|Input], E, I, Current, Next, Trace) :-
184	run_inst(Inst, E, I, Current, Next, Trace) ,
185	run_code(Input, E, I, Current, Next, Trace).
186
187run_inst(emit(Id,T), E, _SP, Current, _, Trace) :-
188	arg(E,T,P),
189	matrix_add(Current, [Id], P).
190run_inst(prop_same(I,P,Prob), _, SP, Current, _, Trace) :-
191	matrix_get(Current, [I], PI),
192	NP is PI+Prob,
193	matrix_get(Current, [P], P0),
194	(NP > P0 ->
195	 matrix_set(Current, [P], NP),
196	 matrix_set(Trace, [SP,P], I)
197	;
198	 true
199	).
200run_inst(prop_next(I,P,Prob), _, SP, Current, Next, Trace) :-
201	matrix_get(Current, [I], PI),
202	NP is PI+Prob,
203	matrix_get(Next, [P], P0),
204	(NP > P0 ->
205	 matrix_set(Next, [P], NP),
206	 SP1 is SP+1,
207	 IN is -I,
208	 matrix_set(Trace, [SP1,P], IN)
209	;
210	 true
211	).
212
213backtrace(Dump, EI, Map, L, Trace) :-
214	L1 is L-1,
215	Pos = [L1,EI],
216	matrix_get(Dump,Pos,Next),
217	trace(L1,Next,Dump,Map,[],Trace).
218
219trace(0,0,_,_,Trace,Trace) :- !.
220trace(L1,Next,Dump,Map,Trace0,Trace) :-
221	(Next < 0 ->
222	 NL is L1-1,
223	 P is -Next
224	;
225	 NL = L1,
226	 P = Next
227	),
228	once(member(P-AKey,Map)),
229	AKey=..[N|Args],
230	Key=..[N,NL|Args],
231	matrix_get(Dump,[NL,P],New),
232	trace(NL,New,Dump,Map,[Key|Trace0],Trace).
233
234
235
236