1%%% -*- Mode: Prolog; -*-
2
3%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4%
5%  $Date: 2011-04-21 14:18:59 +0200 (Thu, 21 Apr 2011) $
6%  $Revision: 6364 $
7%
8%  This file is part of ProbLog
9%  http://dtai.cs.kuleuven.be/problog
10%
11%  ProbLog was developed at Katholieke Universiteit Leuven
12%
13%  Copyright 2008, 2009, 2010
14%  Katholieke Universiteit Leuven
15%
16%  Main authors of this file:
17%  Bernd Gutmann
18%
19%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
20%
21% Artistic License 2.0
22%
23% Copyright (c) 2000-2006, The Perl Foundation.
24%
25% Everyone is permitted to copy and distribute verbatim copies of this
26% license document, but changing it is not allowed.  Preamble
27%
28% This license establishes the terms under which a given free software
29% Package may be copied, modified, distributed, and/or
30% redistributed. The intent is that the Copyright Holder maintains some
31% artistic control over the development of that Package while still
32% keeping the Package available as open source and free software.
33%
34% You are always permitted to make arrangements wholly outside of this
35% license directly with the Copyright Holder of a given Package. If the
36% terms of this license do not permit the full use that you propose to
37% make of the Package, you should contact the Copyright Holder and seek
38% a different licensing arrangement.  Definitions
39%
40% "Copyright Holder" means the individual(s) or organization(s) named in
41% the copyright notice for the entire Package.
42%
43% "Contributor" means any party that has contributed code or other
44% material to the Package, in accordance with the Copyright Holder's
45% procedures.
46%
47% "You" and "your" means any person who would like to copy, distribute,
48% or modify the Package.
49%
50% "Package" means the collection of files distributed by the Copyright
51% Holder, and derivatives of that collection and/or of those files. A
52% given Package may consist of either the Standard Version, or a
53% Modified Version.
54%
55% "Distribute" means providing a copy of the Package or making it
56% accessible to anyone else, or in the case of a company or
57% organization, to others outside of your company or organization.
58%
59% "Distributor Fee" means any fee that you charge for Distributing this
60% Package or providing support for this Package to another party. It
61% does not mean licensing fees.
62%
63% "Standard Version" refers to the Package if it has not been modified,
64% or has been modified only in ways explicitly requested by the
65% Copyright Holder.
66%
67% "Modified Version" means the Package, if it has been changed, and such
68% changes were not explicitly requested by the Copyright Holder.
69%
70% "Original License" means this Artistic License as Distributed with the
71% Standard Version of the Package, in its current version or as it may
72% be modified by The Perl Foundation in the future.
73%
74% "Source" form means the source code, documentation source, and
75% configuration files for the Package.
76%
77% "Compiled" form means the compiled bytecode, object code, binary, or
78% any other form resulting from mechanical transformation or translation
79% of the Source form.
80%
81%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
82%
83% Permission for Use and Modification Without Distribution
84%
85% (1) You are permitted to use the Standard Version and create and use
86% Modified Versions for any purpose without restriction, provided that
87% you do not Distribute the Modified Version.
88%
89% Permissions for Redistribution of the Standard Version
90%
91% (2) You may Distribute verbatim copies of the Source form of the
92% Standard Version of this Package in any medium without restriction,
93% either gratis or for a Distributor Fee, provided that you duplicate
94% all of the original copyright notices and associated disclaimers. At
95% your discretion, such verbatim copies may or may not include a
96% Compiled form of the Package.
97%
98% (3) You may apply any bug fixes, portability changes, and other
99% modifications made available from the Copyright Holder. The resulting
100% Package will still be considered the Standard Version, and as such
101% will be subject to the Original License.
102%
103% Distribution of Modified Versions of the Package as Source
104%
105% (4) You may Distribute your Modified Version as Source (either gratis
106% or for a Distributor Fee, and with or without a Compiled form of the
107% Modified Version) provided that you clearly document how it differs
108% from the Standard Version, including, but not limited to, documenting
109% any non-standard features, executables, or modules, and provided that
110% you do at least ONE of the following:
111%
112% (a) make the Modified Version available to the Copyright Holder of the
113% Standard Version, under the Original License, so that the Copyright
114% Holder may include your modifications in the Standard Version.  (b)
115% ensure that installation of your Modified Version does not prevent the
116% user installing or running the Standard Version. In addition, the
117% modified Version must bear a name that is different from the name of
118% the Standard Version.  (c) allow anyone who receives a copy of the
119% Modified Version to make the Source form of the Modified Version
120% available to others under (i) the Original License or (ii) a license
121% that permits the licensee to freely copy, modify and redistribute the
122% Modified Version using the same licensing terms that apply to the copy
123% that the licensee received, and requires that the Source form of the
124% Modified Version, and of any works derived from it, be made freely
125% available in that license fees are prohibited but Distributor Fees are
126% allowed.
127%
128% Distribution of Compiled Forms of the Standard Version or
129% Modified Versions without the Source
130%
131% (5) You may Distribute Compiled forms of the Standard Version without
132% the Source, provided that you include complete instructions on how to
133% get the Source of the Standard Version. Such instructions must be
134% valid at the time of your distribution. If these instructions, at any
135% time while you are carrying out such distribution, become invalid, you
136% must provide new instructions on demand or cease further
137% distribution. If you provide valid instructions or cease distribution
138% within thirty days after you become aware that the instructions are
139% invalid, then you do not forfeit any of your rights under this
140% license.
141%
142% (6) You may Distribute a Modified Version in Compiled form without the
143% Source, provided that you comply with Section 4 with respect to the
144% Source of the Modified Version.
145%
146% Aggregating or Linking the Package
147%
148% (7) You may aggregate the Package (either the Standard Version or
149% Modified Version) with other packages and Distribute the resulting
150% aggregation provided that you do not charge a licensing fee for the
151% Package. Distributor Fees are permitted, and licensing fees for other
152% components in the aggregation are permitted. The terms of this license
153% apply to the use and Distribution of the Standard or Modified Versions
154% as included in the aggregation.
155%
156% (8) You are permitted to link Modified and Standard Versions with
157% other works, to embed the Package in a larger work of your own, or to
158% build stand-alone binary or bytecode versions of applications that
159% include the Package, and Distribute the result without restriction,
160% provided the result does not expose a direct interface to the Package.
161%
162% Items That are Not Considered Part of a Modified Version
163%
164% (9) Works (including, but not limited to, modules and scripts) that
165% merely extend or make use of the Package, do not, by themselves, cause
166% the Package to be a Modified Version. In addition, such works are not
167% considered parts of the Package itself, and are not subject to the
168% terms of this license.
169%
170% General Provisions
171%
172% (10) Any use, modification, and distribution of the Standard or
173% Modified Versions is governed by this Artistic License. By using,
174% modifying or distributing the Package, you accept this license. Do not
175% use, modify, or distribute the Package, if you do not accept this
176% license.
177%
178% (11) If your Modified Version has been derived from a Modified Version
179% made by someone other than you, you are nevertheless required to
180% ensure that your Modified Version complies with the requirements of
181% this license.
182%
183% (12) This license does not grant you the right to use any trademark,
184% service mark, tradename, or logo of the Copyright Holder.
185%
186% (13) This license includes the non-exclusive, worldwide,
187% free-of-charge patent license to make, have made, use, offer to sell,
188% sell, import and otherwise transfer the Package with respect to any
189% patent claims licensable by the Copyright Holder that are necessarily
190% infringed by the Package. If you institute patent litigation
191% (including a cross-claim or counterclaim) against any party alleging
192% that the Package constitutes direct or contributory patent
193% infringement, then this Artistic License to you shall terminate on the
194% date that such litigation is filed.
195%
196% (14) Disclaimer of Warranty: THE PACKAGE IS PROVIDED BY THE COPYRIGHT
197% HOLDER AND CONTRIBUTORS "AS IS' AND WITHOUT ANY EXPRESS OR IMPLIED
198% WARRANTIES. THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
199% PARTICULAR PURPOSE, OR NON-INFRINGEMENT ARE DISCLAIMED TO THE EXTENT
200% PERMITTED BY YOUR LOCAL LAW. UNLESS REQUIRED BY LAW, NO COPYRIGHT
201% HOLDER OR CONTRIBUTOR WILL BE LIABLE FOR ANY DIRECT, INDIRECT,
202% INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING IN ANY WAY OUT OF THE USE
203% OF THE PACKAGE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
204%
205%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
206
207
208:- module(learning,[do_learning/1,
209	            do_learning/2,
210		    reset_learning/0
211		    ]).
212
213% switch on all the checks to reduce bug searching time
214:- style_check(all).
215:- yap_flag(unknown,error).
216
217% load modules from the YAP library
218:- use_module(library(lists), [max_list/2, min_list/2, sum_list/2]).
219:- use_module(library(system), [file_exists/1, shell/2]).
220
221% load our own modules
222:- use_module(problog).
223:- use_module('problog/logger').
224:- use_module('problog/flags').
225:- use_module('problog/os').
226:- use_module('problog/print_learning').
227:- use_module('problog/utils_learning').
228:- use_module('problog/utils').
229:- use_module('problog/tabling').
230
231% used to indicate the state of the system
232:- dynamic(values_correct/0).
233:- dynamic(learning_initialized/0).
234:- dynamic(current_iteration/1).
235:- dynamic(example_count/1).
236:- dynamic(query_probability_intern/2).
237:- dynamic(query_gradient_intern/4).
238:- dynamic(last_mse/1).
239:- dynamic(query_is_similar/2).
240:- dynamic(query_md5/2).
241
242
243% used to identify queries which have identical proofs
244:- dynamic(query_is_similar/2).
245:- dynamic(query_md5/3).
246
247:- multifile(user:example/4).
248user:example(A,B,C,=) :-
249	current_predicate(user:example/3),
250	user:example(A,B,C).
251
252:- multifile(user:test_example/4).
253user:test_example(A,B,C,=) :-
254	current_predicate(user:test_example/3),
255	user:test_example(A,B,C).
256
257
258%========================================================================
259%= store the facts with the learned probabilities to a file
260%========================================================================
261
262save_model:-
263	current_iteration(Iteration),
264	create_factprobs_file_name(Iteration,Filename),
265	export_facts(Filename).
266
267
268
269
270%========================================================================
271%= find out whether some example IDs are used more than once
272%= if so, complain and stop
273%=
274%========================================================================
275
276check_examples :-
277	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
278	% Check example IDs
279        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
280	(
281	 (user:example(ID,_,_,_), \+ atomic(ID))
282	->
283	 (
284	  format(user_error,'The example id of training example ~q ',[ID]),
285	  format(user_error,'is not atomic (e.g foo42, 23, bar, ...).~n',[]),
286	  throw(error(examples))
287	 ); true
288	),
289
290	(
291	 (user:test_example(ID,_,_,_), \+ atomic(ID))
292	->
293	 (
294	  format(user_error,'The example id of test example ~q ',[ID]),
295	  format(user_error,'is not atomic (e.g foo42, 23, bar, ...).~n',[]),
296	  throw(error(examples))
297	 ); true
298	),
299
300        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
301	% Check example probabilities
302        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
303	(
304	 (user:example(ID,_,P,_), (\+ number(P); P>1 ; P<0))
305	->
306	 (
307	  format(user_error,'The training example ~q does not have a valid probability value (~q).~n',[ID,P]),
308	  throw(error(examples))
309	 ); true
310	),
311
312	(
313	 (user:test_example(ID,_,P,_), (\+ number(P); P>1 ; P<0))
314	->
315	 (
316	  format(user_error,'The test example ~q does not have a valid probability value (~q).~n',[ID,P]),
317	  throw(error(examples))
318	 ); true
319	),
320
321
322	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
323	% Check that no example ID is repeated,
324	% and if it is repeated make sure the query is the same
325        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
326	(
327	 (
328	  (
329	   user:example(ID,QueryA,_,_),
330	   user:example(ID,QueryB,_,_),
331	   QueryA \= QueryB
332	  ) ;
333
334	  (
335	   user:test_example(ID,QueryA,_,_),
336	   user:test_example(ID,QueryB,_,_),
337	   QueryA \= QueryB
338	  );
339
340	  (
341	   user:example(ID,QueryA,_,_),
342	   user:test_example(ID,QueryB,_,_),
343	   QueryA \= QueryB
344	  )
345	 )
346	->
347	 (
348	  format(user_error,'The example id ~q is used several times.~n',[ID]),
349	  throw(error(examples))
350	 ); true
351	).
352%========================================================================
353%=
354%========================================================================
355
356reset_learning :-
357	retractall(learning_initialized),
358	retractall(values_correct),
359	retractall(current_iteration(_)),
360	retractall(example_count(_)),
361	retractall(query_probability_intern(_,_)),
362	retractall(query_gradient_intern(_,_,_)),
363	retractall(last_mse(_)),
364	retractall(query_is_similar(_,_)),
365	retractall(query_md5(_,_,_)),
366
367	set_problog_flag(alpha,auto),
368	set_problog_flag(learning_rate,examples),
369	logger_reset_all_variables.
370
371
372
373%========================================================================
374%= initialize everything and perform Iterations times gradient descent
375%= can be called several times
376%= if it is called with an epsilon parameter, it stops when the change
377%= in the MSE is smaller than epsilon
378%========================================================================
379
380do_learning(Iterations) :-
381	do_learning(Iterations,-1).
382
383do_learning(Iterations,Epsilon) :-
384	current_predicate(user:example/4),
385	!,
386	integer(Iterations),
387	number(Epsilon),
388	Iterations>0,
389	do_learning_intern(Iterations,Epsilon).
390do_learning(_,_) :-
391	format(user_error,'~n~Error: No training examples specified.~n~n',[]).
392
393
394do_learning_intern(0,_) :-
395	!.
396do_learning_intern(Iterations,Epsilon) :-
397	Iterations>0,
398
399	init_learning,
400	current_iteration(CurrentIteration),
401	retractall(current_iteration(_)),
402	NextIteration is CurrentIteration+1,
403	assertz(current_iteration(NextIteration)),
404	EndIteration is CurrentIteration+Iterations-1,
405
406	format_learning(1,'~nIteration ~d of ~d~n',[CurrentIteration,EndIteration]),
407	logger_set_variable(iteration,CurrentIteration),
408
409	logger_start_timer(duration),
410	mse_testset,
411	ground_truth_difference,
412	gradient_descent,
413
414	problog_flag(log_frequency,Log_Frequency),
415
416	(
417	 ( Log_Frequency>0, 0 =:= CurrentIteration mod Log_Frequency)
418	->
419	 once(save_model);
420	 true
421	),
422
423	update_values,
424
425	(
426	 last_mse(Last_MSE)
427	->
428	 (
429	  retractall(last_mse(_)),
430	  logger_get_variable(mse_trainingset,Current_MSE),
431	  assertz(last_mse(Current_MSE)),
432	  !,
433	  MSE_Diff is abs(Last_MSE-Current_MSE)
434	 );  (
435	      logger_get_variable(mse_trainingset,Current_MSE),
436	      assertz(last_mse(Current_MSE)),
437	      MSE_Diff is Epsilon+1
438	     )
439	),
440
441	(
442	 (problog_flag(rebuild_bdds,BDDFreq),BDDFreq>0,0 =:= CurrentIteration mod BDDFreq)
443	->
444	 (
445	  retractall(values_correct),
446	  retractall(query_is_similar(_,_)),
447	  retractall(query_md5(_,_,_)),
448	  empty_bdd_directory,
449	  init_queries
450	 ); true
451	),
452
453
454	!,
455	logger_stop_timer(duration),
456
457
458	logger_write_data,
459
460
461
462	RemainingIterations is Iterations-1,
463
464	(
465	 MSE_Diff>Epsilon
466	->
467	 do_learning_intern(RemainingIterations,Epsilon);
468	 true
469	).
470
471
472%========================================================================
473%= find proofs and build bdds for all training and test examples
474%=
475%=
476%========================================================================
477init_learning :-
478	learning_initialized,
479	!.
480init_learning :-
481	check_examples,
482
483	empty_output_directory,
484	logger_write_header,
485
486	format_learning(1,'Initializing everything~n',[]),
487
488
489        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
490	% Delete the BDDs from the previous run if they should
491	% not be reused
492        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
493	(
494	 (
495	  problog_flag(reuse_initialized_bdds,true),
496	  problog_flag(rebuild_bdds,0)
497	 )
498	->
499	 true;
500	 empty_bdd_directory
501	),
502
503        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
504	% Check, if continuous facts are used.
505	% if yes, switch to problog_exact
506        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
507	problog_flag(init_method,(_,_,_,_,OldCall)),
508	(
509	 (
510	  continuous_fact(_),
511	  OldCall\=problog_exact_save(_,_,_,_,_)
512	 )
513	->
514	 (
515	  format_learning(2,'Theory uses continuous facts.~nWill use problog_exact/3 as initalization method.~2n',[]),
516	  set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile)))
517	 );
518	 true
519	),
520
521	(
522	 problog_tabled(_)
523	->
524	 (
525	  format_learning(2,'Theory uses tabling.~nWill use problog_exact/3 as initalization method.~2n',[]),
526	  set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile)))
527	 );
528	 true
529	),
530
531
532	succeeds_n_times(user:test_example(_,_,_,_),TestExampleCount),
533	format_learning(3,'~q test examples~n',[TestExampleCount]),
534
535	succeeds_n_times(user:example(_,_,_,_),TrainingExampleCount),
536	assertz(example_count(TrainingExampleCount)),
537	format_learning(3,'~q training examples~n',[TrainingExampleCount]),
538
539
540        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
541	% set learning rate and alpha
542        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
543	(
544	 problog_flag(learning_rate,examples)
545	->
546	 set_problog_flag(learning_rate,TrainingExampleCount);
547	 true
548	),
549
550	(
551	 problog_flag(alpha,auto)
552	->
553	 (
554	  (user:example(_,_,P,_),P<1,P>0)
555	 ->
556	  set_problog_flag(alpha,1.0);
557	  (
558	   succeeds_n_times((user:example(_,_,P,=),P=:=1.0),Pos_Count),
559	   succeeds_n_times((user:example(_,_,P,=),P=:=0.0),Neg_Count),
560	   Alpha is Pos_Count/Neg_Count,
561	   set_problog_flag(alpha,Alpha)
562	  )
563	 )
564	),
565
566	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
567	% build BDD script for every example
568        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
569	once(init_queries),
570
571	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
572	% done
573        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
574	assertz(current_iteration(0)),
575	assertz(learning_initialized),
576
577	format_learning(1,'~n',[]).
578
579
580
581%========================================================================
582%= This predicate goes over all training and test examples,
583%= calls the inference method of ProbLog and stores the resulting
584%= BDDs
585%========================================================================
586
587
588init_queries :-
589	format_learning(2,'Build BDDs for examples~n',[]),
590	forall(user:test_example(ID,Query,_Prob,_),init_one_query(ID,Query,test)),
591	forall(user:example(ID,Query,_Prob,_),init_one_query(ID,Query,training)).
592
593bdd_input_file(Filename) :-
594	problog_flag(output_directory,Dir),
595	concat_path_with_filename(Dir,'input.txt',Filename).
596
597init_one_query(QueryID,Query,Type) :-
598	format_learning(3,' ~q example ~q: ~q~n',[Type,QueryID,Query]),
599
600	bdd_input_file(Probabilities_File),
601	problog_flag(bdd_directory,Query_Directory),
602
603	atomic_concat(['query_',QueryID],Filename1),
604	concat_path_with_filename(Query_Directory,Filename1,Filename),
605
606        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
607	% if BDD file does not exist, call ProbLog
608	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
609	(
610	 file_exists(Filename)
611	->
612	 format_learning(3,' Reuse existing BDD ~q~n~n',[Filename]);
613	 (
614	  problog_flag(init_method,(Query,_Prob,Filename,Probabilities_File,Call)),
615	  once(Call),
616	  delete_file_silently(Probabilities_File)
617	 )
618	),
619
620        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
621	% check wether this BDD is similar to another BDD
622	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
623	(
624	 problog_flag(check_duplicate_bdds,true)
625	->
626	 (
627	  calc_md5(Filename,Query_MD5),
628	  (
629	    query_md5(OtherQueryID,Query_MD5,Type)
630	  ->
631	    (
632	      assertz(query_is_similar(QueryID,OtherQueryID)),
633	      format_learning(3, '~q is similar to ~q~2n', [QueryID,OtherQueryID])
634	    );
635	    assertz(query_md5(QueryID,Query_MD5,Type))
636	  )
637	 );
638
639	 true
640	),!,
641	garbage_collect.
642
643
644
645
646%========================================================================
647%= updates all values of query_probability/2 and query_gradient/4
648%= should be called always before these predicates are accessed
649%= if the old values are still valid, nothing happens
650%========================================================================
651
652update_values :-
653	values_correct,
654	!.
655update_values :-
656	\+ values_correct,
657
658	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
659	% delete old values
660	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
661	retractall(query_probability_intern(_,_)),
662	retractall(query_gradient_intern(_,_,_,_)),
663
664	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
665	% start write current probabilities to file
666	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
667	bdd_input_file(Probabilities_File),
668	delete_file_silently(Probabilities_File),
669
670	open(Probabilities_File,'write',Handle),
671
672	forall(get_fact_probability(ID,Prob),
673	       (
674		(problog:dynamic_probability_fact(ID) ->
675      get_fact(ID, Term),
676      forall(grounding_is_known(Term, GID), (
677        problog:dynamic_probability_fact_extract(Term, Prob2),
678        inv_sigmoid(Prob2,Value),
679        format(Handle, '@x~q_~q~n~10f~n', [ID,GID, Value])))
680    ; non_ground_fact(ID) ->
681      inv_sigmoid(Prob,Value),
682		 format(Handle,'@x~q_*~n~10f~n',[ID,Value])
683    ;
684      inv_sigmoid(Prob,Value),
685		 format(Handle,'@x~q~n~10f~n',[ID,Value])
686		)
687	       )),
688
689	forall(get_continuous_fact_parameters(ID,gaussian(Mu,Sigma)),
690	       format(Handle,'@x~q_*~n0~n0~n~10f;~10f~n',[ID,Mu,Sigma])),
691
692	close(Handle),
693	!,
694	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
695	% stop write current probabilities to file
696	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
697
698	assertz(values_correct).
699
700
701
702%========================================================================
703%=
704%=
705%=
706%========================================================================
707
708update_query_cleanup(QueryID) :-
709	(
710	 (query_is_similar(QueryID,_) ; query_is_similar(_,QueryID))
711	->
712	    % either this query is similar to another or vice versa,
713	    % therefore we don't delete anything
714	 true;
715	 retractall(query_gradient_intern(QueryID,_,_,_))
716	).
717
718
719update_query(QueryID,Symbol,What_To_Update) :-
720	% fixme OS trouble
721	problog_flag(output_directory,Output_Directory),
722	problog_flag(bdd_directory,Query_Directory),
723	bdd_input_file(Probabilities_File),
724	(
725	 query_is_similar(QueryID,_)
726	->
727				% we don't have to evaluate the BDD
728	 format_learning(4,'#',[]);
729	 (
730	  problog_flag(sigmoid_slope,Slope),
731	  problog_dir(PD),
732	  ((What_To_Update=all;query_is_similar(_,QueryID)) -> Method='g' ; Method='l'),
733	  atomic_concat([PD,
734			 '/problogbdd',
735			 ' -i "', Probabilities_File, '"',
736			 ' -l "', Query_Directory,'/query_',QueryID, '"',
737			 ' -m ', Method,
738			 ' -id ', QueryID,
739			 ' -sl ', Slope,
740			 ' > "',
741			 Output_Directory,
742			 'values.pl"'],Command),
743	  shell(Command,Error),
744
745
746	  (
747	   Error = 2
748	  ->
749	   throw(error('SimpleCUDD has been interrupted.'));
750	   true
751	  ),
752	  (
753	   Error \= 0
754	  ->
755	   (
756	   format(user_error,'SimpleCUDD stopped with error code ~q, command was ~q~n',[Error, shell(Command,Error)]),
757	   throw(bdd_error(QueryID,Error)));
758	   true
759	  ),
760	  atomic_concat([Output_Directory,'values.pl'],Values_Filename),
761	  (
762	   file_exists(Values_Filename)
763	  ->
764	   (
765	    (
766	     once(my_load(Values_Filename,QueryID))
767	    ->
768	     true;
769	     (
770	      format(user_error,'ERROR: Tried to read the file ~q but my_load/1 fails.~n~q.~2n',[Values_Filename,update_query(QueryID,Symbol,What_To_Update)]),
771	      throw(error(my_load_fails))
772	     )
773	    );
774	    (
775	     format(user_error,'ERROR: Tried to read the file ~q but it does not exist.~n~q.~2n',[Values_Filename,update_query(QueryID,Symbol,What_To_Update)]),
776	     throw(error(output_file_does_not_exist))
777	    )
778	   )
779	  ),
780
781	  delete_file_silently(Values_Filename),
782	  format_learning(4,'~w',[Symbol])
783	 )
784	).
785
786
787%========================================================================
788%= This predicate reads probability and gradient values from the file
789%= the gradient ID is a mere check to uncover hidden bugs
790%= +Filename +QueryID -QueryProbability
791%========================================================================
792
793my_load(File,QueryID) :-
794	open(File,'read',Handle),
795	read(Handle,Atom),
796	once(my_load_intern(Atom,Handle,QueryID)),
797	close(Handle).
798my_load(File,QueryID) :-
799	format(user_error,'Error at ~q.~2n',[my_load(File,QueryID)]),
800	throw(error(my_load(File,QueryID))).
801
802my_load_intern(end_of_file,_,_) :-
803	!.
804my_load_intern(query_probability(QueryID,Prob),Handle,QueryID) :-
805	!,
806	assertz(query_probability_intern(QueryID,Prob)),
807	read(Handle,X),
808	my_load_intern(X,Handle,QueryID).
809my_load_intern(query_gradient(QueryID,XFactID,Type,Value),Handle,QueryID) :-
810	!,
811	atomic_concat(x,StringFactID,XFactID),
812	atom_number(StringFactID,FactID),
813	assertz(query_gradient_intern(QueryID,FactID,Type,Value)),
814	read(Handle,X),
815	my_load_intern(X,Handle,QueryID).
816my_load_intern(X,Handle,QueryID) :-
817	format(user_error,'Unknown atom ~q in results file.~n',[X]),
818	read(Handle,X2),
819	my_load_intern(X2,Handle,QueryID).
820
821
822
823
824%========================================================================
825%=
826%=
827%=
828%========================================================================
829query_probability(QueryID,Prob) :-
830	(
831	 query_probability_intern(QueryID,Prob)
832	->
833	 true;
834	 (
835	  query_is_similar(QueryID,OtherQueryID),
836	  query_probability_intern(OtherQueryID,Prob)
837	 )
838	).
839query_gradient(QueryID,Fact,Type,Value) :-
840	(
841	 query_gradient_intern(QueryID,Fact,Type,Value)
842	->
843	 true;
844	 (
845	  query_is_similar(QueryID,OtherQueryID),
846	  query_gradient_intern(OtherQueryID,Fact,Type,Value)
847	 )
848	).
849
850%========================================================================
851%=
852%=
853%=
854%========================================================================
855
856
857
858% FIXME
859ground_truth_difference :-
860	findall(Diff,(tunable_fact(FactID,GroundTruth),
861		      \+continuous_fact(FactID),
862		      \+ var(GroundTruth),
863		      get_fact_probability(FactID,Prob),
864		      Diff is abs(GroundTruth-Prob)),AllDiffs),
865	(
866	 AllDiffs=[]
867	->
868	 (
869	  MinDiff=0.0,
870	  MaxDiff=0.0,
871	  DiffMean=0.0
872	 ) ;
873	 (
874	  length(AllDiffs,Len),
875	  sum_list(AllDiffs,AllDiffsSum),
876	  min_list(AllDiffs,MinDiff),
877	  max_list(AllDiffs,MaxDiff),
878	  DiffMean is AllDiffsSum/Len
879	 )
880	),
881
882	logger_set_variable(ground_truth_diff,DiffMean),
883	logger_set_variable(ground_truth_mindiff,MinDiff),
884	logger_set_variable(ground_truth_maxdiff,MaxDiff).
885
886%========================================================================
887%= Calculates the mse of training and test data
888%=
889%= -Float
890%========================================================================
891
892mse_trainingset_only_for_linesearch(MSE) :-
893	update_values,
894
895	example_count(Example_Count),
896
897	bb_put(error_train_line_search,0.0),
898	forall(user:example(QueryID,_Query,QueryProb,Type),
899	       (
900		once(update_query(QueryID,'.',probability)),
901		query_probability(QueryID,CurrentProb),
902		once(update_query_cleanup(QueryID)),
903		(
904		 (Type == '='; (Type == '<', CurrentProb>QueryProb); (Type=='>',CurrentProb<QueryProb))
905		->
906		 (
907		  bb_get(error_train_line_search,Old_Error),
908		  New_Error is Old_Error + (CurrentProb-QueryProb)**2,
909		  bb_put(error_train_line_search,New_Error)
910		 );true
911		)
912	       )
913	      ),
914	bb_delete(error_train_line_search,Error),
915	MSE is Error/Example_Count,
916	format_learning(3,' (~8f)~n',[MSE]),
917	retractall(values_correct).
918
919mse_testset :-
920	current_iteration(Iteration),
921	create_test_predictions_file_name(Iteration,File_Name),
922	open(File_Name,'write',Handle),
923	format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]),
924	format(Handle,"% Iteration, train/test, QueryID, Query, GroundTruth, Prediction %~n",[]),
925	format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]),
926
927	format_learning(2,'MSE_Test ',[]),
928	update_values,
929	bb_put(llh_test_queries,0.0),
930	findall(SquaredError,
931		(user:test_example(QueryID,Query,TrueQueryProb,Type),
932		 once(update_query(QueryID,'+',probability)),
933		 query_probability(QueryID,CurrentProb),
934		 format(Handle,'ex(~q,test,~q,~q,~10f,~10f).~n',[Iteration,QueryID,Query,TrueQueryProb,CurrentProb]),
935		 once(update_query_cleanup(QueryID)),
936		 (
937		  (Type == '='; (Type == '<', CurrentProb>QueryProb); (Type=='>',CurrentProb<QueryProb))
938		 ->
939		  SquaredError is (CurrentProb-TrueQueryProb)**2;
940		  SquaredError = 0.0
941		 ),
942		 bb_get(llh_test_queries,Old_LLH_Test_Queries),
943		 New_LLH_Test_Queries is Old_LLH_Test_Queries+log(CurrentProb),
944		 bb_put(llh_test_queries,New_LLH_Test_Queries)
945		),
946		AllSquaredErrors),
947
948        close(Handle),
949	bb_delete(llh_test_queries,LLH_Test_Queries),
950
951	length(AllSquaredErrors,Length),
952
953	(
954	 Length>0
955	->
956	 (
957	  sum_list(AllSquaredErrors,SumAllSquaredErrors),
958	  min_list(AllSquaredErrors,MinError),
959	  max_list(AllSquaredErrors,MaxError),
960	  MSE is SumAllSquaredErrors/Length
961	 );(
962	    MSE=0.0,
963	    MinError=0.0,
964	    MaxError=0.0
965	   )
966	),
967
968	logger_set_variable(mse_testset,MSE),
969	logger_set_variable(mse_min_testset,MinError),
970	logger_set_variable(mse_max_testset,MaxError),
971	logger_set_variable(llh_test_queries,LLH_Test_Queries),
972	format_learning(2,' (~8f)~n',[MSE]).
973
974%========================================================================
975%= Calculates the sigmoid function respectivly the inverse of it
976%= warning: applying inv_sigmoid to 0.0 or 1.0 will yield +/-inf
977%=
978%= +Float, -Float
979%========================================================================
980
981sigmoid(T,Sig) :-
982	problog_flag(sigmoid_slope,Slope),
983	Sig is 1/(1+exp(-T*Slope)).
984
985inv_sigmoid(T,InvSig) :-
986	problog_flag(sigmoid_slope,Slope),
987	InvSig is -log(1/T-1)/Slope.
988
989
990
991
992
993
994%========================================================================
995%= Perform one iteration of gradient descent
996%=
997%= assumes that everything is initialized, if the current values
998%= of query_probability/2 and query_gradient/4 are not up to date
999%= they will be recalculated
1000%= finally, the values_correct/0 is retracted to signal that the
1001%= probabilities of the examples have to be recalculated
1002%========================================================================
1003
1004save_old_probabilities :-
1005	forall(tunable_fact(FactID,_),
1006	       (
1007		continuous_fact(FactID)
1008	       ->
1009		(
1010		 get_continuous_fact_parameters(FactID,gaussian(OldMu,OldSigma)),
1011		 atomic_concat(['old_mu_',FactID],Key),
1012		 atomic_concat(['old_sigma_',FactID],Key2),
1013		 bb_put(Key,OldMu),
1014		 bb_put(Key2,OldSigma)
1015		);
1016		(
1017		 get_fact_probability(FactID,OldProbability),
1018		 atomic_concat(['old_prob_',FactID],Key),
1019		 bb_put(Key,OldProbability)
1020		)
1021	       )
1022	      ).
1023
1024
1025
1026forget_old_probabilities :-
1027	forall(tunable_fact(FactID,_),
1028	       (
1029		continuous_fact(FactID)
1030	       ->
1031		(
1032		 atomic_concat(['old_mu_',FactID],Key),
1033		 atomic_concat(['old_sigma_',FactID],Key2),
1034		 atomic_concat(['grad_mu_',FactID],Key3),
1035		 atomic_concat(['grad_sigma_',FactID],Key4),
1036		 bb_delete(Key,_),
1037		 bb_delete(Key2,_),
1038		 bb_delete(Key3,_),
1039		 bb_delete(Key4,_)
1040		);
1041		(
1042		 atomic_concat(['old_prob_',FactID],Key),
1043		 atomic_concat(['grad_',FactID],Key2),
1044		 bb_delete(Key,_),
1045		 bb_delete(Key2,_)
1046		)
1047	       )
1048	      ).
1049
1050add_gradient(Learning_Rate) :-
1051	forall(tunable_fact(FactID,_),
1052	       (
1053		continuous_fact(FactID)
1054	       ->
1055		(
1056		 atomic_concat(['old_mu_',FactID],Key),
1057		 atomic_concat(['old_sigma_',FactID],Key2),
1058		 atomic_concat(['grad_mu_',FactID],Key3),
1059		 atomic_concat(['grad_sigma_',FactID],Key4),
1060
1061		 bb_get(Key,Old_Mu),
1062		 bb_get(Key2,Old_Sigma),
1063		 bb_get(Key3,Grad_Mu),
1064		 bb_get(Key4,Grad_Sigma),
1065
1066		 Mu is Old_Mu  -Learning_Rate* Grad_Mu,
1067		 Sigma is exp(log(Old_Sigma)  -Learning_Rate* Grad_Sigma),
1068
1069		 set_continuous_fact_parameters(FactID,gaussian(Mu,Sigma))
1070		);
1071		(
1072		 atomic_concat(['old_prob_',FactID],Key),
1073		 atomic_concat(['grad_',FactID],Key2),
1074
1075		 bb_get(Key,OldProbability),
1076		 bb_get(Key2,GradValue),
1077
1078		 inv_sigmoid(OldProbability,OldValue),
1079		 NewValue is OldValue -Learning_Rate*GradValue,
1080		 sigmoid(NewValue,NewProbability),
1081
1082				% Prevent "inf" by using values too close to 1.0
1083		 Prob_Secure is min(0.999999999,max(0.000000001,NewProbability)),
1084		 set_fact_probability(FactID,Prob_Secure)
1085		)
1086	       )
1087	      ),
1088	retractall(values_correct).
1089
1090
1091gradient_descent :-
1092	current_iteration(Iteration),
1093	create_training_predictions_file_name(Iteration,File_Name),
1094	open(File_Name,'write',Handle),
1095	format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]),
1096	format(Handle,"% Iteration, train/test, QueryID, Query, GroundTruth, Prediction %~n",[]),
1097	format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]),
1098
1099	format_learning(2,'Gradient ',[]),
1100
1101	save_old_probabilities,
1102	update_values,
1103
1104	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1105	% start set gradient to zero
1106	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1107	forall(tunable_fact(FactID,_),
1108	       (
1109		continuous_fact(FactID)
1110	       ->
1111
1112		(
1113		 atomic_concat(['grad_mu_',FactID],Key),
1114		 atomic_concat(['grad_sigma_',FactID],Key2),
1115		 bb_put(Key,0.0),
1116		 bb_put(Key2,0.0)
1117		);
1118		(
1119		 atomic_concat(['grad_',FactID],Key),
1120		 bb_put(Key,0.0)
1121		)
1122	       )
1123	      ),
1124	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1125	% stop gradient to zero
1126	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1127
1128	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1129	% start calculate gradient
1130	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1131	bb_put(mse_train_sum, 0.0),
1132	bb_put(mse_train_min, 0.0),
1133	bb_put(mse_train_max, 0.0),
1134	bb_put(llh_training_queries, 0.0),
1135
1136	problog_flag(alpha,Alpha),
1137	logger_set_variable(alpha,Alpha),
1138	example_count(Example_Count),
1139
1140	forall(user:example(QueryID,Query,QueryProb,Type),
1141	       (
1142		once(update_query(QueryID,'.',all)),
1143		query_probability(QueryID,BDDProb),
1144		format(Handle,'ex(~q,train,~q,~q,~10f,~10f).~n',[Iteration,QueryID,Query,QueryProb,BDDProb]),
1145		(
1146		 QueryProb=:=0.0
1147		->
1148		 Y2=Alpha;
1149		 Y2=1.0
1150		),
1151		(
1152		 (Type == '='; (Type == '<', BDDProb>QueryProb); (Type=='>',BDDProb<QueryProb))
1153		->
1154		 Y is Y2*2/Example_Count * (BDDProb-QueryProb);
1155		 Y=0.0
1156		),
1157
1158
1159				% first do the calculations for the MSE on training set
1160		(
1161		 (Type == '='; (Type == '<', BDDProb>QueryProb); (Type=='>',BDDProb<QueryProb))
1162		->
1163		 Squared_Error is (BDDProb-QueryProb)**2;
1164		 Squared_Error=0.0
1165		),
1166
1167		bb_get(mse_train_sum,Old_MSE_Train_Sum),
1168		bb_get(mse_train_min,Old_MSE_Train_Min),
1169		bb_get(mse_train_max,Old_MSE_Train_Max),
1170		bb_get(llh_training_queries,Old_LLH_Training_Queries),
1171		New_MSE_Train_Sum is Old_MSE_Train_Sum+Squared_Error,
1172		New_MSE_Train_Min is min(Old_MSE_Train_Min,Squared_Error),
1173		New_MSE_Train_Max is max(Old_MSE_Train_Max,Squared_Error),
1174		New_LLH_Training_Queries is Old_LLH_Training_Queries+log(BDDProb),
1175		bb_put(mse_train_sum,New_MSE_Train_Sum),
1176		bb_put(mse_train_min,New_MSE_Train_Min),
1177		bb_put(mse_train_max,New_MSE_Train_Max),
1178		bb_put(llh_training_queries,New_LLH_Training_Queries),
1179
1180
1181
1182		(		% go over all tunable facts
1183		  tunable_fact(FactID,_),
1184		  (
1185		   continuous_fact(FactID)
1186		  ->
1187		   (
1188		    atomic_concat(['grad_mu_',FactID],Key),
1189		    atomic_concat(['grad_sigma_',FactID],Key2),
1190
1191	      % if the following query fails,
1192	      % it means, the fact is not used in the proof
1193	      % of QueryID, and the gradient is 0.0 and will
1194	      % not contribute to NewValue either way
1195	      % DON'T FORGET THIS IF YOU CHANGE SOMETHING HERE!
1196		    query_gradient(QueryID,FactID,mu,GradValueMu),
1197		    query_gradient(QueryID,FactID,sigma,GradValueSigma),
1198
1199		    bb_get(Key,OldValueMu),
1200		    bb_get(Key2,OldValueSigma),
1201
1202		    NewValueMu is OldValueMu + Y*GradValueMu,
1203		    NewValueSigma is OldValueSigma + Y*GradValueSigma,
1204
1205		    bb_put(Key,NewValueMu),
1206		    bb_put(Key2,NewValueSigma)
1207		   );
1208		   (
1209		    atomic_concat(['grad_',FactID],Key),
1210
1211	      % if the following query fails,
1212	      % it means, the fact is not used in the proof
1213	      % of QueryID, and the gradient is 0.0 and will
1214	      % not contribute to NewValue either way
1215	      % DON'T FORGET THIS IF YOU CHANGE SOMETHING HERE!
1216		    query_gradient(QueryID,FactID,p,GradValue),
1217
1218		    bb_get(Key,OldValue),
1219		    NewValue is OldValue + Y*GradValue,
1220		    bb_put(Key,NewValue)
1221		   )
1222		  ),
1223
1224				fail; % go to next fact
1225				true
1226		),
1227
1228		once(update_query_cleanup(QueryID))
1229	       )),
1230	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1231	% stop calculate gradient
1232	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1233	!,
1234
1235	close(Handle),
1236
1237	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1238	% start statistics on gradient
1239	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1240	findall(V, (
1241		    tunable_fact(FactID,_),
1242		    atomic_concat(['grad_',FactID],Key),
1243		    bb_get(Key,V)
1244		   ),Gradient_Values),
1245
1246	(
1247	 Gradient_Values==[]
1248	->
1249	 (
1250	  logger_set_variable(gradient_mean,0.0),
1251	  logger_set_variable(gradient_min,0.0),
1252	  logger_set_variable(gradient_max,0.0)
1253	 );
1254	 (
1255	  sum_list(Gradient_Values,GradSum),
1256	  max_list(Gradient_Values,GradMax),
1257	  min_list(Gradient_Values,GradMin),
1258	  length(Gradient_Values,GradLength),
1259	  GradMean is GradSum/GradLength,
1260
1261	  logger_set_variable(gradient_mean,GradMean),
1262	  logger_set_variable(gradient_min,GradMin),
1263	  logger_set_variable(gradient_max,GradMax)
1264	 )
1265	),
1266	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1267	% stop statistics on gradient
1268	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1269
1270	bb_delete(mse_train_sum,MSE_Train_Sum),
1271	bb_delete(mse_train_min,MSE_Train_Min),
1272	bb_delete(mse_train_max,MSE_Train_Max),
1273	bb_delete(llh_training_queries,LLH_Training_Queries),
1274	MSE is MSE_Train_Sum/Example_Count,
1275
1276	logger_set_variable(mse_trainingset,MSE),
1277	logger_set_variable(mse_min_trainingset,MSE_Train_Min),
1278	logger_set_variable(mse_max_trainingset,MSE_Train_Max),
1279	logger_set_variable(llh_training_queries,LLH_Training_Queries),
1280
1281	format_learning(2,'~n',[]),
1282
1283	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1284	% start add gradient to current probabilities
1285	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1286	(
1287	    problog_flag(line_search,false)
1288	->
1289	    problog_flag(learning_rate,LearningRate);
1290	    lineSearch(LearningRate,_)
1291	),
1292	format_learning(3,'learning rate:~8f~n',[LearningRate]),
1293	add_gradient(LearningRate),
1294	logger_set_variable(learning_rate,LearningRate),
1295	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1296	% stop add gradient to current probabilities
1297	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1298	!,
1299	forget_old_probabilities.
1300
1301%========================================================================
1302%=
1303%=
1304%========================================================================
1305
1306line_search_evaluate_point(Learning_Rate,MSE) :-
1307	add_gradient(Learning_Rate),
1308	format_learning(2,'Line search (h=~8f) ',[Learning_Rate]),
1309	mse_trainingset_only_for_linesearch(MSE).
1310
1311
1312lineSearch(Final_X,Final_Value) :-
1313
1314	% Get Parameters for line search
1315	problog_flag(line_search_tolerance,Tol),
1316	problog_flag(line_search_tau,Tau),
1317	problog_flag(line_search_interval,(A,B)),
1318
1319	format_learning(3,'Line search in interval (~4f,~4f)~n',[A,B]),
1320
1321	% init values
1322	Acc is Tol * (B-A),
1323	InitRight is A + Tau*(B-A),
1324	InitLeft is B - Tau*(B-A),
1325
1326	line_search_evaluate_point(A,Value_A),
1327	line_search_evaluate_point(B,Value_B),
1328	line_search_evaluate_point(InitRight,Value_InitRight),
1329	line_search_evaluate_point(InitLeft,Value_InitLeft),
1330
1331
1332	Parameters=ls(A,B,InitLeft,InitRight,Value_A,Value_B,Value_InitLeft,Value_InitRight,1),
1333
1334	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1335	%%%% BEGIN BACK TRACKING
1336	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1337	(
1338	 repeat,
1339
1340	 Parameters=ls(Ak,Bk,Left,Right,Fl,Fr,FLeft,FRight,Iteration),
1341
1342	 (
1343		% check for infinity, if there is, go to the left
1344	  ( FLeft >= FRight, \+ FLeft = (+inf), \+ FRight = (+inf) )
1345	 ->
1346	  (
1347	   AkNew=Left,
1348	   FlNew=FLeft,
1349	   LeftNew=Right,
1350	   FLeftNew=FRight,
1351	   RightNew is Left + Bk - Right,
1352	   line_search_evaluate_point(RightNew,FRightNew),
1353	   BkNew=Bk,
1354	   FrNew=Fr,
1355	   Interval_Size is Bk-Left
1356	  );
1357	  (
1358	   BkNew=Right,
1359	   FrNew=FRight,
1360	   RightNew=Left,
1361	   FRightNew=FLeft,
1362	   LeftNew is Ak + Right - Left,
1363
1364	   line_search_evaluate_point(LeftNew,FLeftNew),
1365	   AkNew=Ak,
1366	   FlNew=Fl,
1367	   Interval_Size is Right-Ak
1368	  )
1369	 ),
1370
1371	 Next_Iteration is Iteration + 1,
1372
1373	 nb_setarg(9,Parameters,Next_Iteration),
1374	 nb_setarg(1,Parameters,AkNew),
1375	 nb_setarg(2,Parameters,BkNew),
1376	 nb_setarg(3,Parameters,LeftNew),
1377	 nb_setarg(4,Parameters,RightNew),
1378	 nb_setarg(5,Parameters,FlNew),
1379	 nb_setarg(6,Parameters,FrNew),
1380	 nb_setarg(7,Parameters,FLeftNew),
1381	 nb_setarg(8,Parameters,FRightNew),
1382
1383				% is the search interval smaller than the tolerance level?
1384	 Interval_Size<Acc,
1385
1386	% apperantly it is, so get me out of here and
1387	% cut away the choice point from repeat
1388	 !
1389	),
1390	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1391	%%%% END BACK TRACKING
1392	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1393
1394
1395
1396	% it doesn't harm to check also the value in the middle
1397	% of the current search interval
1398	Middle is (AkNew + BkNew) / 2.0,
1399	line_search_evaluate_point(Middle,Value_Middle),
1400
1401	% return the optimal value
1402	my_5_min(Value_Middle,FlNew,FrNew,FLeftNew,FRightNew,
1403		 Middle,AkNew,BkNew,LeftNew,RightNew,
1404		 Optimal_Value,Optimal_X),
1405
1406	line_search_postcheck(Optimal_Value,Optimal_X,Final_Value,Final_X).
1407
1408line_search_postcheck(V,X,V,X) :-
1409	X>0,
1410	!.
1411line_search_postcheck(V,X,V,X) :-
1412	problog_flag(line_search_never_stop,false),
1413	!.
1414line_search_postcheck(_,_, LLH, FinalPosition) :-
1415	problog_flag(line_search_tolerance,Tolerance),
1416	problog_flag(line_search_interval,(Left,Right)),
1417
1418
1419	Offset is (Right - Left) * Tolerance,
1420
1421	bb_put(line_search_offset,Offset),
1422
1423	%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1424	(
1425
1426	 repeat,
1427
1428	 bb_get(line_search_offset,OldOffset),
1429	 NewOffset is OldOffset * Tolerance,
1430	 bb_put(line_search_offset,NewOffset),
1431
1432	 Position is Left + NewOffset,
1433	 line_search_evaluate_point(Position,LLH),
1434	 bb_put(line_search_llh,LLH),
1435
1436	 write(logAtom(lineSearchPostCheck(Position,LLH))),nl,
1437
1438
1439	 \+ LLH = (+inf),
1440	 !
1441	),  % cut away choice point from repeat
1442        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
1443
1444	bb_delete(line_search_llh,LLH),
1445	bb_delete(line_search_offset,FinalOffset),
1446	FinalPosition is Left + FinalOffset.
1447
1448
1449
1450my_5_min(V1,V2,V3,V4,V5,F1,F2,F3,F4,F5,VMin,FMin) :-
1451	(
1452	    V1<V2
1453	->
1454	 (VTemp1=V1,FTemp1=F1);
1455	 (VTemp1=V2,FTemp1=F2)
1456	),
1457	(
1458	 V3<V4
1459	->
1460	 (VTemp2=V3,FTemp2=F3);
1461	 (VTemp2=V4,FTemp2=F4)
1462	),
1463	(
1464	 VTemp1<VTemp2
1465	->
1466	 (VTemp3=VTemp1,FTemp3=FTemp1);
1467	 (VTemp3=VTemp2,FTemp3=FTemp2)
1468	),
1469	(
1470	 VTemp3<V5
1471	->
1472	 (VMin=VTemp3,FMin=FTemp3);
1473	 (VMin=V5,FMin=F5)
1474	).
1475
1476
1477
1478%========================================================================
1479%= initialize the logger module and set the flags for learning
1480%= don't change anything here! use set_problog_flag/2 instead
1481%========================================================================
1482
1483init_flags :-
1484	prolog_file_name('queries',Queries_Folder), % get absolute file name for './queries'
1485	prolog_file_name('output',Output_Folder), % get absolute file name for './output'
1486	problog_define_flag(bdd_directory, problog_flag_validate_directory, 'directory for BDD scripts', Queries_Folder,learning_general),
1487	problog_define_flag(output_directory, problog_flag_validate_directory, 'directory for logfiles etc', Output_Folder,learning_general,flags:learning_output_dir_handler),
1488	problog_define_flag(log_frequency, problog_flag_validate_posint, 'log results every nth iteration', 1, learning_general),
1489	problog_define_flag(rebuild_bdds, problog_flag_validate_nonegint, 'rebuild BDDs every nth iteration', 0, learning_general),
1490	problog_define_flag(reuse_initialized_bdds,problog_flag_validate_boolean, 'Reuse BDDs from previous runs',false, learning_general),
1491	problog_define_flag(check_duplicate_bdds,problog_flag_validate_boolean,'Store intermediate results in hash table',true,learning_general),
1492	problog_define_flag(init_method,problog_flag_validate_dummy,'ProbLog predicate to search proofs',(Query,Probability,BDDFile,ProbFile,problog_kbest_save(Query,100,Probability,_Status,BDDFile,ProbFile)),learning_general,flags:learning_init_handler),
1493	problog_define_flag(alpha,problog_flag_validate_number,'weight of negative examples (auto=n_p/n_n)',auto,learning_general,flags:auto_handler),
1494	problog_define_flag(sigmoid_slope,problog_flag_validate_posnumber,'slope of sigmoid function',1.0,learning_general),
1495
1496	problog_define_flag(learning_rate,problog_flag_validate_posnumber,'Default learning rate (If line_search=false)',examples,learning_line_search,flags:examples_handler),
1497	problog_define_flag(line_search, problog_flag_validate_boolean,'estimate learning rate by line search',false,learning_line_search),
1498	problog_define_flag(line_search_never_stop, problog_flag_validate_boolean,'make tiny step if line search returns 0',true,learning_line_search),
1499	problog_define_flag(line_search_tau, problog_flag_validate_indomain_0_1_open,'tau value for line search',0.618033988749,learning_line_search),
1500	problog_define_flag(line_search_tolerance,problog_flag_validate_posnumber,'tolerance value for line search',0.05,learning_line_search),
1501	problog_define_flag(line_search_interval, problog_flag_validate_dummy,'interval for line search',(0,100),learning_line_search,flags:linesearch_interval_handler).
1502
1503
1504init_logger :-
1505	logger_define_variable(iteration, int),
1506	logger_define_variable(duration,time),
1507	logger_define_variable(mse_trainingset,float),
1508	logger_define_variable(mse_min_trainingset,float),
1509	logger_define_variable(mse_max_trainingset,float),
1510	logger_define_variable(mse_testset,float),
1511	logger_define_variable(mse_min_testset,float),
1512	logger_define_variable(mse_max_testset,float),
1513	logger_define_variable(gradient_mean,float),
1514	logger_define_variable(gradient_min,float),
1515	logger_define_variable(gradient_max,float),
1516	logger_define_variable(ground_truth_diff,float),
1517	logger_define_variable(ground_truth_mindiff,float),
1518	logger_define_variable(ground_truth_maxdiff,float),
1519	logger_define_variable(learning_rate,float),
1520	logger_define_variable(alpha,float),
1521	logger_define_variable(llh_training_queries,float),
1522	logger_define_variable(llh_test_queries,float).
1523
1524:- initialization(init_flags).
1525:- initialization(init_logger).
1526
1527