1#!/usr/bin/env python3 2# 3# Used to dump training games in V2 format from MongoDB or V1 chunk files. 4# 5# Usage: v2_write_training [chunk_prefix] 6# If run without a chunk_prefix it reads from MongoDB. 7# With a chunk prefix, it uses all chunk files with that prefix 8# as input. 9# 10# Sets up a dataflow pipeline that: 11# 1. Reads from input (MongoDB or v1 chunk files) 12# 2. Split into a test set and a training set. 13# 3. Converts from v1 format to v2 format. 14# 4. Shuffle V2 records. 15# 5. Write out to compressed v2 chunk files. 16# 17 18from chunkparser import ChunkParser 19import glob 20import gzip 21import itertools 22import multiprocessing as mp 23import numpy as np 24import pymongo 25import sys 26 27def mongo_fetch_games(q_out, num_games): 28 """ 29 Read V1 format games from MongoDB and put them 30 in the output queue (q_out) 31 32 Reads a network list from MongoDB from most recents, 33 and then reads games produced by those network until 34 'num_games' has been read. 35 """ 36 client = pymongo.MongoClient() 37 db = client.test 38 # MongoDB closes idle cursors after 10 minutes unless specific 39 # options are given. That means this query will time out before 40 # we finish. Rather than keeping it alive, increase the default 41 # batch size so we're sure to get all networks in the first fetch. 42 networks = db.networks.find(None, {"_id": False, "hash": True}).\ 43 sort("_id", pymongo.DESCENDING).batch_size(5000) 44 45 game_count = 0 46 for net in networks: 47 print("Searching for {}".format(net['hash'])) 48 49 games = db.games.\ 50 find({"networkhash": net['hash']}, 51 {"_id": False, "data": True}) 52 53 for game in games: 54 game_data = game['data'] 55 q_out.put(game_data.encode("ascii")) 56 57 game_count += 1 58 if game_count >= num_games: 59 q_out.put('STOP') 60 return 61 if game_count % 1000 == 0: 62 print("{} games".format(game_count)) 63 64def disk_fetch_games(q_out, prefix): 65 """ 66 Fetch chunk files off disk. 67 68 Chunk files can be either v1 or v2 format. 69 """ 70 files = glob.glob(prefix + "*.gz") 71 for f in files: 72 with gzip.open(f, 'rb') as chunk_file: 73 v = chunk_file.read() 74 q_out.put(v) 75 print("In {}".format(f)) 76 q_out.put('STOP') 77 78def fake_fetch_games(q_out, num_games): 79 """ 80 Generate V1 format fake games. Used for testing and benchmarking 81 """ 82 for _ in range(num_games): 83 # Generate a 200 move 'game' 84 # Generate a random game move. 85 # 1. 18 binary planes of length 361 86 planes = [np.random.randint(2, size=361).tolist() for plane in range(16)] 87 stm = float(np.random.randint(2)) 88 planes.append([stm] * 361) 89 planes.append([1. - stm] * 361) 90 # 2. 362 probs 91 probs = np.random.randint(3, size=362).tolist() 92 # 3. And a winner: 1 or -1 93 winner = [ 2 * float(np.random.randint(2)) - 1 ] 94 95 # Convert that to a v1 text record. 96 items = [] 97 for p in range(16): 98 # generate first 360 bits 99 h = np.packbits([int(x) for x in planes[p][0:360]]).tobytes().hex() 100 # then add the stray single bit 101 h += str(planes[p][360]) + "\n" 102 items.append(h) 103 # then side to move 104 items.append(str(int(planes[17][0])) + "\n") 105 # then probabilities 106 items.append(' '.join([str(x) for x in probs]) + "\n") 107 # and finally if the side to move is a winner 108 items.append(str(int(winner[0])) + "\n") 109 game = ''.join(items) 110 game = game * 200 111 game = game.encode('ascii') 112 113 q_out.put(game) 114 q_out.put('STOP') 115 116def queue_gen(q, out_qs): 117 """ 118 Turn a queue into a generator 119 120 Yields items pulled from 'q' until a 'STOP' item is seen. 121 The STOP item will be propogated to all the queues in 122 the list 'out_qs' (if any). 123 """ 124 while True: 125 try: 126 item = q.get() 127 except: 128 break 129 if item == 'STOP': 130 break 131 yield item 132 # There might be multiple workers reading from this queue, 133 # and they all need to be stopped, so put the STOP token 134 # back in the queue. 135 q.put('STOP') 136 # Stop any listed output queues as well 137 for x in out_qs: 138 x.put('STOP') 139 140def split_train_test(q_in, q_train, q_test): 141 """ 142 Stream a stream of chunks into separate train and test 143 pools. 10% of the chunks are assigned to test. 144 145 Uses hash sharding, so multiple runs will split chunks 146 in the same way. 147 """ 148 for item in queue_gen(q_in, [q_train, q_test]): 149 # Use the hash of the game to determine the split. This means 150 # that test games will never be used for training. 151 h = hash(item) & 0xfff 152 if h < 0.1*0xfff: 153 # a test game. 154 q_test.put(item) 155 else: 156 q_train.put(item) 157 158class QueueChunkSrc: 159 def __init__(self, q): 160 self.q = q 161 self.gen = None 162 def next(self): 163 print("Queue next") 164 if self.gen is None: 165 self.gen = queue_gen(self.q,[]) 166 try: 167 return next(self.gen) 168 except: 169 return None 170 171 172def chunk_parser(q_in, q_out, shuffle_size, chunk_size): 173 """ 174 Parse input chunks from 'q_in', shuffle, and put 175 chunks of moves in v2 format into 'q_out' 176 177 Each output chunk contains 'chunk_size' moves. 178 Moves are shuffled in a buffer of 'shuffle_size' moves. 179 (A 2^20 items shuffle buffer is ~ 2.2GB of RAM). 180 """ 181 workers = max(1, mp.cpu_count() - 2) 182 parse = ChunkParser(QueueChunkSrc(q_in), 183 shuffle_size=shuffle_size, 184 workers=workers) 185 gen = parse.v2_gen() 186 while True: 187 s = list(itertools.islice(gen, chunk_size)) 188 if not len(s): 189 break 190 s = b''.join(s) 191 q_out.put(s) 192 q_out.put('STOP') 193 194def chunk_writer(q_in, namesrc): 195 """ 196 Write a batch of moves out to disk as a compressed file. 197 198 Filenames are taken from the generator 'namegen'. 199 """ 200 for chunk in queue_gen(q_in,[]): 201 filename = namesrc.next() 202 chunk_file = gzip.open(filename, 'w', 1) 203 chunk_file.write(chunk) 204 chunk_file.close() 205 print("chunk_writer completed") 206 207class NameSrc: 208 """ 209 Generator a sequence of names, starting with 'prefix'. 210 """ 211 def __init__(self, prefix): 212 self.prefix = prefix 213 self.n = 0 214 def next(self): 215 print("Name next") 216 self.n += 1 217 return self.prefix + "{:0>8d}.gz".format(self.n) 218 219def main(args): 220 # Build the pipeline. 221 procs=[] 222 # Read from input. 223 q_games = mp.SimpleQueue() 224 if args: 225 prefix = args.pop(0) 226 print("Reading from chunkfiles {}".format(prefix)) 227 procs.append(mp.Process(target=disk_fetch_games, args=(q_games, prefix))) 228 else: 229 print("Reading from MongoDB") 230 #procs.append(mp.Process(target=fake_fetch_games, args=(q_games, 20))) 231 procs.append(mp.Process(target=mongo_fetch_games, args=(q_games, 275000))) 232 # Split into train/test 233 q_test = mp.SimpleQueue() 234 q_train = mp.SimpleQueue() 235 procs.append(mp.Process(target=split_train_test, args=(q_games, q_train, q_test))) 236 # Convert v1 to v2 format and shuffle, writing 8192 moves per chunk. 237 q_write_train = mp.SimpleQueue() 238 q_write_test = mp.SimpleQueue() 239 # Shuffle buffer is ~ 2.2GB of RAM with 2^20 (~1e6) entries. A game is ~500 moves, so 240 # there's ~2000 games in the shuffle buffer. Selecting 8k moves gives an expected 241 # number of ~4 moves from the same game in a given chunk file. 242 # 243 # The output files are in parse.py via another 1e6 sized shuffle buffer. At 8192 moves 244 # per chunk, there's ~ 128 chunks in the shuffle buffer. With a batch size of 4096, 245 # the expected max number of moves from the same game in the batch is < 1.14 246 procs.append(mp.Process(target=chunk_parser, args=(q_train, q_write_train, 1<<20, 8192))) 247 procs.append(mp.Process(target=chunk_parser, args=(q_test, q_write_test, 1<<16, 8192))) 248 # Write to output files 249 procs.append(mp.Process(target=chunk_writer, args=(q_write_train, NameSrc('train_')))) 250 procs.append(mp.Process(target=chunk_writer, args=(q_write_test, NameSrc('test_')))) 251 252 # Start all the child processes running. 253 for p in procs: 254 p.start() 255 # Wait for everything to finish. 256 for p in procs: 257 p.join() 258 # All done! 259 260if __name__ == "__main__": 261 mp.set_start_method('spawn') 262 main(sys.argv[1:]) 263