1#!/usr/bin/env python3
3# Used to dump training games in V2 format from MongoDB or V1 chunk files.
5# Usage: v2_write_training [chunk_prefix]
6#   If run without a chunk_prefix it reads from MongoDB.
7#  With a chunk prefix, it uses all chunk files with that prefix
8#  as input.
10# Sets up a dataflow pipeline that:
11# 1. Reads from input (MongoDB or v1 chunk files)
12# 2. Split into a test set and a training set.
13# 3. Converts from v1 format to v2 format.
14# 4. Shuffle V2 records.
15# 5. Write out to compressed v2 chunk files.
18from chunkparser import ChunkParser
19import glob
20import gzip
21import itertools
22import multiprocessing as mp
23import numpy as np
24import pymongo
25import sys
27def mongo_fetch_games(q_out, num_games):
28    """
29        Read V1 format games from MongoDB and put them
30        in the output queue (q_out)
32        Reads a network list from MongoDB from most recents,
33        and then reads games produced by those network until
34        'num_games' has been read.
35    """
36    client = pymongo.MongoClient()
37    db = client.test
38    # MongoDB closes idle cursors after 10 minutes unless specific
39    # options are given. That means this query will time out before
40    # we finish. Rather than keeping it alive, increase the default
41    # batch size so we're sure to get all networks in the first fetch.
42    networks = db.networks.find(None, {"_id": False, "hash": True}).\
43        sort("_id", pymongo.DESCENDING).batch_size(5000)
45    game_count = 0
46    for net in networks:
47        print("Searching for {}".format(net['hash']))
49        games = db.games.\
50            find({"networkhash": net['hash']},
51                 {"_id": False, "data": True})
53        for game in games:
54            game_data = game['data']
55            q_out.put(game_data.encode("ascii"))
57            game_count += 1
58            if game_count >= num_games:
59                q_out.put('STOP')
60                return
61            if game_count % 1000 == 0:
62                print("{} games".format(game_count))
64def disk_fetch_games(q_out, prefix):
65    """
66        Fetch chunk files off disk.
68        Chunk files can be either v1 or v2 format.
69    """
70    files = glob.glob(prefix + "*.gz")
71    for f in files:
72        with gzip.open(f, 'rb') as chunk_file:
73            v = chunk_file.read()
74            q_out.put(v)
75            print("In {}".format(f))
76    q_out.put('STOP')
78def fake_fetch_games(q_out, num_games):
79    """
80        Generate V1 format fake games. Used for testing and benchmarking
81    """
82    for _ in range(num_games):
83        # Generate a 200 move 'game'
84        # Generate a random game move.
85        # 1. 18 binary planes of length 361
86        planes = [np.random.randint(2, size=361).tolist() for plane in range(16)]
87        stm = float(np.random.randint(2))
88        planes.append([stm] * 361)
89        planes.append([1. - stm] * 361)
90        # 2. 362 probs
91        probs = np.random.randint(3, size=362).tolist()
92        # 3. And a winner: 1 or -1
93        winner = [ 2 * float(np.random.randint(2)) - 1 ]
95        # Convert that to a v1 text record.
96        items = []
97        for p in range(16):
98            # generate first 360 bits
99            h = np.packbits([int(x) for x in planes[p][0:360]]).tobytes().hex()
100            # then add the stray single bit
101            h += str(planes[p][360]) + "\n"
102            items.append(h)
103        # then side to move
104        items.append(str(int(planes[17][0])) + "\n")
105        # then probabilities
106        items.append(' '.join([str(x) for x in probs]) + "\n")
107        # and finally if the side to move is a winner
108        items.append(str(int(winner[0])) + "\n")
109        game = ''.join(items)
110        game = game * 200
111        game = game.encode('ascii')
113        q_out.put(game)
114    q_out.put('STOP')
116def queue_gen(q, out_qs):
117    """
118        Turn a queue into a generator
120        Yields items pulled from 'q' until a 'STOP' item is seen.
121        The STOP item will be propogated to all the queues in
122        the list 'out_qs' (if any).
123    """
124    while True:
125        try:
126            item = q.get()
127        except:
128            break
129        if item == 'STOP':
130            break
131        yield item
132    # There might be multiple workers reading from this queue,
133    # and they all need to be stopped, so put the STOP token
134    # back in the queue.
135    q.put('STOP')
136    # Stop any listed output queues as well
137    for x in out_qs:
138        x.put('STOP')
140def split_train_test(q_in, q_train, q_test):
141    """
142        Stream a stream of chunks into separate train and test
143        pools. 10% of the chunks are assigned to test.
145        Uses hash sharding, so multiple runs will split chunks
146        in the same way.
147    """
148    for item in queue_gen(q_in, [q_train, q_test]):
149        # Use the hash of the game to determine the split. This means
150        # that test games will never be used for training.
151        h = hash(item) & 0xfff
152        if h < 0.1*0xfff:
153            # a test game.
154            q_test.put(item)
155        else:
156            q_train.put(item)
158class QueueChunkSrc:
159    def __init__(self, q):
160        self.q = q
161        self.gen = None
162    def next(self):
163        print("Queue next")
164        if self.gen is None:
165            self.gen = queue_gen(self.q,[])
166        try:
167            return next(self.gen)
168        except:
169            return None
172def chunk_parser(q_in, q_out, shuffle_size, chunk_size):
173    """
174        Parse input chunks from 'q_in', shuffle, and put
175        chunks of moves in v2 format into 'q_out'
177        Each output chunk contains 'chunk_size' moves.
178        Moves are shuffled in a buffer of 'shuffle_size' moves.
179        (A 2^20 items shuffle buffer is ~ 2.2GB of RAM).
180    """
181    workers = max(1, mp.cpu_count() - 2)
182    parse = ChunkParser(QueueChunkSrc(q_in),
183                        shuffle_size=shuffle_size,
184                        workers=workers)
185    gen = parse.v2_gen()
186    while True:
187        s = list(itertools.islice(gen, chunk_size))
188        if not len(s):
189            break
190        s = b''.join(s)
191        q_out.put(s)
192    q_out.put('STOP')
194def chunk_writer(q_in, namesrc):
195    """
196        Write a batch of moves out to disk as a compressed file.
198        Filenames are taken from the generator 'namegen'.
199    """
200    for chunk in queue_gen(q_in,[]):
201        filename = namesrc.next()
202        chunk_file = gzip.open(filename, 'w', 1)
203        chunk_file.write(chunk)
204        chunk_file.close()
205    print("chunk_writer completed")
207class NameSrc:
208    """
209        Generator a sequence of names, starting with 'prefix'.
210    """
211    def __init__(self, prefix):
212        self.prefix = prefix
213        self.n = 0
214    def next(self):
215        print("Name next")
216        self.n += 1
217        return self.prefix + "{:0>8d}.gz".format(self.n)
219def main(args):
220    # Build the pipeline.
221    procs=[]
222    # Read from input.
223    q_games = mp.SimpleQueue()
224    if args:
225        prefix = args.pop(0)
226        print("Reading from chunkfiles {}".format(prefix))
227        procs.append(mp.Process(target=disk_fetch_games, args=(q_games, prefix)))
228    else:
229        print("Reading from MongoDB")
230        #procs.append(mp.Process(target=fake_fetch_games, args=(q_games, 20)))
231        procs.append(mp.Process(target=mongo_fetch_games, args=(q_games, 275000)))
232    # Split into train/test
233    q_test = mp.SimpleQueue()
234    q_train = mp.SimpleQueue()
235    procs.append(mp.Process(target=split_train_test, args=(q_games, q_train, q_test)))
236    # Convert v1 to v2 format and shuffle, writing 8192 moves per chunk.
237    q_write_train = mp.SimpleQueue()
238    q_write_test = mp.SimpleQueue()
239    # Shuffle buffer is ~ 2.2GB of RAM with 2^20 (~1e6) entries. A game is ~500 moves, so
240    # there's ~2000 games in the shuffle buffer. Selecting 8k moves gives an expected
241    # number of ~4 moves from the same game in a given chunk file.
242    #
243    # The output files are in parse.py via another 1e6 sized shuffle buffer. At 8192 moves
244    # per chunk, there's ~ 128 chunks in the shuffle buffer. With a batch size of 4096,
245    # the expected max number of moves from the same game in the batch is < 1.14
246    procs.append(mp.Process(target=chunk_parser, args=(q_train, q_write_train, 1<<20, 8192)))
247    procs.append(mp.Process(target=chunk_parser, args=(q_test, q_write_test, 1<<16, 8192)))
248    # Write to output files
249    procs.append(mp.Process(target=chunk_writer, args=(q_write_train, NameSrc('train_'))))
250    procs.append(mp.Process(target=chunk_writer, args=(q_write_test, NameSrc('test_'))))
252    # Start all the child processes running.
253    for p in procs:
254        p.start()
255    # Wait for everything to finish.
256    for p in procs:
257        p.join()
258    # All done!
260if __name__ == "__main__":
261    mp.set_start_method('spawn')
262    main(sys.argv[1:])