1#!/usr/bin/env python3
2#
3# Used to dump training games in V2 format from MongoDB or V1 chunk files.
4#
5# Usage: v2_write_training [chunk_prefix]
6#   If run without a chunk_prefix it reads from MongoDB.
7#  With a chunk prefix, it uses all chunk files with that prefix
8#  as input.
9#
10# Sets up a dataflow pipeline that:
11# 1. Reads from input (MongoDB or v1 chunk files)
12# 2. Split into a test set and a training set.
13# 3. Converts from v1 format to v2 format.
14# 4. Shuffle V2 records.
15# 5. Write out to compressed v2 chunk files.
16#
17
18from chunkparser import ChunkParser
19import glob
20import gzip
21import itertools
22import multiprocessing as mp
23import numpy as np
24import pymongo
25import sys
26
27def mongo_fetch_games(q_out, num_games):
28    """
29        Read V1 format games from MongoDB and put them
30        in the output queue (q_out)
31
32        Reads a network list from MongoDB from most recents,
33        and then reads games produced by those network until
34        'num_games' has been read.
35    """
36    client = pymongo.MongoClient()
37    db = client.test
38    # MongoDB closes idle cursors after 10 minutes unless specific
39    # options are given. That means this query will time out before
40    # we finish. Rather than keeping it alive, increase the default
41    # batch size so we're sure to get all networks in the first fetch.
42    networks = db.networks.find(None, {"_id": False, "hash": True}).\
43        sort("_id", pymongo.DESCENDING).batch_size(5000)
44
45    game_count = 0
46    for net in networks:
47        print("Searching for {}".format(net['hash']))
48
49        games = db.games.\
50            find({"networkhash": net['hash']},
51                 {"_id": False, "data": True})
52
53        for game in games:
54            game_data = game['data']
55            q_out.put(game_data.encode("ascii"))
56
57            game_count += 1
58            if game_count >= num_games:
59                q_out.put('STOP')
60                return
61            if game_count % 1000 == 0:
62                print("{} games".format(game_count))
63
64def disk_fetch_games(q_out, prefix):
65    """
66        Fetch chunk files off disk.
67
68        Chunk files can be either v1 or v2 format.
69    """
70    files = glob.glob(prefix + "*.gz")
71    for f in files:
72        with gzip.open(f, 'rb') as chunk_file:
73            v = chunk_file.read()
74            q_out.put(v)
75            print("In {}".format(f))
76    q_out.put('STOP')
77
78def fake_fetch_games(q_out, num_games):
79    """
80        Generate V1 format fake games. Used for testing and benchmarking
81    """
82    for _ in range(num_games):
83        # Generate a 200 move 'game'
84        # Generate a random game move.
85        # 1. 18 binary planes of length 361
86        planes = [np.random.randint(2, size=361).tolist() for plane in range(16)]
87        stm = float(np.random.randint(2))
88        planes.append([stm] * 361)
89        planes.append([1. - stm] * 361)
90        # 2. 362 probs
91        probs = np.random.randint(3, size=362).tolist()
92        # 3. And a winner: 1 or -1
93        winner = [ 2 * float(np.random.randint(2)) - 1 ]
94
95        # Convert that to a v1 text record.
96        items = []
97        for p in range(16):
98            # generate first 360 bits
99            h = np.packbits([int(x) for x in planes[p][0:360]]).tobytes().hex()
100            # then add the stray single bit
101            h += str(planes[p][360]) + "\n"
102            items.append(h)
103        # then side to move
104        items.append(str(int(planes[17][0])) + "\n")
105        # then probabilities
106        items.append(' '.join([str(x) for x in probs]) + "\n")
107        # and finally if the side to move is a winner
108        items.append(str(int(winner[0])) + "\n")
109        game = ''.join(items)
110        game = game * 200
111        game = game.encode('ascii')
112
113        q_out.put(game)
114    q_out.put('STOP')
115
116def queue_gen(q, out_qs):
117    """
118        Turn a queue into a generator
119
120        Yields items pulled from 'q' until a 'STOP' item is seen.
121        The STOP item will be propogated to all the queues in
122        the list 'out_qs' (if any).
123    """
124    while True:
125        try:
126            item = q.get()
127        except:
128            break
129        if item == 'STOP':
130            break
131        yield item
132    # There might be multiple workers reading from this queue,
133    # and they all need to be stopped, so put the STOP token
134    # back in the queue.
135    q.put('STOP')
136    # Stop any listed output queues as well
137    for x in out_qs:
138        x.put('STOP')
139
140def split_train_test(q_in, q_train, q_test):
141    """
142        Stream a stream of chunks into separate train and test
143        pools. 10% of the chunks are assigned to test.
144
145        Uses hash sharding, so multiple runs will split chunks
146        in the same way.
147    """
148    for item in queue_gen(q_in, [q_train, q_test]):
149        # Use the hash of the game to determine the split. This means
150        # that test games will never be used for training.
151        h = hash(item) & 0xfff
152        if h < 0.1*0xfff:
153            # a test game.
154            q_test.put(item)
155        else:
156            q_train.put(item)
157
158class QueueChunkSrc:
159    def __init__(self, q):
160        self.q = q
161        self.gen = None
162    def next(self):
163        print("Queue next")
164        if self.gen is None:
165            self.gen = queue_gen(self.q,[])
166        try:
167            return next(self.gen)
168        except:
169            return None
170
171
172def chunk_parser(q_in, q_out, shuffle_size, chunk_size):
173    """
174        Parse input chunks from 'q_in', shuffle, and put
175        chunks of moves in v2 format into 'q_out'
176
177        Each output chunk contains 'chunk_size' moves.
178        Moves are shuffled in a buffer of 'shuffle_size' moves.
179        (A 2^20 items shuffle buffer is ~ 2.2GB of RAM).
180    """
181    workers = max(1, mp.cpu_count() - 2)
182    parse = ChunkParser(QueueChunkSrc(q_in),
183                        shuffle_size=shuffle_size,
184                        workers=workers)
185    gen = parse.v2_gen()
186    while True:
187        s = list(itertools.islice(gen, chunk_size))
188        if not len(s):
189            break
190        s = b''.join(s)
191        q_out.put(s)
192    q_out.put('STOP')
193
194def chunk_writer(q_in, namesrc):
195    """
196        Write a batch of moves out to disk as a compressed file.
197
198        Filenames are taken from the generator 'namegen'.
199    """
200    for chunk in queue_gen(q_in,[]):
201        filename = namesrc.next()
202        chunk_file = gzip.open(filename, 'w', 1)
203        chunk_file.write(chunk)
204        chunk_file.close()
205    print("chunk_writer completed")
206
207class NameSrc:
208    """
209        Generator a sequence of names, starting with 'prefix'.
210    """
211    def __init__(self, prefix):
212        self.prefix = prefix
213        self.n = 0
214    def next(self):
215        print("Name next")
216        self.n += 1
217        return self.prefix + "{:0>8d}.gz".format(self.n)
218
219def main(args):
220    # Build the pipeline.
221    procs=[]
222    # Read from input.
223    q_games = mp.SimpleQueue()
224    if args:
225        prefix = args.pop(0)
226        print("Reading from chunkfiles {}".format(prefix))
227        procs.append(mp.Process(target=disk_fetch_games, args=(q_games, prefix)))
228    else:
229        print("Reading from MongoDB")
230        #procs.append(mp.Process(target=fake_fetch_games, args=(q_games, 20)))
231        procs.append(mp.Process(target=mongo_fetch_games, args=(q_games, 275000)))
232    # Split into train/test
233    q_test = mp.SimpleQueue()
234    q_train = mp.SimpleQueue()
235    procs.append(mp.Process(target=split_train_test, args=(q_games, q_train, q_test)))
236    # Convert v1 to v2 format and shuffle, writing 8192 moves per chunk.
237    q_write_train = mp.SimpleQueue()
238    q_write_test = mp.SimpleQueue()
239    # Shuffle buffer is ~ 2.2GB of RAM with 2^20 (~1e6) entries. A game is ~500 moves, so
240    # there's ~2000 games in the shuffle buffer. Selecting 8k moves gives an expected
241    # number of ~4 moves from the same game in a given chunk file.
242    #
243    # The output files are in parse.py via another 1e6 sized shuffle buffer. At 8192 moves
244    # per chunk, there's ~ 128 chunks in the shuffle buffer. With a batch size of 4096,
245    # the expected max number of moves from the same game in the batch is < 1.14
246    procs.append(mp.Process(target=chunk_parser, args=(q_train, q_write_train, 1<<20, 8192)))
247    procs.append(mp.Process(target=chunk_parser, args=(q_test, q_write_test, 1<<16, 8192)))
248    # Write to output files
249    procs.append(mp.Process(target=chunk_writer, args=(q_write_train, NameSrc('train_'))))
250    procs.append(mp.Process(target=chunk_writer, args=(q_write_test, NameSrc('test_'))))
251
252    # Start all the child processes running.
253    for p in procs:
254        p.start()
255    # Wait for everything to finish.
256    for p in procs:
257        p.join()
258    # All done!
259
260if __name__ == "__main__":
261    mp.set_start_method('spawn')
262    main(sys.argv[1:])
263