1 #include "network.h"
2 #include "utils.h"
3 #include "parser.h"
4 #include "option_list.h"
5 #include "blas.h"
6
7
8 int inverted = 1;
9 int noi = 1;
10 //static const unsigned int n_ind = 5;
11 #define n_ind 5
12
13 typedef struct {
14 char **data;
15 int n;
16 } moves;
17
fgetgo(FILE * fp)18 char *fgetgo(FILE *fp)
19 {
20 if(feof(fp)) return 0;
21 size_t size = 94;
22 char* line = (char*)xmalloc(size * sizeof(char));
23 if(size != fread(line, sizeof(char), size, fp)){
24 free(line);
25 return 0;
26 }
27
28 return line;
29 }
30
load_go_moves(char * filename)31 moves load_go_moves(char *filename)
32 {
33 moves m;
34 m.n = 128;
35 m.data = (char**)xcalloc(128, sizeof(char*));
36 FILE *fp = fopen(filename, "rb");
37 int count = 0;
38 char *line = 0;
39 while((line = fgetgo(fp))){
40 if(count >= m.n){
41 m.n *= 2;
42 m.data = (char**)xrealloc(m.data, m.n * sizeof(char*));
43 }
44 m.data[count] = line;
45 ++count;
46 }
47 printf("%d\n", count);
48 m.n = count;
49 m.data = (char**)xrealloc(m.data, count * sizeof(char*));
50 fclose(fp);
51 return m;
52 }
53
string_to_board(char * s,float * board)54 void string_to_board(char *s, float *board)
55 {
56 int i, j;
57 //memset(board, 0, 1*19*19*sizeof(float));
58 int count = 0;
59 for(i = 0; i < 91; ++i){
60 char c = s[i];
61 for(j = 0; j < 4; ++j){
62 int me = (c >> (2*j)) & 1;
63 int you = (c >> (2*j + 1)) & 1;
64 if (me) board[count] = 1;
65 else if (you) board[count] = -1;
66 else board[count] = 0;
67 ++count;
68 if(count >= 19*19) break;
69 }
70 }
71 }
72
board_to_string(char * s,float * board)73 void board_to_string(char *s, float *board)
74 {
75 int i, j;
76 memset(s, 0, (19*19/4+1)*sizeof(char));
77 int count = 0;
78 for(i = 0; i < 91; ++i){
79 for(j = 0; j < 4; ++j){
80 int me = (board[count] == 1);
81 int you = (board[count] == -1);
82 if (me) s[i] = s[i] | (1<<(2*j));
83 if (you) s[i] = s[i] | (1<<(2*j + 1));
84 ++count;
85 if(count >= 19*19) break;
86 }
87 }
88 }
89
random_go_moves(moves m,float * boards,float * labels,int n)90 void random_go_moves(moves m, float *boards, float *labels, int n)
91 {
92 int i;
93 memset(labels, 0, 19*19*n*sizeof(float));
94 for(i = 0; i < n; ++i){
95 char *b = m.data[rand()%m.n];
96 int row = b[0];
97 int col = b[1];
98 labels[col + 19*(row + i*19)] = 1;
99 string_to_board(b+2, boards+i*19*19);
100 boards[col + 19*(row + i*19)] = 0;
101
102 int flip = rand()%2;
103 int rotate = rand()%4;
104 image in = float_to_image(19, 19, 1, boards+i*19*19);
105 image out = float_to_image(19, 19, 1, labels+i*19*19);
106 if(flip){
107 flip_image(in);
108 flip_image(out);
109 }
110 rotate_image_cw(in, rotate);
111 rotate_image_cw(out, rotate);
112 }
113 }
114
115
train_go(char * cfgfile,char * weightfile)116 void train_go(char *cfgfile, char *weightfile)
117 {
118 srand(time(0));
119 float avg_loss = -1;
120 char *base = basecfg(cfgfile);
121 printf("%s\n", base);
122 network net = parse_network_cfg(cfgfile);
123 if(weightfile){
124 load_weights(&net, weightfile);
125 }
126 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
127
128 char* backup_directory = "backup/";
129
130 char buff[256];
131 float* board = (float*)xcalloc(19 * 19 * net.batch, sizeof(float));
132 float* move = (float*)xcalloc(19 * 19 * net.batch, sizeof(float));
133 moves m = load_go_moves("backup/go.train");
134 //moves m = load_go_moves("games.txt");
135
136 int N = m.n;
137 int epoch = (*net.seen)/N;
138 while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
139 clock_t time=clock();
140
141 random_go_moves(m, board, move, net.batch);
142 float loss = train_network_datum(net, board, move) / net.batch;
143 if(avg_loss == -1) avg_loss = loss;
144 avg_loss = avg_loss*.95 + loss*.05;
145 printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
146 if(*net.seen/N > epoch){
147 epoch = *net.seen/N;
148 char buff[256];
149 sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
150 save_weights(net, buff);
151
152 }
153 if(get_current_batch(net)%100 == 0){
154 char buff[256];
155 sprintf(buff, "%s/%s.backup",backup_directory,base);
156 save_weights(net, buff);
157 }
158 if(get_current_batch(net)%10000 == 0){
159 char buff[256];
160 sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
161 save_weights(net, buff);
162 }
163 }
164 sprintf(buff, "%s/%s.weights", backup_directory, base);
165 save_weights(net, buff);
166
167 free_network(net);
168 free(base);
169 free(board);
170 free(move);
171 }
172
propagate_liberty(float * board,int * lib,int * visited,int row,int col,int side)173 void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
174 {
175 if (row < 0 || row > 18 || col < 0 || col > 18) return;
176 int index = row*19 + col;
177 if (board[index] != side) return;
178 if (visited[index]) return;
179 visited[index] = 1;
180 lib[index] += 1;
181 propagate_liberty(board, lib, visited, row+1, col, side);
182 propagate_liberty(board, lib, visited, row-1, col, side);
183 propagate_liberty(board, lib, visited, row, col+1, side);
184 propagate_liberty(board, lib, visited, row, col-1, side);
185 }
186
187
calculate_liberties(float * board)188 int *calculate_liberties(float *board)
189 {
190 int* lib = (int*)xcalloc(19 * 19, sizeof(int));
191 int visited[361];
192 int i, j;
193 for(j = 0; j < 19; ++j){
194 for(i = 0; i < 19; ++i){
195 memset(visited, 0, 19*19*sizeof(int));
196 int index = j*19 + i;
197 if(board[index] == 0){
198 if ((i > 0) && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
199 if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
200 if ((j > 0) && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
201 if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
202 }
203 }
204 }
205 return lib;
206 }
207
print_board(float * board,int swap,int * indexes)208 void print_board(float *board, int swap, int *indexes)
209 {
210 //FILE *stream = stdout;
211 FILE *stream = stderr;
212 int i,j,n;
213 fprintf(stream, "\n\n");
214 fprintf(stream, " ");
215 for(i = 0; i < 19; ++i){
216 fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
217 }
218 fprintf(stream, "\n");
219 for(j = 0; j < 19; ++j){
220 fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
221 for(i = 0; i < 19; ++i){
222 int index = j*19 + i;
223 if(indexes){
224 int found = 0;
225 for (n = 0; n < n_ind; ++n) {
226 if(index == indexes[n]){
227 found = 1;
228 /*
229 if(n == 0) fprintf(stream, "\uff11");
230 else if(n == 1) fprintf(stream, "\uff12");
231 else if(n == 2) fprintf(stream, "\uff13");
232 else if(n == 3) fprintf(stream, "\uff14");
233 else if(n == 4) fprintf(stream, "\uff15");
234 */
235 if(n == 0) fprintf(stream, " 1");
236 else if(n == 1) fprintf(stream, " 2");
237 else if(n == 2) fprintf(stream, " 3");
238 else if(n == 3) fprintf(stream, " 4");
239 else if(n == 4) fprintf(stream, " 5");
240 }
241 }
242 if(found) continue;
243 }
244 //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
245 //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
246 if(board[index]*-swap > 0) fprintf(stream, " O");
247 else if(board[index]*-swap < 0) fprintf(stream, " X");
248 else fprintf(stream, " ");
249 }
250 fprintf(stream, "\n");
251 }
252 }
253
flip_board(float * board)254 void flip_board(float *board)
255 {
256 int i;
257 for(i = 0; i < 19*19; ++i){
258 board[i] = -board[i];
259 }
260 }
261
predict_move(network net,float * board,float * move,int multi)262 void predict_move(network net, float *board, float *move, int multi)
263 {
264 float *output = network_predict(net, board);
265 copy_cpu(19*19, output, 1, move, 1);
266 int i;
267 if(multi){
268 image bim = float_to_image(19, 19, 1, board);
269 for(i = 1; i < 8; ++i){
270 rotate_image_cw(bim, i);
271 if(i >= 4) flip_image(bim);
272
273 float *output = network_predict(net, board);
274 image oim = float_to_image(19, 19, 1, output);
275
276 if(i >= 4) flip_image(oim);
277 rotate_image_cw(oim, -i);
278
279 axpy_cpu(19*19, 1, output, 1, move, 1);
280
281 if(i >= 4) flip_image(bim);
282 rotate_image_cw(bim, -i);
283 }
284 scal_cpu(19*19, 1./8., move, 1);
285 }
286 for(i = 0; i < 19*19; ++i){
287 if(board[i]) move[i] = 0;
288 }
289 }
290
remove_connected(float * b,int * lib,int p,int r,int c)291 void remove_connected(float *b, int *lib, int p, int r, int c)
292 {
293 if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
294 if (b[r*19 + c] != p) return;
295 if (lib[r*19 + c] != 1) return;
296 b[r*19 + c] = 0;
297 remove_connected(b, lib, p, r+1, c);
298 remove_connected(b, lib, p, r-1, c);
299 remove_connected(b, lib, p, r, c+1);
300 remove_connected(b, lib, p, r, c-1);
301 }
302
303
move_go(float * b,int p,int r,int c)304 void move_go(float *b, int p, int r, int c)
305 {
306 int *l = calculate_liberties(b);
307 b[r*19 + c] = p;
308 remove_connected(b, l, -p, r+1, c);
309 remove_connected(b, l, -p, r-1, c);
310 remove_connected(b, l, -p, r, c+1);
311 remove_connected(b, l, -p, r, c-1);
312 free(l);
313 }
314
makes_safe_go(float * b,int * lib,int p,int r,int c)315 int makes_safe_go(float *b, int *lib, int p, int r, int c){
316 if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
317 if (b[r*19 + c] == -p){
318 if (lib[r*19 + c] > 1) return 0;
319 else return 1;
320 }
321 if (b[r*19 + c] == 0) return 1;
322 if (lib[r*19 + c] > 1) return 1;
323 return 0;
324 }
325
suicide_go(float * b,int p,int r,int c)326 int suicide_go(float *b, int p, int r, int c)
327 {
328 int *l = calculate_liberties(b);
329 int safe = 0;
330 safe = safe || makes_safe_go(b, l, p, r+1, c);
331 safe = safe || makes_safe_go(b, l, p, r-1, c);
332 safe = safe || makes_safe_go(b, l, p, r, c+1);
333 safe = safe || makes_safe_go(b, l, p, r, c-1);
334 free(l);
335 return !safe;
336 }
337
legal_go(float * b,char * ko,int p,int r,int c)338 int legal_go(float *b, char *ko, int p, int r, int c)
339 {
340 if (b[r*19 + c]) return 0;
341 char curr[91];
342 char next[91];
343 board_to_string(curr, b);
344 move_go(b, p, r, c);
345 board_to_string(next, b);
346 string_to_board(curr, b);
347 if(memcmp(next, ko, 91) == 0) return 0;
348 return 1;
349 }
350
generate_move(network net,int player,float * board,int multi,float thresh,float temp,char * ko,int print)351 int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
352 {
353 int i, j;
354 for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
355
356 float move[361];
357 if (player < 0) flip_board(board);
358 predict_move(net, board, move, multi);
359 if (player < 0) flip_board(board);
360
361
362 for(i = 0; i < 19; ++i){
363 for(j = 0; j < 19; ++j){
364 if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
365 }
366 }
367
368 int indexes[n_ind];
369 top_k(move, 19*19, n_ind, indexes);
370 if(thresh > move[indexes[0]]) thresh = move[indexes[n_ind-1]];
371
372 for(i = 0; i < 19; ++i){
373 for(j = 0; j < 19; ++j){
374 if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
375 }
376 }
377
378
379 int max = max_index(move, 19*19);
380 int row = max / 19;
381 int col = max % 19;
382 int index = sample_array(move, 19*19);
383
384 if(print){
385 top_k(move, 19*19, n_ind, indexes);
386 for(i = 0; i < n_ind; ++i){
387 if (!move[indexes[i]]) indexes[i] = -1;
388 }
389 print_board(board, player, indexes);
390 for(i = 0; i < n_ind; ++i){
391 fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
392 }
393 }
394
395 if(suicide_go(board, player, row, col)){
396 return -1;
397 }
398 if(suicide_go(board, player, index/19, index%19)) index = max;
399 return index;
400 }
401
valid_go(char * cfgfile,char * weightfile,int multi)402 void valid_go(char *cfgfile, char *weightfile, int multi)
403 {
404 srand(time(0));
405 char *base = basecfg(cfgfile);
406 printf("%s\n", base);
407 network net = parse_network_cfg(cfgfile);
408 if(weightfile){
409 load_weights(&net, weightfile);
410 }
411 set_batch_network(&net, 1);
412 printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
413
414 float* board = (float*)xcalloc(19 * 19, sizeof(float));
415 float* move = (float*)xcalloc(19 * 19, sizeof(float));
416 moves m = load_go_moves("backup/go.test");
417
418 int N = m.n;
419 int i;
420 int correct = 0;
421 for(i = 0; i <N; ++i){
422 char *b = m.data[i];
423 int row = b[0];
424 int col = b[1];
425 int truth = col + 19*row;
426 string_to_board(b+2, board);
427 predict_move(net, board, move, multi);
428 int index = max_index(move, 19*19);
429 if(index == truth) ++correct;
430 printf("%d Accuracy %f\n", i, (float) correct/(i+1));
431 }
432 free(board);
433 free(move);
434 }
435
engine_go(char * filename,char * weightfile,int multi)436 void engine_go(char *filename, char *weightfile, int multi)
437 {
438 network net = parse_network_cfg(filename);
439 if(weightfile){
440 load_weights(&net, weightfile);
441 }
442 srand(time(0));
443 set_batch_network(&net, 1);
444 float* board = (float*)xcalloc(19 * 19, sizeof(float));
445 char* one = (char*)xcalloc(91, sizeof(char));
446 char* two = (char*)xcalloc(91, sizeof(char));
447 int passed = 0;
448 while(1){
449 char buff[256];
450 int id = 0;
451 int has_id = (scanf("%d", &id) == 1);
452 scanf("%s", buff);
453 if (feof(stdin)) break;
454 char ids[256];
455 sprintf(ids, "%d", id);
456 //fprintf(stderr, "%s\n", buff);
457 if (!has_id) ids[0] = 0;
458 if (!strcmp(buff, "protocol_version")){
459 printf("=%s 2\n\n", ids);
460 } else if (!strcmp(buff, "name")){
461 printf("=%s DarkGo\n\n", ids);
462 } else if (!strcmp(buff, "version")){
463 printf("=%s 1.0\n\n", ids);
464 } else if (!strcmp(buff, "known_command")){
465 char comm[256];
466 scanf("%s", comm);
467 int known = (!strcmp(comm, "protocol_version") ||
468 !strcmp(comm, "name") ||
469 !strcmp(comm, "version") ||
470 !strcmp(comm, "known_command") ||
471 !strcmp(comm, "list_commands") ||
472 !strcmp(comm, "quit") ||
473 !strcmp(comm, "boardsize") ||
474 !strcmp(comm, "clear_board") ||
475 !strcmp(comm, "komi") ||
476 !strcmp(comm, "final_status_list") ||
477 !strcmp(comm, "play") ||
478 !strcmp(comm, "genmove"));
479 if(known) printf("=%s true\n\n", ids);
480 else printf("=%s false\n\n", ids);
481 } else if (!strcmp(buff, "list_commands")){
482 printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
483 } else if (!strcmp(buff, "quit")){
484 break;
485 } else if (!strcmp(buff, "boardsize")){
486 int boardsize = 0;
487 scanf("%d", &boardsize);
488 //fprintf(stderr, "%d\n", boardsize);
489 if(boardsize != 19){
490 printf("?%s unacceptable size\n\n", ids);
491 } else {
492 printf("=%s \n\n", ids);
493 }
494 } else if (!strcmp(buff, "clear_board")){
495 passed = 0;
496 memset(board, 0, 19*19*sizeof(float));
497 printf("=%s \n\n", ids);
498 } else if (!strcmp(buff, "komi")){
499 float komi = 0;
500 scanf("%f", &komi);
501 printf("=%s \n\n", ids);
502 } else if (!strcmp(buff, "play")){
503 char color[256];
504 scanf("%s ", color);
505 char c;
506 int r;
507 int count = scanf("%c%d", &c, &r);
508 int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
509 if(c == 'p' && count < 2) {
510 passed = 1;
511 printf("=%s \n\n", ids);
512 char *line = fgetl(stdin);
513 free(line);
514 fflush(stdout);
515 fflush(stderr);
516 continue;
517 } else {
518 passed = 0;
519 }
520 if(c >= 'A' && c <= 'Z') c = c - 'A';
521 if(c >= 'a' && c <= 'z') c = c - 'a';
522 if(c >= 8) --c;
523 r = 19 - r;
524 fprintf(stderr, "move: %d %d\n", r, c);
525
526 char *swap = two;
527 two = one;
528 one = swap;
529 move_go(board, player, r, c);
530 board_to_string(one, board);
531
532 printf("=%s \n\n", ids);
533 print_board(board, 1, 0);
534 } else if (!strcmp(buff, "genmove")){
535 char color[256];
536 scanf("%s", color);
537 int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
538
539 int index = generate_move(net, player, board, multi, .1, .7, two, 1);
540 if(passed || index < 0){
541 printf("=%s pass\n\n", ids);
542 passed = 0;
543 } else {
544 int row = index / 19;
545 int col = index % 19;
546
547 char *swap = two;
548 two = one;
549 one = swap;
550
551 move_go(board, player, row, col);
552 board_to_string(one, board);
553 row = 19 - row;
554 if (col >= 8) ++col;
555 printf("=%s %c%d\n\n", ids, 'A' + col, row);
556 print_board(board, 1, 0);
557 }
558
559 } else if (!strcmp(buff, "p")){
560 //print_board(board, 1, 0);
561 } else if (!strcmp(buff, "final_status_list")){
562 char type[256];
563 scanf("%s", type);
564 fprintf(stderr, "final_status\n");
565 char *line = fgetl(stdin);
566 free(line);
567 if(type[0] == 'd' || type[0] == 'D'){
568 FILE *f = fopen("game.txt", "w");
569 int i, j;
570 int count = 2;
571 fprintf(f, "boardsize 19\n");
572 fprintf(f, "clear_board\n");
573 for(j = 0; j < 19; ++j){
574 for(i = 0; i < 19; ++i){
575 if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
576 if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
577 if(board[j*19 + i]) ++count;
578 }
579 }
580 fprintf(f, "final_status_list dead\n");
581 fclose(f);
582 #ifdef _WIN32
583 FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
584 #else
585 FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
586 #endif
587 for(i = 0; i < count; ++i){
588 free(fgetl(p));
589 free(fgetl(p));
590 }
591 char *l = 0;
592 while((l = fgetl(p))){
593 printf("%s\n", l);
594 free(l);
595 }
596 } else {
597 printf("?%s unknown command\n\n", ids);
598 }
599 } else {
600 char *line = fgetl(stdin);
601 free(line);
602 printf("?%s unknown command\n\n", ids);
603 }
604 fflush(stdout);
605 fflush(stderr);
606 }
607 }
608
test_go(char * cfg,char * weights,int multi)609 void test_go(char *cfg, char *weights, int multi)
610 {
611 network net = parse_network_cfg(cfg);
612 if(weights){
613 load_weights(&net, weights);
614 }
615 srand(time(0));
616 set_batch_network(&net, 1);
617 float* board = (float*)xcalloc(19 * 19, sizeof(float));
618 float* move = (float*)xcalloc(19 * 19, sizeof(float));
619 int color = 1;
620 while(1){
621 float *output = network_predict(net, board);
622 copy_cpu(19*19, output, 1, move, 1);
623 int i;
624 if(multi){
625 image bim = float_to_image(19, 19, 1, board);
626 for(i = 1; i < 8; ++i){
627 rotate_image_cw(bim, i);
628 if(i >= 4) flip_image(bim);
629
630 float *output = network_predict(net, board);
631 image oim = float_to_image(19, 19, 1, output);
632
633 if(i >= 4) flip_image(oim);
634 rotate_image_cw(oim, -i);
635
636 axpy_cpu(19*19, 1, output, 1, move, 1);
637
638 if(i >= 4) flip_image(bim);
639 rotate_image_cw(bim, -i);
640 }
641 scal_cpu(19*19, 1./8., move, 1);
642 }
643 for(i = 0; i < 19*19; ++i){
644 if(board[i]) move[i] = 0;
645 }
646
647 int indexes[n_ind];
648 int row, col;
649 top_k(move, 19 * 19, n_ind, indexes);
650 print_board(board, color, indexes);
651 for (i = 0; i < n_ind; ++i) {
652 int index = indexes[i];
653 row = index / 19;
654 col = index % 19;
655 printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
656 }
657 //if(color == 1) printf("\u25EF Enter move: ");
658 //else printf("\u25C9 Enter move: ");
659 if(color == 1) printf("X Enter move: ");
660 else printf("O Enter move: ");
661
662 char c;
663 char *line = fgetl(stdin);
664 int picked = 1;
665 int dnum = sscanf(line, "%d", &picked);
666 int cnum = sscanf(line, "%c", &c);
667 if (strlen(line) == 0 || dnum) {
668 --picked;
669 if (picked < n_ind){
670 int index = indexes[picked];
671 row = index / 19;
672 col = index % 19;
673 board[row*19 + col] = 1;
674 }
675 } else if (cnum){
676 if (c <= 'T' && c >= 'A'){
677 int num = sscanf(line, "%c %d", &c, &row);
678 row = (inverted)?19 - row : row-1;
679 col = c - 'A';
680 if (col > 7 && noi) col -= 1;
681 if (num == 2) board[row*19 + col] = 1;
682 } else if (c == 'p') {
683 // Pass
684 } else if(c=='b' || c == 'w'){
685 char g;
686 int num = sscanf(line, "%c %c %d", &g, &c, &row);
687 row = (inverted)?19 - row : row-1;
688 col = c - 'A';
689 if (col > 7 && noi) col -= 1;
690 if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
691 } else if(c == 'c'){
692 char g;
693 int num = sscanf(line, "%c %c %d", &g, &c, &row);
694 row = (inverted)?19 - row : row-1;
695 col = c - 'A';
696 if (col > 7 && noi) col -= 1;
697 if (num == 3) board[row*19 + col] = 0;
698 }
699 }
700 free(line);
701 flip_board(board);
702 color = -color;
703 }
704 }
705
score_game(float * board)706 float score_game(float *board)
707 {
708 FILE *f = fopen("game.txt", "w");
709 int i, j;
710 int count = 3;
711 fprintf(f, "komi 6.5\n");
712 fprintf(f, "boardsize 19\n");
713 fprintf(f, "clear_board\n");
714 for(j = 0; j < 19; ++j){
715 for(i = 0; i < 19; ++i){
716 if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
717 if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
718 if(board[j*19 + i]) ++count;
719 }
720 }
721 fprintf(f, "final_score\n");
722 fclose(f);
723 #ifdef _WIN32
724 FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
725 #else
726 FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
727 #endif
728 for(i = 0; i < count; ++i){
729 free(fgetl(p));
730 free(fgetl(p));
731 }
732 char *l = 0;
733 float score = 0;
734 char player = 0;
735 while((l = fgetl(p))){
736 fprintf(stderr, "%s \t", l);
737 int n = sscanf(l, "= %c+%f", &player, &score);
738 free(l);
739 if (n == 2) break;
740 }
741 if(player == 'W') score = -score;
742 #ifdef _WIN32
743 _pclose(p);
744 #else
745 pclose(p);
746 #endif
747 return score;
748 }
749
self_go(char * filename,char * weightfile,char * f2,char * w2,int multi)750 void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
751 {
752 network net = parse_network_cfg(filename);
753 if(weightfile){
754 load_weights(&net, weightfile);
755 }
756
757 network net2 = net;
758 if(f2){
759 net2 = parse_network_cfg(f2);
760 if(w2){
761 load_weights(&net2, w2);
762 }
763 }
764 srand(time(0));
765 char boards[300][93];
766 int count = 0;
767 set_batch_network(&net, 1);
768 set_batch_network(&net2, 1);
769 float* board = (float*)xcalloc(19 * 19, sizeof(float));
770 char* one = (char*)xcalloc(91, sizeof(char));
771 char* two = (char*)xcalloc(91, sizeof(char));
772 int done = 0;
773 int player = 1;
774 int p1 = 0;
775 int p2 = 0;
776 int total = 0;
777 while(1){
778 if (done || count >= 300){
779 float score = score_game(board);
780 int i = (score > 0)? 0 : 1;
781 if((score > 0) == (total%2==0)) ++p1;
782 else ++p2;
783 ++total;
784 fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
785 int j;
786 for(; i < count; i += 2){
787 for(j = 0; j < 93; ++j){
788 printf("%c", boards[i][j]);
789 }
790 printf("\n");
791 }
792 memset(board, 0, 19*19*sizeof(float));
793 player = 1;
794 done = 0;
795 count = 0;
796 fflush(stdout);
797 fflush(stderr);
798 }
799 //print_board(board, 1, 0);
800 //sleep(1);
801 network use = ((total%2==0) == (player==1)) ? net : net2;
802 int index = generate_move(use, player, board, multi, .1, .7, two, 0);
803 if(index < 0){
804 done = 1;
805 continue;
806 }
807 int row = index / 19;
808 int col = index % 19;
809
810 char *swap = two;
811 two = one;
812 one = swap;
813
814 if(player < 0) flip_board(board);
815 boards[count][0] = row;
816 boards[count][1] = col;
817 board_to_string(boards[count] + 2, board);
818 if(player < 0) flip_board(board);
819 ++count;
820
821 move_go(board, player, row, col);
822 board_to_string(one, board);
823
824 player = -player;
825 }
826 free(board);
827 free(one);
828 free(two);
829 }
830
run_go(int argc,char ** argv)831 void run_go(int argc, char **argv)
832 {
833 //boards_go();
834 if(argc < 4){
835 fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
836 return;
837 }
838
839 char *cfg = argv[3];
840 char *weights = (argc > 4) ? argv[4] : 0;
841 char *c2 = (argc > 5) ? argv[5] : 0;
842 char *w2 = (argc > 6) ? argv[6] : 0;
843 int multi = find_arg(argc, argv, "-multi");
844 if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
845 else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
846 else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
847 else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
848 else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
849 }
850