1 #include "network.h"
2 #include "utils.h"
3 #include "parser.h"
4 #include "option_list.h"
5 #include "blas.h"
6 
7 
8 int inverted = 1;
9 int noi = 1;
10 //static const unsigned int n_ind = 5;
11 #define n_ind 5
12 
13 typedef struct {
14     char **data;
15     int n;
16 } moves;
17 
fgetgo(FILE * fp)18 char *fgetgo(FILE *fp)
19 {
20     if(feof(fp)) return 0;
21     size_t size = 94;
22     char* line = (char*)xmalloc(size * sizeof(char));
23     if(size != fread(line, sizeof(char), size, fp)){
24         free(line);
25         return 0;
26     }
27 
28     return line;
29 }
30 
load_go_moves(char * filename)31 moves load_go_moves(char *filename)
32 {
33     moves m;
34     m.n = 128;
35     m.data = (char**)xcalloc(128, sizeof(char*));
36     FILE *fp = fopen(filename, "rb");
37     int count = 0;
38     char *line = 0;
39     while((line = fgetgo(fp))){
40         if(count >= m.n){
41             m.n *= 2;
42             m.data = (char**)xrealloc(m.data, m.n * sizeof(char*));
43         }
44         m.data[count] = line;
45         ++count;
46     }
47     printf("%d\n", count);
48     m.n = count;
49     m.data = (char**)xrealloc(m.data, count * sizeof(char*));
50     fclose(fp);
51     return m;
52 }
53 
string_to_board(char * s,float * board)54 void string_to_board(char *s, float *board)
55 {
56     int i, j;
57     //memset(board, 0, 1*19*19*sizeof(float));
58     int count = 0;
59     for(i = 0; i < 91; ++i){
60         char c = s[i];
61         for(j = 0; j < 4; ++j){
62             int me = (c >> (2*j)) & 1;
63             int you = (c >> (2*j + 1)) & 1;
64             if (me) board[count] = 1;
65             else if (you) board[count] = -1;
66             else board[count] = 0;
67             ++count;
68             if(count >= 19*19) break;
69         }
70     }
71 }
72 
board_to_string(char * s,float * board)73 void board_to_string(char *s, float *board)
74 {
75     int i, j;
76     memset(s, 0, (19*19/4+1)*sizeof(char));
77     int count = 0;
78     for(i = 0; i < 91; ++i){
79         for(j = 0; j < 4; ++j){
80             int me = (board[count] == 1);
81             int you = (board[count] == -1);
82             if (me) s[i] = s[i] | (1<<(2*j));
83             if (you) s[i] = s[i] | (1<<(2*j + 1));
84             ++count;
85             if(count >= 19*19) break;
86         }
87     }
88 }
89 
random_go_moves(moves m,float * boards,float * labels,int n)90 void random_go_moves(moves m, float *boards, float *labels, int n)
91 {
92     int i;
93     memset(labels, 0, 19*19*n*sizeof(float));
94     for(i = 0; i < n; ++i){
95         char *b = m.data[rand()%m.n];
96         int row = b[0];
97         int col = b[1];
98         labels[col + 19*(row + i*19)] = 1;
99         string_to_board(b+2, boards+i*19*19);
100         boards[col + 19*(row + i*19)] = 0;
101 
102         int flip = rand()%2;
103         int rotate = rand()%4;
104         image in = float_to_image(19, 19, 1, boards+i*19*19);
105         image out = float_to_image(19, 19, 1, labels+i*19*19);
106         if(flip){
107             flip_image(in);
108             flip_image(out);
109         }
110         rotate_image_cw(in, rotate);
111         rotate_image_cw(out, rotate);
112     }
113 }
114 
115 
train_go(char * cfgfile,char * weightfile)116 void train_go(char *cfgfile, char *weightfile)
117 {
118     srand(time(0));
119     float avg_loss = -1;
120     char *base = basecfg(cfgfile);
121     printf("%s\n", base);
122     network net = parse_network_cfg(cfgfile);
123     if(weightfile){
124         load_weights(&net, weightfile);
125     }
126     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
127 
128     char* backup_directory = "backup/";
129 
130     char buff[256];
131     float* board = (float*)xcalloc(19 * 19 * net.batch, sizeof(float));
132     float* move = (float*)xcalloc(19 * 19 * net.batch, sizeof(float));
133     moves m = load_go_moves("backup/go.train");
134     //moves m = load_go_moves("games.txt");
135 
136     int N = m.n;
137     int epoch = (*net.seen)/N;
138     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
139         clock_t time=clock();
140 
141         random_go_moves(m, board, move, net.batch);
142         float loss = train_network_datum(net, board, move) / net.batch;
143         if(avg_loss == -1) avg_loss = loss;
144         avg_loss = avg_loss*.95 + loss*.05;
145         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
146         if(*net.seen/N > epoch){
147             epoch = *net.seen/N;
148             char buff[256];
149             sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
150             save_weights(net, buff);
151 
152         }
153         if(get_current_batch(net)%100 == 0){
154             char buff[256];
155             sprintf(buff, "%s/%s.backup",backup_directory,base);
156             save_weights(net, buff);
157         }
158         if(get_current_batch(net)%10000 == 0){
159             char buff[256];
160             sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
161             save_weights(net, buff);
162         }
163     }
164     sprintf(buff, "%s/%s.weights", backup_directory, base);
165     save_weights(net, buff);
166 
167     free_network(net);
168     free(base);
169     free(board);
170     free(move);
171 }
172 
propagate_liberty(float * board,int * lib,int * visited,int row,int col,int side)173 void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
174 {
175     if (row < 0 || row > 18 || col < 0 || col > 18) return;
176     int index = row*19 + col;
177     if (board[index] != side) return;
178     if (visited[index]) return;
179     visited[index] = 1;
180     lib[index] += 1;
181     propagate_liberty(board, lib, visited, row+1, col, side);
182     propagate_liberty(board, lib, visited, row-1, col, side);
183     propagate_liberty(board, lib, visited, row, col+1, side);
184     propagate_liberty(board, lib, visited, row, col-1, side);
185 }
186 
187 
calculate_liberties(float * board)188 int *calculate_liberties(float *board)
189 {
190     int* lib = (int*)xcalloc(19 * 19, sizeof(int));
191     int visited[361];
192     int i, j;
193     for(j = 0; j < 19; ++j){
194         for(i = 0; i < 19; ++i){
195             memset(visited, 0, 19*19*sizeof(int));
196             int index = j*19 + i;
197             if(board[index] == 0){
198                 if ((i > 0)  && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
199                 if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
200                 if ((j > 0)  && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
201                 if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
202             }
203         }
204     }
205     return lib;
206 }
207 
print_board(float * board,int swap,int * indexes)208 void print_board(float *board, int swap, int *indexes)
209 {
210     //FILE *stream = stdout;
211     FILE *stream = stderr;
212     int i,j,n;
213     fprintf(stream, "\n\n");
214     fprintf(stream, "   ");
215     for(i = 0; i < 19; ++i){
216         fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
217     }
218     fprintf(stream, "\n");
219     for(j = 0; j < 19; ++j){
220         fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
221         for(i = 0; i < 19; ++i){
222             int index = j*19 + i;
223             if(indexes){
224                 int found = 0;
225                 for (n = 0; n < n_ind; ++n) {
226                     if(index == indexes[n]){
227                         found = 1;
228                         /*
229                         if(n == 0) fprintf(stream, "\uff11");
230                         else if(n == 1) fprintf(stream, "\uff12");
231                         else if(n == 2) fprintf(stream, "\uff13");
232                         else if(n == 3) fprintf(stream, "\uff14");
233                         else if(n == 4) fprintf(stream, "\uff15");
234                         */
235                         if(n == 0) fprintf(stream, " 1");
236                         else if(n == 1) fprintf(stream, " 2");
237                         else if(n == 2) fprintf(stream, " 3");
238                         else if(n == 3) fprintf(stream, " 4");
239                         else if(n == 4) fprintf(stream, " 5");
240                     }
241                 }
242                 if(found) continue;
243             }
244             //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
245             //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
246             if(board[index]*-swap > 0) fprintf(stream, " O");
247             else if(board[index]*-swap < 0) fprintf(stream, " X");
248             else fprintf(stream, "  ");
249         }
250         fprintf(stream, "\n");
251     }
252 }
253 
flip_board(float * board)254 void flip_board(float *board)
255 {
256     int i;
257     for(i = 0; i < 19*19; ++i){
258         board[i] = -board[i];
259     }
260 }
261 
predict_move(network net,float * board,float * move,int multi)262 void predict_move(network net, float *board, float *move, int multi)
263 {
264     float *output = network_predict(net, board);
265     copy_cpu(19*19, output, 1, move, 1);
266     int i;
267     if(multi){
268         image bim = float_to_image(19, 19, 1, board);
269         for(i = 1; i < 8; ++i){
270             rotate_image_cw(bim, i);
271             if(i >= 4) flip_image(bim);
272 
273             float *output = network_predict(net, board);
274             image oim = float_to_image(19, 19, 1, output);
275 
276             if(i >= 4) flip_image(oim);
277             rotate_image_cw(oim, -i);
278 
279             axpy_cpu(19*19, 1, output, 1, move, 1);
280 
281             if(i >= 4) flip_image(bim);
282             rotate_image_cw(bim, -i);
283         }
284         scal_cpu(19*19, 1./8., move, 1);
285     }
286     for(i = 0; i < 19*19; ++i){
287         if(board[i]) move[i] = 0;
288     }
289 }
290 
remove_connected(float * b,int * lib,int p,int r,int c)291 void remove_connected(float *b, int *lib, int p, int r, int c)
292 {
293     if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
294     if (b[r*19 + c] != p) return;
295     if (lib[r*19 + c] != 1) return;
296     b[r*19 + c] = 0;
297     remove_connected(b, lib, p, r+1, c);
298     remove_connected(b, lib, p, r-1, c);
299     remove_connected(b, lib, p, r, c+1);
300     remove_connected(b, lib, p, r, c-1);
301 }
302 
303 
move_go(float * b,int p,int r,int c)304 void move_go(float *b, int p, int r, int c)
305 {
306     int *l = calculate_liberties(b);
307     b[r*19 + c] = p;
308     remove_connected(b, l, -p, r+1, c);
309     remove_connected(b, l, -p, r-1, c);
310     remove_connected(b, l, -p, r, c+1);
311     remove_connected(b, l, -p, r, c-1);
312     free(l);
313 }
314 
makes_safe_go(float * b,int * lib,int p,int r,int c)315 int makes_safe_go(float *b, int *lib, int p, int r, int c){
316     if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
317     if (b[r*19 + c] == -p){
318         if (lib[r*19 + c] > 1) return 0;
319         else return 1;
320     }
321     if (b[r*19 + c] == 0) return 1;
322     if (lib[r*19 + c] > 1) return 1;
323     return 0;
324 }
325 
suicide_go(float * b,int p,int r,int c)326 int suicide_go(float *b, int p, int r, int c)
327 {
328     int *l = calculate_liberties(b);
329     int safe = 0;
330     safe = safe || makes_safe_go(b, l, p, r+1, c);
331     safe = safe || makes_safe_go(b, l, p, r-1, c);
332     safe = safe || makes_safe_go(b, l, p, r, c+1);
333     safe = safe || makes_safe_go(b, l, p, r, c-1);
334     free(l);
335     return !safe;
336 }
337 
legal_go(float * b,char * ko,int p,int r,int c)338 int legal_go(float *b, char *ko, int p, int r, int c)
339 {
340     if (b[r*19 + c]) return 0;
341     char curr[91];
342     char next[91];
343     board_to_string(curr, b);
344     move_go(b, p, r, c);
345     board_to_string(next, b);
346     string_to_board(curr, b);
347     if(memcmp(next, ko, 91) == 0) return 0;
348     return 1;
349 }
350 
generate_move(network net,int player,float * board,int multi,float thresh,float temp,char * ko,int print)351 int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
352 {
353     int i, j;
354     for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
355 
356     float move[361];
357     if (player < 0) flip_board(board);
358     predict_move(net, board, move, multi);
359     if (player < 0) flip_board(board);
360 
361 
362     for(i = 0; i < 19; ++i){
363         for(j = 0; j < 19; ++j){
364             if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
365         }
366     }
367 
368     int indexes[n_ind];
369     top_k(move, 19*19, n_ind, indexes);
370     if(thresh > move[indexes[0]]) thresh = move[indexes[n_ind-1]];
371 
372     for(i = 0; i < 19; ++i){
373         for(j = 0; j < 19; ++j){
374             if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
375         }
376     }
377 
378 
379     int max = max_index(move, 19*19);
380     int row = max / 19;
381     int col = max % 19;
382     int index = sample_array(move, 19*19);
383 
384     if(print){
385         top_k(move, 19*19, n_ind, indexes);
386         for(i = 0; i < n_ind; ++i){
387             if (!move[indexes[i]]) indexes[i] = -1;
388         }
389         print_board(board, player, indexes);
390         for(i = 0; i < n_ind; ++i){
391             fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
392         }
393     }
394 
395     if(suicide_go(board, player, row, col)){
396         return -1;
397     }
398     if(suicide_go(board, player, index/19, index%19)) index = max;
399     return index;
400 }
401 
valid_go(char * cfgfile,char * weightfile,int multi)402 void valid_go(char *cfgfile, char *weightfile, int multi)
403 {
404     srand(time(0));
405     char *base = basecfg(cfgfile);
406     printf("%s\n", base);
407     network net = parse_network_cfg(cfgfile);
408     if(weightfile){
409         load_weights(&net, weightfile);
410     }
411     set_batch_network(&net, 1);
412     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
413 
414     float* board = (float*)xcalloc(19 * 19, sizeof(float));
415     float* move = (float*)xcalloc(19 * 19, sizeof(float));
416     moves m = load_go_moves("backup/go.test");
417 
418     int N = m.n;
419     int i;
420     int correct = 0;
421     for(i = 0; i <N; ++i){
422         char *b = m.data[i];
423         int row = b[0];
424         int col = b[1];
425         int truth = col + 19*row;
426         string_to_board(b+2, board);
427         predict_move(net, board, move, multi);
428         int index = max_index(move, 19*19);
429         if(index == truth) ++correct;
430         printf("%d Accuracy %f\n", i, (float) correct/(i+1));
431     }
432     free(board);
433     free(move);
434 }
435 
engine_go(char * filename,char * weightfile,int multi)436 void engine_go(char *filename, char *weightfile, int multi)
437 {
438     network net = parse_network_cfg(filename);
439     if(weightfile){
440         load_weights(&net, weightfile);
441     }
442     srand(time(0));
443     set_batch_network(&net, 1);
444     float* board = (float*)xcalloc(19 * 19, sizeof(float));
445     char* one = (char*)xcalloc(91, sizeof(char));
446     char* two = (char*)xcalloc(91, sizeof(char));
447     int passed = 0;
448     while(1){
449         char buff[256];
450         int id = 0;
451         int has_id = (scanf("%d", &id) == 1);
452         scanf("%s", buff);
453         if (feof(stdin)) break;
454         char ids[256];
455         sprintf(ids, "%d", id);
456         //fprintf(stderr, "%s\n", buff);
457         if (!has_id) ids[0] = 0;
458         if (!strcmp(buff, "protocol_version")){
459             printf("=%s 2\n\n", ids);
460         } else if (!strcmp(buff, "name")){
461             printf("=%s DarkGo\n\n", ids);
462         } else if (!strcmp(buff, "version")){
463             printf("=%s 1.0\n\n", ids);
464         } else if (!strcmp(buff, "known_command")){
465             char comm[256];
466             scanf("%s", comm);
467             int known = (!strcmp(comm, "protocol_version") ||
468                     !strcmp(comm, "name") ||
469                     !strcmp(comm, "version") ||
470                     !strcmp(comm, "known_command") ||
471                     !strcmp(comm, "list_commands") ||
472                     !strcmp(comm, "quit") ||
473                     !strcmp(comm, "boardsize") ||
474                     !strcmp(comm, "clear_board") ||
475                     !strcmp(comm, "komi") ||
476                     !strcmp(comm, "final_status_list") ||
477                     !strcmp(comm, "play") ||
478                     !strcmp(comm, "genmove"));
479             if(known) printf("=%s true\n\n", ids);
480             else printf("=%s false\n\n", ids);
481         } else if (!strcmp(buff, "list_commands")){
482             printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
483         } else if (!strcmp(buff, "quit")){
484             break;
485         } else if (!strcmp(buff, "boardsize")){
486             int boardsize = 0;
487             scanf("%d", &boardsize);
488             //fprintf(stderr, "%d\n", boardsize);
489             if(boardsize != 19){
490                 printf("?%s unacceptable size\n\n", ids);
491             } else {
492                 printf("=%s \n\n", ids);
493             }
494         } else if (!strcmp(buff, "clear_board")){
495             passed = 0;
496             memset(board, 0, 19*19*sizeof(float));
497             printf("=%s \n\n", ids);
498         } else if (!strcmp(buff, "komi")){
499             float komi = 0;
500             scanf("%f", &komi);
501             printf("=%s \n\n", ids);
502         } else if (!strcmp(buff, "play")){
503             char color[256];
504             scanf("%s ", color);
505             char c;
506             int r;
507             int count = scanf("%c%d", &c, &r);
508             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
509             if(c == 'p' && count < 2) {
510                 passed = 1;
511                 printf("=%s \n\n", ids);
512                 char *line = fgetl(stdin);
513                 free(line);
514                 fflush(stdout);
515                 fflush(stderr);
516                 continue;
517             } else {
518                 passed = 0;
519             }
520             if(c >= 'A' && c <= 'Z') c = c - 'A';
521             if(c >= 'a' && c <= 'z') c = c - 'a';
522             if(c >= 8) --c;
523             r = 19 - r;
524             fprintf(stderr, "move: %d %d\n", r, c);
525 
526             char *swap = two;
527             two = one;
528             one = swap;
529             move_go(board, player, r, c);
530             board_to_string(one, board);
531 
532             printf("=%s \n\n", ids);
533             print_board(board, 1, 0);
534         } else if (!strcmp(buff, "genmove")){
535             char color[256];
536             scanf("%s", color);
537             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
538 
539             int index = generate_move(net, player, board, multi, .1, .7, two, 1);
540             if(passed || index < 0){
541                 printf("=%s pass\n\n", ids);
542                 passed = 0;
543             } else {
544                 int row = index / 19;
545                 int col = index % 19;
546 
547                 char *swap = two;
548                 two = one;
549                 one = swap;
550 
551                 move_go(board, player, row, col);
552                 board_to_string(one, board);
553                 row = 19 - row;
554                 if (col >= 8) ++col;
555                 printf("=%s %c%d\n\n", ids, 'A' + col, row);
556                 print_board(board, 1, 0);
557             }
558 
559         } else if (!strcmp(buff, "p")){
560             //print_board(board, 1, 0);
561         } else if (!strcmp(buff, "final_status_list")){
562             char type[256];
563             scanf("%s", type);
564             fprintf(stderr, "final_status\n");
565             char *line = fgetl(stdin);
566             free(line);
567             if(type[0] == 'd' || type[0] == 'D'){
568                 FILE *f = fopen("game.txt", "w");
569                 int i, j;
570                 int count = 2;
571                 fprintf(f, "boardsize 19\n");
572                 fprintf(f, "clear_board\n");
573                 for(j = 0; j < 19; ++j){
574                     for(i = 0; i < 19; ++i){
575                         if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
576                         if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
577                         if(board[j*19 + i]) ++count;
578                     }
579                 }
580                 fprintf(f, "final_status_list dead\n");
581                 fclose(f);
582 #ifdef _WIN32
583 				FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
584 #else
585 				FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
586 #endif
587                 for(i = 0; i < count; ++i){
588                     free(fgetl(p));
589                     free(fgetl(p));
590                 }
591                 char *l = 0;
592                 while((l = fgetl(p))){
593                     printf("%s\n", l);
594                     free(l);
595                 }
596             } else {
597                 printf("?%s unknown command\n\n", ids);
598             }
599         } else {
600             char *line = fgetl(stdin);
601             free(line);
602             printf("?%s unknown command\n\n", ids);
603         }
604         fflush(stdout);
605         fflush(stderr);
606     }
607 }
608 
test_go(char * cfg,char * weights,int multi)609 void test_go(char *cfg, char *weights, int multi)
610 {
611     network net = parse_network_cfg(cfg);
612     if(weights){
613         load_weights(&net, weights);
614     }
615     srand(time(0));
616     set_batch_network(&net, 1);
617     float* board = (float*)xcalloc(19 * 19, sizeof(float));
618     float* move = (float*)xcalloc(19 * 19, sizeof(float));
619     int color = 1;
620     while(1){
621         float *output = network_predict(net, board);
622         copy_cpu(19*19, output, 1, move, 1);
623         int i;
624         if(multi){
625             image bim = float_to_image(19, 19, 1, board);
626             for(i = 1; i < 8; ++i){
627                 rotate_image_cw(bim, i);
628                 if(i >= 4) flip_image(bim);
629 
630                 float *output = network_predict(net, board);
631                 image oim = float_to_image(19, 19, 1, output);
632 
633                 if(i >= 4) flip_image(oim);
634                 rotate_image_cw(oim, -i);
635 
636                 axpy_cpu(19*19, 1, output, 1, move, 1);
637 
638                 if(i >= 4) flip_image(bim);
639                 rotate_image_cw(bim, -i);
640             }
641             scal_cpu(19*19, 1./8., move, 1);
642         }
643         for(i = 0; i < 19*19; ++i){
644             if(board[i]) move[i] = 0;
645         }
646 
647         int indexes[n_ind];
648         int row, col;
649         top_k(move, 19 * 19, n_ind, indexes);
650         print_board(board, color, indexes);
651         for (i = 0; i < n_ind; ++i) {
652             int index = indexes[i];
653             row = index / 19;
654             col = index % 19;
655             printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
656         }
657         //if(color == 1) printf("\u25EF Enter move: ");
658         //else printf("\u25C9 Enter move: ");
659         if(color == 1) printf("X Enter move: ");
660         else printf("O Enter move: ");
661 
662         char c;
663         char *line = fgetl(stdin);
664         int picked = 1;
665         int dnum = sscanf(line, "%d", &picked);
666         int cnum = sscanf(line, "%c", &c);
667         if (strlen(line) == 0 || dnum) {
668             --picked;
669             if (picked < n_ind){
670                 int index = indexes[picked];
671                 row = index / 19;
672                 col = index % 19;
673                 board[row*19 + col] = 1;
674             }
675         } else if (cnum){
676             if (c <= 'T' && c >= 'A'){
677                 int num = sscanf(line, "%c %d", &c, &row);
678                 row = (inverted)?19 - row : row-1;
679                 col = c - 'A';
680                 if (col > 7 && noi) col -= 1;
681                 if (num == 2) board[row*19 + col] = 1;
682             } else if (c == 'p') {
683                 // Pass
684             } else if(c=='b' || c == 'w'){
685                 char g;
686                 int num = sscanf(line, "%c %c %d", &g, &c, &row);
687                 row = (inverted)?19 - row : row-1;
688                 col = c - 'A';
689                 if (col > 7 && noi) col -= 1;
690                 if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
691             } else if(c == 'c'){
692                 char g;
693                 int num = sscanf(line, "%c %c %d", &g, &c, &row);
694                 row = (inverted)?19 - row : row-1;
695                 col = c - 'A';
696                 if (col > 7 && noi) col -= 1;
697                 if (num == 3) board[row*19 + col] = 0;
698             }
699         }
700         free(line);
701         flip_board(board);
702         color = -color;
703     }
704 }
705 
score_game(float * board)706 float score_game(float *board)
707 {
708     FILE *f = fopen("game.txt", "w");
709     int i, j;
710     int count = 3;
711     fprintf(f, "komi 6.5\n");
712     fprintf(f, "boardsize 19\n");
713     fprintf(f, "clear_board\n");
714     for(j = 0; j < 19; ++j){
715         for(i = 0; i < 19; ++i){
716             if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
717             if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
718             if(board[j*19 + i]) ++count;
719         }
720     }
721     fprintf(f, "final_score\n");
722     fclose(f);
723 #ifdef _WIN32
724 	FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
725 #else
726 	FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
727 #endif
728     for(i = 0; i < count; ++i){
729         free(fgetl(p));
730         free(fgetl(p));
731     }
732     char *l = 0;
733     float score = 0;
734     char player = 0;
735     while((l = fgetl(p))){
736         fprintf(stderr, "%s  \t", l);
737         int n = sscanf(l, "= %c+%f", &player, &score);
738         free(l);
739         if (n == 2) break;
740     }
741     if(player == 'W') score = -score;
742 #ifdef _WIN32
743 	_pclose(p);
744 #else
745 	pclose(p);
746 #endif
747     return score;
748 }
749 
self_go(char * filename,char * weightfile,char * f2,char * w2,int multi)750 void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
751 {
752     network net = parse_network_cfg(filename);
753     if(weightfile){
754         load_weights(&net, weightfile);
755     }
756 
757     network net2 = net;
758     if(f2){
759         net2 = parse_network_cfg(f2);
760         if(w2){
761             load_weights(&net2, w2);
762         }
763     }
764     srand(time(0));
765     char boards[300][93];
766     int count = 0;
767     set_batch_network(&net, 1);
768     set_batch_network(&net2, 1);
769     float* board = (float*)xcalloc(19 * 19, sizeof(float));
770     char* one = (char*)xcalloc(91, sizeof(char));
771     char* two = (char*)xcalloc(91, sizeof(char));
772     int done = 0;
773     int player = 1;
774     int p1 = 0;
775     int p2 = 0;
776     int total = 0;
777     while(1){
778         if (done || count >= 300){
779             float score = score_game(board);
780             int i = (score > 0)? 0 : 1;
781             if((score > 0) == (total%2==0)) ++p1;
782             else ++p2;
783             ++total;
784             fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
785             int j;
786             for(; i < count; i += 2){
787                 for(j = 0; j < 93; ++j){
788                     printf("%c", boards[i][j]);
789                 }
790                 printf("\n");
791             }
792             memset(board, 0, 19*19*sizeof(float));
793             player = 1;
794             done = 0;
795             count = 0;
796             fflush(stdout);
797             fflush(stderr);
798         }
799         //print_board(board, 1, 0);
800         //sleep(1);
801         network use = ((total%2==0) == (player==1)) ? net : net2;
802         int index = generate_move(use, player, board, multi, .1, .7, two, 0);
803         if(index < 0){
804             done = 1;
805             continue;
806         }
807         int row = index / 19;
808         int col = index % 19;
809 
810         char *swap = two;
811         two = one;
812         one = swap;
813 
814         if(player < 0) flip_board(board);
815         boards[count][0] = row;
816         boards[count][1] = col;
817         board_to_string(boards[count] + 2, board);
818         if(player < 0) flip_board(board);
819         ++count;
820 
821         move_go(board, player, row, col);
822         board_to_string(one, board);
823 
824         player = -player;
825     }
826     free(board);
827     free(one);
828     free(two);
829 }
830 
run_go(int argc,char ** argv)831 void run_go(int argc, char **argv)
832 {
833     //boards_go();
834     if(argc < 4){
835         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
836         return;
837     }
838 
839     char *cfg = argv[3];
840     char *weights = (argc > 4) ? argv[4] : 0;
841     char *c2 = (argc > 5) ? argv[5] : 0;
842     char *w2 = (argc > 6) ? argv[6] : 0;
843     int multi = find_arg(argc, argv, "-multi");
844     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
845     else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
846     else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
847     else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
848     else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
849 }
850