1%%% -*- Mode: Prolog; -*- 2 3%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4% 5% $Date: 2011-04-21 14:18:59 +0200 (Thu, 21 Apr 2011) $ 6% $Revision: 6364 $ 7% 8% This file is part of ProbLog 9% http://dtai.cs.kuleuven.be/problog 10% 11% ProbLog was developed at Katholieke Universiteit Leuven 12% 13% Copyright 2008, 2009, 2010 14% Katholieke Universiteit Leuven 15% 16% Main authors of this file: 17% Bernd Gutmann 18% 19%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 20% 21% Artistic License 2.0 22% 23% Copyright (c) 2000-2006, The Perl Foundation. 24% 25% Everyone is permitted to copy and distribute verbatim copies of this 26% license document, but changing it is not allowed. Preamble 27% 28% This license establishes the terms under which a given free software 29% Package may be copied, modified, distributed, and/or 30% redistributed. The intent is that the Copyright Holder maintains some 31% artistic control over the development of that Package while still 32% keeping the Package available as open source and free software. 33% 34% You are always permitted to make arrangements wholly outside of this 35% license directly with the Copyright Holder of a given Package. If the 36% terms of this license do not permit the full use that you propose to 37% make of the Package, you should contact the Copyright Holder and seek 38% a different licensing arrangement. Definitions 39% 40% "Copyright Holder" means the individual(s) or organization(s) named in 41% the copyright notice for the entire Package. 42% 43% "Contributor" means any party that has contributed code or other 44% material to the Package, in accordance with the Copyright Holder's 45% procedures. 46% 47% "You" and "your" means any person who would like to copy, distribute, 48% or modify the Package. 49% 50% "Package" means the collection of files distributed by the Copyright 51% Holder, and derivatives of that collection and/or of those files. A 52% given Package may consist of either the Standard Version, or a 53% Modified Version. 54% 55% "Distribute" means providing a copy of the Package or making it 56% accessible to anyone else, or in the case of a company or 57% organization, to others outside of your company or organization. 58% 59% "Distributor Fee" means any fee that you charge for Distributing this 60% Package or providing support for this Package to another party. It 61% does not mean licensing fees. 62% 63% "Standard Version" refers to the Package if it has not been modified, 64% or has been modified only in ways explicitly requested by the 65% Copyright Holder. 66% 67% "Modified Version" means the Package, if it has been changed, and such 68% changes were not explicitly requested by the Copyright Holder. 69% 70% "Original License" means this Artistic License as Distributed with the 71% Standard Version of the Package, in its current version or as it may 72% be modified by The Perl Foundation in the future. 73% 74% "Source" form means the source code, documentation source, and 75% configuration files for the Package. 76% 77% "Compiled" form means the compiled bytecode, object code, binary, or 78% any other form resulting from mechanical transformation or translation 79% of the Source form. 80% 81%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 82% 83% Permission for Use and Modification Without Distribution 84% 85% (1) You are permitted to use the Standard Version and create and use 86% Modified Versions for any purpose without restriction, provided that 87% you do not Distribute the Modified Version. 88% 89% Permissions for Redistribution of the Standard Version 90% 91% (2) You may Distribute verbatim copies of the Source form of the 92% Standard Version of this Package in any medium without restriction, 93% either gratis or for a Distributor Fee, provided that you duplicate 94% all of the original copyright notices and associated disclaimers. At 95% your discretion, such verbatim copies may or may not include a 96% Compiled form of the Package. 97% 98% (3) You may apply any bug fixes, portability changes, and other 99% modifications made available from the Copyright Holder. The resulting 100% Package will still be considered the Standard Version, and as such 101% will be subject to the Original License. 102% 103% Distribution of Modified Versions of the Package as Source 104% 105% (4) You may Distribute your Modified Version as Source (either gratis 106% or for a Distributor Fee, and with or without a Compiled form of the 107% Modified Version) provided that you clearly document how it differs 108% from the Standard Version, including, but not limited to, documenting 109% any non-standard features, executables, or modules, and provided that 110% you do at least ONE of the following: 111% 112% (a) make the Modified Version available to the Copyright Holder of the 113% Standard Version, under the Original License, so that the Copyright 114% Holder may include your modifications in the Standard Version. (b) 115% ensure that installation of your Modified Version does not prevent the 116% user installing or running the Standard Version. In addition, the 117% modified Version must bear a name that is different from the name of 118% the Standard Version. (c) allow anyone who receives a copy of the 119% Modified Version to make the Source form of the Modified Version 120% available to others under (i) the Original License or (ii) a license 121% that permits the licensee to freely copy, modify and redistribute the 122% Modified Version using the same licensing terms that apply to the copy 123% that the licensee received, and requires that the Source form of the 124% Modified Version, and of any works derived from it, be made freely 125% available in that license fees are prohibited but Distributor Fees are 126% allowed. 127% 128% Distribution of Compiled Forms of the Standard Version or 129% Modified Versions without the Source 130% 131% (5) You may Distribute Compiled forms of the Standard Version without 132% the Source, provided that you include complete instructions on how to 133% get the Source of the Standard Version. Such instructions must be 134% valid at the time of your distribution. If these instructions, at any 135% time while you are carrying out such distribution, become invalid, you 136% must provide new instructions on demand or cease further 137% distribution. If you provide valid instructions or cease distribution 138% within thirty days after you become aware that the instructions are 139% invalid, then you do not forfeit any of your rights under this 140% license. 141% 142% (6) You may Distribute a Modified Version in Compiled form without the 143% Source, provided that you comply with Section 4 with respect to the 144% Source of the Modified Version. 145% 146% Aggregating or Linking the Package 147% 148% (7) You may aggregate the Package (either the Standard Version or 149% Modified Version) with other packages and Distribute the resulting 150% aggregation provided that you do not charge a licensing fee for the 151% Package. Distributor Fees are permitted, and licensing fees for other 152% components in the aggregation are permitted. The terms of this license 153% apply to the use and Distribution of the Standard or Modified Versions 154% as included in the aggregation. 155% 156% (8) You are permitted to link Modified and Standard Versions with 157% other works, to embed the Package in a larger work of your own, or to 158% build stand-alone binary or bytecode versions of applications that 159% include the Package, and Distribute the result without restriction, 160% provided the result does not expose a direct interface to the Package. 161% 162% Items That are Not Considered Part of a Modified Version 163% 164% (9) Works (including, but not limited to, modules and scripts) that 165% merely extend or make use of the Package, do not, by themselves, cause 166% the Package to be a Modified Version. In addition, such works are not 167% considered parts of the Package itself, and are not subject to the 168% terms of this license. 169% 170% General Provisions 171% 172% (10) Any use, modification, and distribution of the Standard or 173% Modified Versions is governed by this Artistic License. By using, 174% modifying or distributing the Package, you accept this license. Do not 175% use, modify, or distribute the Package, if you do not accept this 176% license. 177% 178% (11) If your Modified Version has been derived from a Modified Version 179% made by someone other than you, you are nevertheless required to 180% ensure that your Modified Version complies with the requirements of 181% this license. 182% 183% (12) This license does not grant you the right to use any trademark, 184% service mark, tradename, or logo of the Copyright Holder. 185% 186% (13) This license includes the non-exclusive, worldwide, 187% free-of-charge patent license to make, have made, use, offer to sell, 188% sell, import and otherwise transfer the Package with respect to any 189% patent claims licensable by the Copyright Holder that are necessarily 190% infringed by the Package. If you institute patent litigation 191% (including a cross-claim or counterclaim) against any party alleging 192% that the Package constitutes direct or contributory patent 193% infringement, then this Artistic License to you shall terminate on the 194% date that such litigation is filed. 195% 196% (14) Disclaimer of Warranty: THE PACKAGE IS PROVIDED BY THE COPYRIGHT 197% HOLDER AND CONTRIBUTORS "AS IS' AND WITHOUT ANY EXPRESS OR IMPLIED 198% WARRANTIES. THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 199% PARTICULAR PURPOSE, OR NON-INFRINGEMENT ARE DISCLAIMED TO THE EXTENT 200% PERMITTED BY YOUR LOCAL LAW. UNLESS REQUIRED BY LAW, NO COPYRIGHT 201% HOLDER OR CONTRIBUTOR WILL BE LIABLE FOR ANY DIRECT, INDIRECT, 202% INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING IN ANY WAY OUT OF THE USE 203% OF THE PACKAGE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 204% 205%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 206 207 208:- module(learning,[do_learning/1, 209 do_learning/2, 210 reset_learning/0 211 ]). 212 213% switch on all the checks to reduce bug searching time 214:- style_check(all). 215:- yap_flag(unknown,error). 216 217% load modules from the YAP library 218:- use_module(library(lists), [max_list/2, min_list/2, sum_list/2]). 219:- use_module(library(system), [file_exists/1, shell/2]). 220 221% load our own modules 222:- use_module(problog). 223:- use_module('problog/logger'). 224:- use_module('problog/flags'). 225:- use_module('problog/os'). 226:- use_module('problog/print_learning'). 227:- use_module('problog/utils_learning'). 228:- use_module('problog/utils'). 229:- use_module('problog/tabling'). 230 231% used to indicate the state of the system 232:- dynamic(values_correct/0). 233:- dynamic(learning_initialized/0). 234:- dynamic(current_iteration/1). 235:- dynamic(example_count/1). 236:- dynamic(query_probability_intern/2). 237:- dynamic(query_gradient_intern/4). 238:- dynamic(last_mse/1). 239:- dynamic(query_is_similar/2). 240:- dynamic(query_md5/2). 241 242 243% used to identify queries which have identical proofs 244:- dynamic(query_is_similar/2). 245:- dynamic(query_md5/3). 246 247:- multifile(user:example/4). 248user:example(A,B,C,=) :- 249 current_predicate(user:example/3), 250 user:example(A,B,C). 251 252:- multifile(user:test_example/4). 253user:test_example(A,B,C,=) :- 254 current_predicate(user:test_example/3), 255 user:test_example(A,B,C). 256 257 258%======================================================================== 259%= store the facts with the learned probabilities to a file 260%======================================================================== 261 262save_model:- 263 current_iteration(Iteration), 264 create_factprobs_file_name(Iteration,Filename), 265 export_facts(Filename). 266 267 268 269 270%======================================================================== 271%= find out whether some example IDs are used more than once 272%= if so, complain and stop 273%= 274%======================================================================== 275 276check_examples :- 277 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 278 % Check example IDs 279 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 280 ( 281 (user:example(ID,_,_,_), \+ atomic(ID)) 282 -> 283 ( 284 format(user_error,'The example id of training example ~q ',[ID]), 285 format(user_error,'is not atomic (e.g foo42, 23, bar, ...).~n',[]), 286 throw(error(examples)) 287 ); true 288 ), 289 290 ( 291 (user:test_example(ID,_,_,_), \+ atomic(ID)) 292 -> 293 ( 294 format(user_error,'The example id of test example ~q ',[ID]), 295 format(user_error,'is not atomic (e.g foo42, 23, bar, ...).~n',[]), 296 throw(error(examples)) 297 ); true 298 ), 299 300 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 301 % Check example probabilities 302 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 303 ( 304 (user:example(ID,_,P,_), (\+ number(P); P>1 ; P<0)) 305 -> 306 ( 307 format(user_error,'The training example ~q does not have a valid probability value (~q).~n',[ID,P]), 308 throw(error(examples)) 309 ); true 310 ), 311 312 ( 313 (user:test_example(ID,_,P,_), (\+ number(P); P>1 ; P<0)) 314 -> 315 ( 316 format(user_error,'The test example ~q does not have a valid probability value (~q).~n',[ID,P]), 317 throw(error(examples)) 318 ); true 319 ), 320 321 322 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 323 % Check that no example ID is repeated, 324 % and if it is repeated make sure the query is the same 325 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 326 ( 327 ( 328 ( 329 user:example(ID,QueryA,_,_), 330 user:example(ID,QueryB,_,_), 331 QueryA \= QueryB 332 ) ; 333 334 ( 335 user:test_example(ID,QueryA,_,_), 336 user:test_example(ID,QueryB,_,_), 337 QueryA \= QueryB 338 ); 339 340 ( 341 user:example(ID,QueryA,_,_), 342 user:test_example(ID,QueryB,_,_), 343 QueryA \= QueryB 344 ) 345 ) 346 -> 347 ( 348 format(user_error,'The example id ~q is used several times.~n',[ID]), 349 throw(error(examples)) 350 ); true 351 ). 352%======================================================================== 353%= 354%======================================================================== 355 356reset_learning :- 357 retractall(learning_initialized), 358 retractall(values_correct), 359 retractall(current_iteration(_)), 360 retractall(example_count(_)), 361 retractall(query_probability_intern(_,_)), 362 retractall(query_gradient_intern(_,_,_)), 363 retractall(last_mse(_)), 364 retractall(query_is_similar(_,_)), 365 retractall(query_md5(_,_,_)), 366 367 set_problog_flag(alpha,auto), 368 set_problog_flag(learning_rate,examples), 369 logger_reset_all_variables. 370 371 372 373%======================================================================== 374%= initialize everything and perform Iterations times gradient descent 375%= can be called several times 376%= if it is called with an epsilon parameter, it stops when the change 377%= in the MSE is smaller than epsilon 378%======================================================================== 379 380do_learning(Iterations) :- 381 do_learning(Iterations,-1). 382 383do_learning(Iterations,Epsilon) :- 384 current_predicate(user:example/4), 385 !, 386 integer(Iterations), 387 number(Epsilon), 388 Iterations>0, 389 do_learning_intern(Iterations,Epsilon). 390do_learning(_,_) :- 391 format(user_error,'~n~Error: No training examples specified.~n~n',[]). 392 393 394do_learning_intern(0,_) :- 395 !. 396do_learning_intern(Iterations,Epsilon) :- 397 Iterations>0, 398 399 init_learning, 400 current_iteration(CurrentIteration), 401 retractall(current_iteration(_)), 402 NextIteration is CurrentIteration+1, 403 assertz(current_iteration(NextIteration)), 404 EndIteration is CurrentIteration+Iterations-1, 405 406 format_learning(1,'~nIteration ~d of ~d~n',[CurrentIteration,EndIteration]), 407 logger_set_variable(iteration,CurrentIteration), 408 409 logger_start_timer(duration), 410 mse_testset, 411 ground_truth_difference, 412 gradient_descent, 413 414 problog_flag(log_frequency,Log_Frequency), 415 416 ( 417 ( Log_Frequency>0, 0 =:= CurrentIteration mod Log_Frequency) 418 -> 419 once(save_model); 420 true 421 ), 422 423 update_values, 424 425 ( 426 last_mse(Last_MSE) 427 -> 428 ( 429 retractall(last_mse(_)), 430 logger_get_variable(mse_trainingset,Current_MSE), 431 assertz(last_mse(Current_MSE)), 432 !, 433 MSE_Diff is abs(Last_MSE-Current_MSE) 434 ); ( 435 logger_get_variable(mse_trainingset,Current_MSE), 436 assertz(last_mse(Current_MSE)), 437 MSE_Diff is Epsilon+1 438 ) 439 ), 440 441 ( 442 (problog_flag(rebuild_bdds,BDDFreq),BDDFreq>0,0 =:= CurrentIteration mod BDDFreq) 443 -> 444 ( 445 retractall(values_correct), 446 retractall(query_is_similar(_,_)), 447 retractall(query_md5(_,_,_)), 448 empty_bdd_directory, 449 init_queries 450 ); true 451 ), 452 453 454 !, 455 logger_stop_timer(duration), 456 457 458 logger_write_data, 459 460 461 462 RemainingIterations is Iterations-1, 463 464 ( 465 MSE_Diff>Epsilon 466 -> 467 do_learning_intern(RemainingIterations,Epsilon); 468 true 469 ). 470 471 472%======================================================================== 473%= find proofs and build bdds for all training and test examples 474%= 475%= 476%======================================================================== 477init_learning :- 478 learning_initialized, 479 !. 480init_learning :- 481 check_examples, 482 483 empty_output_directory, 484 logger_write_header, 485 486 format_learning(1,'Initializing everything~n',[]), 487 488 489 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 490 % Delete the BDDs from the previous run if they should 491 % not be reused 492 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 493 ( 494 ( 495 problog_flag(reuse_initialized_bdds,true), 496 problog_flag(rebuild_bdds,0) 497 ) 498 -> 499 true; 500 empty_bdd_directory 501 ), 502 503 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 504 % Check, if continuous facts are used. 505 % if yes, switch to problog_exact 506 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 507 problog_flag(init_method,(_,_,_,_,OldCall)), 508 ( 509 ( 510 continuous_fact(_), 511 OldCall\=problog_exact_save(_,_,_,_,_) 512 ) 513 -> 514 ( 515 format_learning(2,'Theory uses continuous facts.~nWill use problog_exact/3 as initalization method.~2n',[]), 516 set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile))) 517 ); 518 true 519 ), 520 521 ( 522 problog_tabled(_) 523 -> 524 ( 525 format_learning(2,'Theory uses tabling.~nWill use problog_exact/3 as initalization method.~2n',[]), 526 set_problog_flag(init_method,(Query,Probability,BDDFile,ProbFile,problog_exact_save(Query,Probability,_Status,BDDFile,ProbFile))) 527 ); 528 true 529 ), 530 531 532 succeeds_n_times(user:test_example(_,_,_,_),TestExampleCount), 533 format_learning(3,'~q test examples~n',[TestExampleCount]), 534 535 succeeds_n_times(user:example(_,_,_,_),TrainingExampleCount), 536 assertz(example_count(TrainingExampleCount)), 537 format_learning(3,'~q training examples~n',[TrainingExampleCount]), 538 539 540 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 541 % set learning rate and alpha 542 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 543 ( 544 problog_flag(learning_rate,examples) 545 -> 546 set_problog_flag(learning_rate,TrainingExampleCount); 547 true 548 ), 549 550 ( 551 problog_flag(alpha,auto) 552 -> 553 ( 554 (user:example(_,_,P,_),P<1,P>0) 555 -> 556 set_problog_flag(alpha,1.0); 557 ( 558 succeeds_n_times((user:example(_,_,P,=),P=:=1.0),Pos_Count), 559 succeeds_n_times((user:example(_,_,P,=),P=:=0.0),Neg_Count), 560 Alpha is Pos_Count/Neg_Count, 561 set_problog_flag(alpha,Alpha) 562 ) 563 ) 564 ), 565 566 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 567 % build BDD script for every example 568 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 569 once(init_queries), 570 571 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 572 % done 573 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 574 assertz(current_iteration(0)), 575 assertz(learning_initialized), 576 577 format_learning(1,'~n',[]). 578 579 580 581%======================================================================== 582%= This predicate goes over all training and test examples, 583%= calls the inference method of ProbLog and stores the resulting 584%= BDDs 585%======================================================================== 586 587 588init_queries :- 589 format_learning(2,'Build BDDs for examples~n',[]), 590 forall(user:test_example(ID,Query,_Prob,_),init_one_query(ID,Query,test)), 591 forall(user:example(ID,Query,_Prob,_),init_one_query(ID,Query,training)). 592 593bdd_input_file(Filename) :- 594 problog_flag(output_directory,Dir), 595 concat_path_with_filename(Dir,'input.txt',Filename). 596 597init_one_query(QueryID,Query,Type) :- 598 format_learning(3,' ~q example ~q: ~q~n',[Type,QueryID,Query]), 599 600 bdd_input_file(Probabilities_File), 601 problog_flag(bdd_directory,Query_Directory), 602 603 atomic_concat(['query_',QueryID],Filename1), 604 concat_path_with_filename(Query_Directory,Filename1,Filename), 605 606 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 607 % if BDD file does not exist, call ProbLog 608 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 609 ( 610 file_exists(Filename) 611 -> 612 format_learning(3,' Reuse existing BDD ~q~n~n',[Filename]); 613 ( 614 problog_flag(init_method,(Query,_Prob,Filename,Probabilities_File,Call)), 615 once(Call), 616 delete_file_silently(Probabilities_File) 617 ) 618 ), 619 620 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 621 % check wether this BDD is similar to another BDD 622 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 623 ( 624 problog_flag(check_duplicate_bdds,true) 625 -> 626 ( 627 calc_md5(Filename,Query_MD5), 628 ( 629 query_md5(OtherQueryID,Query_MD5,Type) 630 -> 631 ( 632 assertz(query_is_similar(QueryID,OtherQueryID)), 633 format_learning(3, '~q is similar to ~q~2n', [QueryID,OtherQueryID]) 634 ); 635 assertz(query_md5(QueryID,Query_MD5,Type)) 636 ) 637 ); 638 639 true 640 ),!, 641 garbage_collect. 642 643 644 645 646%======================================================================== 647%= updates all values of query_probability/2 and query_gradient/4 648%= should be called always before these predicates are accessed 649%= if the old values are still valid, nothing happens 650%======================================================================== 651 652update_values :- 653 values_correct, 654 !. 655update_values :- 656 \+ values_correct, 657 658 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 659 % delete old values 660 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 661 retractall(query_probability_intern(_,_)), 662 retractall(query_gradient_intern(_,_,_,_)), 663 664 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 665 % start write current probabilities to file 666 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 667 bdd_input_file(Probabilities_File), 668 delete_file_silently(Probabilities_File), 669 670 open(Probabilities_File,'write',Handle), 671 672 forall(get_fact_probability(ID,Prob), 673 ( 674 (problog:dynamic_probability_fact(ID) -> 675 get_fact(ID, Term), 676 forall(grounding_is_known(Term, GID), ( 677 problog:dynamic_probability_fact_extract(Term, Prob2), 678 inv_sigmoid(Prob2,Value), 679 format(Handle, '@x~q_~q~n~10f~n', [ID,GID, Value]))) 680 ; non_ground_fact(ID) -> 681 inv_sigmoid(Prob,Value), 682 format(Handle,'@x~q_*~n~10f~n',[ID,Value]) 683 ; 684 inv_sigmoid(Prob,Value), 685 format(Handle,'@x~q~n~10f~n',[ID,Value]) 686 ) 687 )), 688 689 forall(get_continuous_fact_parameters(ID,gaussian(Mu,Sigma)), 690 format(Handle,'@x~q_*~n0~n0~n~10f;~10f~n',[ID,Mu,Sigma])), 691 692 close(Handle), 693 !, 694 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 695 % stop write current probabilities to file 696 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 697 698 assertz(values_correct). 699 700 701 702%======================================================================== 703%= 704%= 705%= 706%======================================================================== 707 708update_query_cleanup(QueryID) :- 709 ( 710 (query_is_similar(QueryID,_) ; query_is_similar(_,QueryID)) 711 -> 712 % either this query is similar to another or vice versa, 713 % therefore we don't delete anything 714 true; 715 retractall(query_gradient_intern(QueryID,_,_,_)) 716 ). 717 718 719update_query(QueryID,Symbol,What_To_Update) :- 720 % fixme OS trouble 721 problog_flag(output_directory,Output_Directory), 722 problog_flag(bdd_directory,Query_Directory), 723 bdd_input_file(Probabilities_File), 724 ( 725 query_is_similar(QueryID,_) 726 -> 727 % we don't have to evaluate the BDD 728 format_learning(4,'#',[]); 729 ( 730 problog_flag(sigmoid_slope,Slope), 731 problog_dir(PD), 732 ((What_To_Update=all;query_is_similar(_,QueryID)) -> Method='g' ; Method='l'), 733 atomic_concat([PD, 734 '/problogbdd', 735 ' -i "', Probabilities_File, '"', 736 ' -l "', Query_Directory,'/query_',QueryID, '"', 737 ' -m ', Method, 738 ' -id ', QueryID, 739 ' -sl ', Slope, 740 ' > "', 741 Output_Directory, 742 'values.pl"'],Command), 743 shell(Command,Error), 744 745 746 ( 747 Error = 2 748 -> 749 throw(error('SimpleCUDD has been interrupted.')); 750 true 751 ), 752 ( 753 Error \= 0 754 -> 755 ( 756 format(user_error,'SimpleCUDD stopped with error code ~q, command was ~q~n',[Error, shell(Command,Error)]), 757 throw(bdd_error(QueryID,Error))); 758 true 759 ), 760 atomic_concat([Output_Directory,'values.pl'],Values_Filename), 761 ( 762 file_exists(Values_Filename) 763 -> 764 ( 765 ( 766 once(my_load(Values_Filename,QueryID)) 767 -> 768 true; 769 ( 770 format(user_error,'ERROR: Tried to read the file ~q but my_load/1 fails.~n~q.~2n',[Values_Filename,update_query(QueryID,Symbol,What_To_Update)]), 771 throw(error(my_load_fails)) 772 ) 773 ); 774 ( 775 format(user_error,'ERROR: Tried to read the file ~q but it does not exist.~n~q.~2n',[Values_Filename,update_query(QueryID,Symbol,What_To_Update)]), 776 throw(error(output_file_does_not_exist)) 777 ) 778 ) 779 ), 780 781 delete_file_silently(Values_Filename), 782 format_learning(4,'~w',[Symbol]) 783 ) 784 ). 785 786 787%======================================================================== 788%= This predicate reads probability and gradient values from the file 789%= the gradient ID is a mere check to uncover hidden bugs 790%= +Filename +QueryID -QueryProbability 791%======================================================================== 792 793my_load(File,QueryID) :- 794 open(File,'read',Handle), 795 read(Handle,Atom), 796 once(my_load_intern(Atom,Handle,QueryID)), 797 close(Handle). 798my_load(File,QueryID) :- 799 format(user_error,'Error at ~q.~2n',[my_load(File,QueryID)]), 800 throw(error(my_load(File,QueryID))). 801 802my_load_intern(end_of_file,_,_) :- 803 !. 804my_load_intern(query_probability(QueryID,Prob),Handle,QueryID) :- 805 !, 806 assertz(query_probability_intern(QueryID,Prob)), 807 read(Handle,X), 808 my_load_intern(X,Handle,QueryID). 809my_load_intern(query_gradient(QueryID,XFactID,Type,Value),Handle,QueryID) :- 810 !, 811 atomic_concat(x,StringFactID,XFactID), 812 atom_number(StringFactID,FactID), 813 assertz(query_gradient_intern(QueryID,FactID,Type,Value)), 814 read(Handle,X), 815 my_load_intern(X,Handle,QueryID). 816my_load_intern(X,Handle,QueryID) :- 817 format(user_error,'Unknown atom ~q in results file.~n',[X]), 818 read(Handle,X2), 819 my_load_intern(X2,Handle,QueryID). 820 821 822 823 824%======================================================================== 825%= 826%= 827%= 828%======================================================================== 829query_probability(QueryID,Prob) :- 830 ( 831 query_probability_intern(QueryID,Prob) 832 -> 833 true; 834 ( 835 query_is_similar(QueryID,OtherQueryID), 836 query_probability_intern(OtherQueryID,Prob) 837 ) 838 ). 839query_gradient(QueryID,Fact,Type,Value) :- 840 ( 841 query_gradient_intern(QueryID,Fact,Type,Value) 842 -> 843 true; 844 ( 845 query_is_similar(QueryID,OtherQueryID), 846 query_gradient_intern(OtherQueryID,Fact,Type,Value) 847 ) 848 ). 849 850%======================================================================== 851%= 852%= 853%= 854%======================================================================== 855 856 857 858% FIXME 859ground_truth_difference :- 860 findall(Diff,(tunable_fact(FactID,GroundTruth), 861 \+continuous_fact(FactID), 862 \+ var(GroundTruth), 863 get_fact_probability(FactID,Prob), 864 Diff is abs(GroundTruth-Prob)),AllDiffs), 865 ( 866 AllDiffs=[] 867 -> 868 ( 869 MinDiff=0.0, 870 MaxDiff=0.0, 871 DiffMean=0.0 872 ) ; 873 ( 874 length(AllDiffs,Len), 875 sum_list(AllDiffs,AllDiffsSum), 876 min_list(AllDiffs,MinDiff), 877 max_list(AllDiffs,MaxDiff), 878 DiffMean is AllDiffsSum/Len 879 ) 880 ), 881 882 logger_set_variable(ground_truth_diff,DiffMean), 883 logger_set_variable(ground_truth_mindiff,MinDiff), 884 logger_set_variable(ground_truth_maxdiff,MaxDiff). 885 886%======================================================================== 887%= Calculates the mse of training and test data 888%= 889%= -Float 890%======================================================================== 891 892mse_trainingset_only_for_linesearch(MSE) :- 893 update_values, 894 895 example_count(Example_Count), 896 897 bb_put(error_train_line_search,0.0), 898 forall(user:example(QueryID,_Query,QueryProb,Type), 899 ( 900 once(update_query(QueryID,'.',probability)), 901 query_probability(QueryID,CurrentProb), 902 once(update_query_cleanup(QueryID)), 903 ( 904 (Type == '='; (Type == '<', CurrentProb>QueryProb); (Type=='>',CurrentProb<QueryProb)) 905 -> 906 ( 907 bb_get(error_train_line_search,Old_Error), 908 New_Error is Old_Error + (CurrentProb-QueryProb)**2, 909 bb_put(error_train_line_search,New_Error) 910 );true 911 ) 912 ) 913 ), 914 bb_delete(error_train_line_search,Error), 915 MSE is Error/Example_Count, 916 format_learning(3,' (~8f)~n',[MSE]), 917 retractall(values_correct). 918 919mse_testset :- 920 current_iteration(Iteration), 921 create_test_predictions_file_name(Iteration,File_Name), 922 open(File_Name,'write',Handle), 923 format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]), 924 format(Handle,"% Iteration, train/test, QueryID, Query, GroundTruth, Prediction %~n",[]), 925 format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]), 926 927 format_learning(2,'MSE_Test ',[]), 928 update_values, 929 bb_put(llh_test_queries,0.0), 930 findall(SquaredError, 931 (user:test_example(QueryID,Query,TrueQueryProb,Type), 932 once(update_query(QueryID,'+',probability)), 933 query_probability(QueryID,CurrentProb), 934 format(Handle,'ex(~q,test,~q,~q,~10f,~10f).~n',[Iteration,QueryID,Query,TrueQueryProb,CurrentProb]), 935 once(update_query_cleanup(QueryID)), 936 ( 937 (Type == '='; (Type == '<', CurrentProb>QueryProb); (Type=='>',CurrentProb<QueryProb)) 938 -> 939 SquaredError is (CurrentProb-TrueQueryProb)**2; 940 SquaredError = 0.0 941 ), 942 bb_get(llh_test_queries,Old_LLH_Test_Queries), 943 New_LLH_Test_Queries is Old_LLH_Test_Queries+log(CurrentProb), 944 bb_put(llh_test_queries,New_LLH_Test_Queries) 945 ), 946 AllSquaredErrors), 947 948 close(Handle), 949 bb_delete(llh_test_queries,LLH_Test_Queries), 950 951 length(AllSquaredErrors,Length), 952 953 ( 954 Length>0 955 -> 956 ( 957 sum_list(AllSquaredErrors,SumAllSquaredErrors), 958 min_list(AllSquaredErrors,MinError), 959 max_list(AllSquaredErrors,MaxError), 960 MSE is SumAllSquaredErrors/Length 961 );( 962 MSE=0.0, 963 MinError=0.0, 964 MaxError=0.0 965 ) 966 ), 967 968 logger_set_variable(mse_testset,MSE), 969 logger_set_variable(mse_min_testset,MinError), 970 logger_set_variable(mse_max_testset,MaxError), 971 logger_set_variable(llh_test_queries,LLH_Test_Queries), 972 format_learning(2,' (~8f)~n',[MSE]). 973 974%======================================================================== 975%= Calculates the sigmoid function respectivly the inverse of it 976%= warning: applying inv_sigmoid to 0.0 or 1.0 will yield +/-inf 977%= 978%= +Float, -Float 979%======================================================================== 980 981sigmoid(T,Sig) :- 982 problog_flag(sigmoid_slope,Slope), 983 Sig is 1/(1+exp(-T*Slope)). 984 985inv_sigmoid(T,InvSig) :- 986 problog_flag(sigmoid_slope,Slope), 987 InvSig is -log(1/T-1)/Slope. 988 989 990 991 992 993 994%======================================================================== 995%= Perform one iteration of gradient descent 996%= 997%= assumes that everything is initialized, if the current values 998%= of query_probability/2 and query_gradient/4 are not up to date 999%= they will be recalculated 1000%= finally, the values_correct/0 is retracted to signal that the 1001%= probabilities of the examples have to be recalculated 1002%======================================================================== 1003 1004save_old_probabilities :- 1005 forall(tunable_fact(FactID,_), 1006 ( 1007 continuous_fact(FactID) 1008 -> 1009 ( 1010 get_continuous_fact_parameters(FactID,gaussian(OldMu,OldSigma)), 1011 atomic_concat(['old_mu_',FactID],Key), 1012 atomic_concat(['old_sigma_',FactID],Key2), 1013 bb_put(Key,OldMu), 1014 bb_put(Key2,OldSigma) 1015 ); 1016 ( 1017 get_fact_probability(FactID,OldProbability), 1018 atomic_concat(['old_prob_',FactID],Key), 1019 bb_put(Key,OldProbability) 1020 ) 1021 ) 1022 ). 1023 1024 1025 1026forget_old_probabilities :- 1027 forall(tunable_fact(FactID,_), 1028 ( 1029 continuous_fact(FactID) 1030 -> 1031 ( 1032 atomic_concat(['old_mu_',FactID],Key), 1033 atomic_concat(['old_sigma_',FactID],Key2), 1034 atomic_concat(['grad_mu_',FactID],Key3), 1035 atomic_concat(['grad_sigma_',FactID],Key4), 1036 bb_delete(Key,_), 1037 bb_delete(Key2,_), 1038 bb_delete(Key3,_), 1039 bb_delete(Key4,_) 1040 ); 1041 ( 1042 atomic_concat(['old_prob_',FactID],Key), 1043 atomic_concat(['grad_',FactID],Key2), 1044 bb_delete(Key,_), 1045 bb_delete(Key2,_) 1046 ) 1047 ) 1048 ). 1049 1050add_gradient(Learning_Rate) :- 1051 forall(tunable_fact(FactID,_), 1052 ( 1053 continuous_fact(FactID) 1054 -> 1055 ( 1056 atomic_concat(['old_mu_',FactID],Key), 1057 atomic_concat(['old_sigma_',FactID],Key2), 1058 atomic_concat(['grad_mu_',FactID],Key3), 1059 atomic_concat(['grad_sigma_',FactID],Key4), 1060 1061 bb_get(Key,Old_Mu), 1062 bb_get(Key2,Old_Sigma), 1063 bb_get(Key3,Grad_Mu), 1064 bb_get(Key4,Grad_Sigma), 1065 1066 Mu is Old_Mu -Learning_Rate* Grad_Mu, 1067 Sigma is exp(log(Old_Sigma) -Learning_Rate* Grad_Sigma), 1068 1069 set_continuous_fact_parameters(FactID,gaussian(Mu,Sigma)) 1070 ); 1071 ( 1072 atomic_concat(['old_prob_',FactID],Key), 1073 atomic_concat(['grad_',FactID],Key2), 1074 1075 bb_get(Key,OldProbability), 1076 bb_get(Key2,GradValue), 1077 1078 inv_sigmoid(OldProbability,OldValue), 1079 NewValue is OldValue -Learning_Rate*GradValue, 1080 sigmoid(NewValue,NewProbability), 1081 1082 % Prevent "inf" by using values too close to 1.0 1083 Prob_Secure is min(0.999999999,max(0.000000001,NewProbability)), 1084 set_fact_probability(FactID,Prob_Secure) 1085 ) 1086 ) 1087 ), 1088 retractall(values_correct). 1089 1090 1091gradient_descent :- 1092 current_iteration(Iteration), 1093 create_training_predictions_file_name(Iteration,File_Name), 1094 open(File_Name,'write',Handle), 1095 format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]), 1096 format(Handle,"% Iteration, train/test, QueryID, Query, GroundTruth, Prediction %~n",[]), 1097 format(Handle,"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%~n",[]), 1098 1099 format_learning(2,'Gradient ',[]), 1100 1101 save_old_probabilities, 1102 update_values, 1103 1104 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1105 % start set gradient to zero 1106 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1107 forall(tunable_fact(FactID,_), 1108 ( 1109 continuous_fact(FactID) 1110 -> 1111 1112 ( 1113 atomic_concat(['grad_mu_',FactID],Key), 1114 atomic_concat(['grad_sigma_',FactID],Key2), 1115 bb_put(Key,0.0), 1116 bb_put(Key2,0.0) 1117 ); 1118 ( 1119 atomic_concat(['grad_',FactID],Key), 1120 bb_put(Key,0.0) 1121 ) 1122 ) 1123 ), 1124 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1125 % stop gradient to zero 1126 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1127 1128 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1129 % start calculate gradient 1130 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1131 bb_put(mse_train_sum, 0.0), 1132 bb_put(mse_train_min, 0.0), 1133 bb_put(mse_train_max, 0.0), 1134 bb_put(llh_training_queries, 0.0), 1135 1136 problog_flag(alpha,Alpha), 1137 logger_set_variable(alpha,Alpha), 1138 example_count(Example_Count), 1139 1140 forall(user:example(QueryID,Query,QueryProb,Type), 1141 ( 1142 once(update_query(QueryID,'.',all)), 1143 query_probability(QueryID,BDDProb), 1144 format(Handle,'ex(~q,train,~q,~q,~10f,~10f).~n',[Iteration,QueryID,Query,QueryProb,BDDProb]), 1145 ( 1146 QueryProb=:=0.0 1147 -> 1148 Y2=Alpha; 1149 Y2=1.0 1150 ), 1151 ( 1152 (Type == '='; (Type == '<', BDDProb>QueryProb); (Type=='>',BDDProb<QueryProb)) 1153 -> 1154 Y is Y2*2/Example_Count * (BDDProb-QueryProb); 1155 Y=0.0 1156 ), 1157 1158 1159 % first do the calculations for the MSE on training set 1160 ( 1161 (Type == '='; (Type == '<', BDDProb>QueryProb); (Type=='>',BDDProb<QueryProb)) 1162 -> 1163 Squared_Error is (BDDProb-QueryProb)**2; 1164 Squared_Error=0.0 1165 ), 1166 1167 bb_get(mse_train_sum,Old_MSE_Train_Sum), 1168 bb_get(mse_train_min,Old_MSE_Train_Min), 1169 bb_get(mse_train_max,Old_MSE_Train_Max), 1170 bb_get(llh_training_queries,Old_LLH_Training_Queries), 1171 New_MSE_Train_Sum is Old_MSE_Train_Sum+Squared_Error, 1172 New_MSE_Train_Min is min(Old_MSE_Train_Min,Squared_Error), 1173 New_MSE_Train_Max is max(Old_MSE_Train_Max,Squared_Error), 1174 New_LLH_Training_Queries is Old_LLH_Training_Queries+log(BDDProb), 1175 bb_put(mse_train_sum,New_MSE_Train_Sum), 1176 bb_put(mse_train_min,New_MSE_Train_Min), 1177 bb_put(mse_train_max,New_MSE_Train_Max), 1178 bb_put(llh_training_queries,New_LLH_Training_Queries), 1179 1180 1181 1182 ( % go over all tunable facts 1183 tunable_fact(FactID,_), 1184 ( 1185 continuous_fact(FactID) 1186 -> 1187 ( 1188 atomic_concat(['grad_mu_',FactID],Key), 1189 atomic_concat(['grad_sigma_',FactID],Key2), 1190 1191 % if the following query fails, 1192 % it means, the fact is not used in the proof 1193 % of QueryID, and the gradient is 0.0 and will 1194 % not contribute to NewValue either way 1195 % DON'T FORGET THIS IF YOU CHANGE SOMETHING HERE! 1196 query_gradient(QueryID,FactID,mu,GradValueMu), 1197 query_gradient(QueryID,FactID,sigma,GradValueSigma), 1198 1199 bb_get(Key,OldValueMu), 1200 bb_get(Key2,OldValueSigma), 1201 1202 NewValueMu is OldValueMu + Y*GradValueMu, 1203 NewValueSigma is OldValueSigma + Y*GradValueSigma, 1204 1205 bb_put(Key,NewValueMu), 1206 bb_put(Key2,NewValueSigma) 1207 ); 1208 ( 1209 atomic_concat(['grad_',FactID],Key), 1210 1211 % if the following query fails, 1212 % it means, the fact is not used in the proof 1213 % of QueryID, and the gradient is 0.0 and will 1214 % not contribute to NewValue either way 1215 % DON'T FORGET THIS IF YOU CHANGE SOMETHING HERE! 1216 query_gradient(QueryID,FactID,p,GradValue), 1217 1218 bb_get(Key,OldValue), 1219 NewValue is OldValue + Y*GradValue, 1220 bb_put(Key,NewValue) 1221 ) 1222 ), 1223 1224 fail; % go to next fact 1225 true 1226 ), 1227 1228 once(update_query_cleanup(QueryID)) 1229 )), 1230 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1231 % stop calculate gradient 1232 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1233 !, 1234 1235 close(Handle), 1236 1237 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1238 % start statistics on gradient 1239 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1240 findall(V, ( 1241 tunable_fact(FactID,_), 1242 atomic_concat(['grad_',FactID],Key), 1243 bb_get(Key,V) 1244 ),Gradient_Values), 1245 1246 ( 1247 Gradient_Values==[] 1248 -> 1249 ( 1250 logger_set_variable(gradient_mean,0.0), 1251 logger_set_variable(gradient_min,0.0), 1252 logger_set_variable(gradient_max,0.0) 1253 ); 1254 ( 1255 sum_list(Gradient_Values,GradSum), 1256 max_list(Gradient_Values,GradMax), 1257 min_list(Gradient_Values,GradMin), 1258 length(Gradient_Values,GradLength), 1259 GradMean is GradSum/GradLength, 1260 1261 logger_set_variable(gradient_mean,GradMean), 1262 logger_set_variable(gradient_min,GradMin), 1263 logger_set_variable(gradient_max,GradMax) 1264 ) 1265 ), 1266 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1267 % stop statistics on gradient 1268 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1269 1270 bb_delete(mse_train_sum,MSE_Train_Sum), 1271 bb_delete(mse_train_min,MSE_Train_Min), 1272 bb_delete(mse_train_max,MSE_Train_Max), 1273 bb_delete(llh_training_queries,LLH_Training_Queries), 1274 MSE is MSE_Train_Sum/Example_Count, 1275 1276 logger_set_variable(mse_trainingset,MSE), 1277 logger_set_variable(mse_min_trainingset,MSE_Train_Min), 1278 logger_set_variable(mse_max_trainingset,MSE_Train_Max), 1279 logger_set_variable(llh_training_queries,LLH_Training_Queries), 1280 1281 format_learning(2,'~n',[]), 1282 1283 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1284 % start add gradient to current probabilities 1285 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1286 ( 1287 problog_flag(line_search,false) 1288 -> 1289 problog_flag(learning_rate,LearningRate); 1290 lineSearch(LearningRate,_) 1291 ), 1292 format_learning(3,'learning rate:~8f~n',[LearningRate]), 1293 add_gradient(LearningRate), 1294 logger_set_variable(learning_rate,LearningRate), 1295 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1296 % stop add gradient to current probabilities 1297 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1298 !, 1299 forget_old_probabilities. 1300 1301%======================================================================== 1302%= 1303%= 1304%======================================================================== 1305 1306line_search_evaluate_point(Learning_Rate,MSE) :- 1307 add_gradient(Learning_Rate), 1308 format_learning(2,'Line search (h=~8f) ',[Learning_Rate]), 1309 mse_trainingset_only_for_linesearch(MSE). 1310 1311 1312lineSearch(Final_X,Final_Value) :- 1313 1314 % Get Parameters for line search 1315 problog_flag(line_search_tolerance,Tol), 1316 problog_flag(line_search_tau,Tau), 1317 problog_flag(line_search_interval,(A,B)), 1318 1319 format_learning(3,'Line search in interval (~4f,~4f)~n',[A,B]), 1320 1321 % init values 1322 Acc is Tol * (B-A), 1323 InitRight is A + Tau*(B-A), 1324 InitLeft is B - Tau*(B-A), 1325 1326 line_search_evaluate_point(A,Value_A), 1327 line_search_evaluate_point(B,Value_B), 1328 line_search_evaluate_point(InitRight,Value_InitRight), 1329 line_search_evaluate_point(InitLeft,Value_InitLeft), 1330 1331 1332 Parameters=ls(A,B,InitLeft,InitRight,Value_A,Value_B,Value_InitLeft,Value_InitRight,1), 1333 1334 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1335 %%%% BEGIN BACK TRACKING 1336 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1337 ( 1338 repeat, 1339 1340 Parameters=ls(Ak,Bk,Left,Right,Fl,Fr,FLeft,FRight,Iteration), 1341 1342 ( 1343 % check for infinity, if there is, go to the left 1344 ( FLeft >= FRight, \+ FLeft = (+inf), \+ FRight = (+inf) ) 1345 -> 1346 ( 1347 AkNew=Left, 1348 FlNew=FLeft, 1349 LeftNew=Right, 1350 FLeftNew=FRight, 1351 RightNew is Left + Bk - Right, 1352 line_search_evaluate_point(RightNew,FRightNew), 1353 BkNew=Bk, 1354 FrNew=Fr, 1355 Interval_Size is Bk-Left 1356 ); 1357 ( 1358 BkNew=Right, 1359 FrNew=FRight, 1360 RightNew=Left, 1361 FRightNew=FLeft, 1362 LeftNew is Ak + Right - Left, 1363 1364 line_search_evaluate_point(LeftNew,FLeftNew), 1365 AkNew=Ak, 1366 FlNew=Fl, 1367 Interval_Size is Right-Ak 1368 ) 1369 ), 1370 1371 Next_Iteration is Iteration + 1, 1372 1373 nb_setarg(9,Parameters,Next_Iteration), 1374 nb_setarg(1,Parameters,AkNew), 1375 nb_setarg(2,Parameters,BkNew), 1376 nb_setarg(3,Parameters,LeftNew), 1377 nb_setarg(4,Parameters,RightNew), 1378 nb_setarg(5,Parameters,FlNew), 1379 nb_setarg(6,Parameters,FrNew), 1380 nb_setarg(7,Parameters,FLeftNew), 1381 nb_setarg(8,Parameters,FRightNew), 1382 1383 % is the search interval smaller than the tolerance level? 1384 Interval_Size<Acc, 1385 1386 % apperantly it is, so get me out of here and 1387 % cut away the choice point from repeat 1388 ! 1389 ), 1390 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1391 %%%% END BACK TRACKING 1392 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1393 1394 1395 1396 % it doesn't harm to check also the value in the middle 1397 % of the current search interval 1398 Middle is (AkNew + BkNew) / 2.0, 1399 line_search_evaluate_point(Middle,Value_Middle), 1400 1401 % return the optimal value 1402 my_5_min(Value_Middle,FlNew,FrNew,FLeftNew,FRightNew, 1403 Middle,AkNew,BkNew,LeftNew,RightNew, 1404 Optimal_Value,Optimal_X), 1405 1406 line_search_postcheck(Optimal_Value,Optimal_X,Final_Value,Final_X). 1407 1408line_search_postcheck(V,X,V,X) :- 1409 X>0, 1410 !. 1411line_search_postcheck(V,X,V,X) :- 1412 problog_flag(line_search_never_stop,false), 1413 !. 1414line_search_postcheck(_,_, LLH, FinalPosition) :- 1415 problog_flag(line_search_tolerance,Tolerance), 1416 problog_flag(line_search_interval,(Left,Right)), 1417 1418 1419 Offset is (Right - Left) * Tolerance, 1420 1421 bb_put(line_search_offset,Offset), 1422 1423 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1424 ( 1425 1426 repeat, 1427 1428 bb_get(line_search_offset,OldOffset), 1429 NewOffset is OldOffset * Tolerance, 1430 bb_put(line_search_offset,NewOffset), 1431 1432 Position is Left + NewOffset, 1433 line_search_evaluate_point(Position,LLH), 1434 bb_put(line_search_llh,LLH), 1435 1436 write(logAtom(lineSearchPostCheck(Position,LLH))),nl, 1437 1438 1439 \+ LLH = (+inf), 1440 ! 1441 ), % cut away choice point from repeat 1442 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 1443 1444 bb_delete(line_search_llh,LLH), 1445 bb_delete(line_search_offset,FinalOffset), 1446 FinalPosition is Left + FinalOffset. 1447 1448 1449 1450my_5_min(V1,V2,V3,V4,V5,F1,F2,F3,F4,F5,VMin,FMin) :- 1451 ( 1452 V1<V2 1453 -> 1454 (VTemp1=V1,FTemp1=F1); 1455 (VTemp1=V2,FTemp1=F2) 1456 ), 1457 ( 1458 V3<V4 1459 -> 1460 (VTemp2=V3,FTemp2=F3); 1461 (VTemp2=V4,FTemp2=F4) 1462 ), 1463 ( 1464 VTemp1<VTemp2 1465 -> 1466 (VTemp3=VTemp1,FTemp3=FTemp1); 1467 (VTemp3=VTemp2,FTemp3=FTemp2) 1468 ), 1469 ( 1470 VTemp3<V5 1471 -> 1472 (VMin=VTemp3,FMin=FTemp3); 1473 (VMin=V5,FMin=F5) 1474 ). 1475 1476 1477 1478%======================================================================== 1479%= initialize the logger module and set the flags for learning 1480%= don't change anything here! use set_problog_flag/2 instead 1481%======================================================================== 1482 1483init_flags :- 1484 prolog_file_name('queries',Queries_Folder), % get absolute file name for './queries' 1485 prolog_file_name('output',Output_Folder), % get absolute file name for './output' 1486 problog_define_flag(bdd_directory, problog_flag_validate_directory, 'directory for BDD scripts', Queries_Folder,learning_general), 1487 problog_define_flag(output_directory, problog_flag_validate_directory, 'directory for logfiles etc', Output_Folder,learning_general,flags:learning_output_dir_handler), 1488 problog_define_flag(log_frequency, problog_flag_validate_posint, 'log results every nth iteration', 1, learning_general), 1489 problog_define_flag(rebuild_bdds, problog_flag_validate_nonegint, 'rebuild BDDs every nth iteration', 0, learning_general), 1490 problog_define_flag(reuse_initialized_bdds,problog_flag_validate_boolean, 'Reuse BDDs from previous runs',false, learning_general), 1491 problog_define_flag(check_duplicate_bdds,problog_flag_validate_boolean,'Store intermediate results in hash table',true,learning_general), 1492 problog_define_flag(init_method,problog_flag_validate_dummy,'ProbLog predicate to search proofs',(Query,Probability,BDDFile,ProbFile,problog_kbest_save(Query,100,Probability,_Status,BDDFile,ProbFile)),learning_general,flags:learning_init_handler), 1493 problog_define_flag(alpha,problog_flag_validate_number,'weight of negative examples (auto=n_p/n_n)',auto,learning_general,flags:auto_handler), 1494 problog_define_flag(sigmoid_slope,problog_flag_validate_posnumber,'slope of sigmoid function',1.0,learning_general), 1495 1496 problog_define_flag(learning_rate,problog_flag_validate_posnumber,'Default learning rate (If line_search=false)',examples,learning_line_search,flags:examples_handler), 1497 problog_define_flag(line_search, problog_flag_validate_boolean,'estimate learning rate by line search',false,learning_line_search), 1498 problog_define_flag(line_search_never_stop, problog_flag_validate_boolean,'make tiny step if line search returns 0',true,learning_line_search), 1499 problog_define_flag(line_search_tau, problog_flag_validate_indomain_0_1_open,'tau value for line search',0.618033988749,learning_line_search), 1500 problog_define_flag(line_search_tolerance,problog_flag_validate_posnumber,'tolerance value for line search',0.05,learning_line_search), 1501 problog_define_flag(line_search_interval, problog_flag_validate_dummy,'interval for line search',(0,100),learning_line_search,flags:linesearch_interval_handler). 1502 1503 1504init_logger :- 1505 logger_define_variable(iteration, int), 1506 logger_define_variable(duration,time), 1507 logger_define_variable(mse_trainingset,float), 1508 logger_define_variable(mse_min_trainingset,float), 1509 logger_define_variable(mse_max_trainingset,float), 1510 logger_define_variable(mse_testset,float), 1511 logger_define_variable(mse_min_testset,float), 1512 logger_define_variable(mse_max_testset,float), 1513 logger_define_variable(gradient_mean,float), 1514 logger_define_variable(gradient_min,float), 1515 logger_define_variable(gradient_max,float), 1516 logger_define_variable(ground_truth_diff,float), 1517 logger_define_variable(ground_truth_mindiff,float), 1518 logger_define_variable(ground_truth_maxdiff,float), 1519 logger_define_variable(learning_rate,float), 1520 logger_define_variable(alpha,float), 1521 logger_define_variable(llh_training_queries,float), 1522 logger_define_variable(llh_test_queries,float). 1523 1524:- initialization(init_flags). 1525:- initialization(init_logger). 1526 1527