1#!/usr/bin/env python3
2
3import glob
4import gzip
5import os
6import re
7import subprocess
8import sys
9import tempfile
10import threading
11import time
12
13# You should change these to your liking
14n_threads   = 8
15datadirs    = "/tmp/rdcost/data/"
16resultdir   = "/tmp/rdcost/coeff_buckets"
17
18gzargs      = ["gzip", "-d"]
19filtargs    = ["./frcosts_matrix"]
20octargs     = ["octave-cli", "invert_matrix.m"]
21filt2args   = ["./ols_2ndpart"]
22
23class MultiPipeManager:
24    pipe_fn_template  = "%02i.txt"
25
26    def __init__(self, odpath, dest_qps):
27        self.odpath = odpath
28        self.dest_qps = dest_qps
29
30        self.pipe_fns  = []
31        for qp in dest_qps:
32            pipe_fn  = os.path.join(self.odpath, self.pipe_fn_template % qp)
33            self.pipe_fns.append(pipe_fn)
34
35    def __enter__(self):
36        os.makedirs(self.odpath, exist_ok=True)
37        for pipe_fn in self.pipe_fns:
38            try:
39                os.unlink(pipe_fn)
40            except FileNotFoundError:
41                pass
42            os.mkfifo(pipe_fn)
43        return self
44
45    def __exit__(self, *_):
46        for pipe_fn in self.pipe_fns:
47            os.unlink(pipe_fn)
48
49    def items(self):
50        for pipe_fn in self.pipe_fns:
51            yield pipe_fn
52
53class MTSafeIterable:
54    def __init__(self, iterable):
55        self.lock = threading.Lock()
56        self.iterable = iterable
57
58    def __iter__(self):
59        return self
60
61    def __next__(self):
62        with self.lock:
63            return next(self.iterable)
64
65def read_in_blocks(f):
66    BLOCK_SZ = 65536
67    while True:
68        block = f.read(BLOCK_SZ)
69        if (len(block) == 0):
70            break
71        else:
72            yield block
73
74def exhaust_gzs(sink_f, gzs):
75    for gz in gzs:
76        with gzip.open(gz, "rb") as f:
77            if (gz == "/tmp/rdcost/data/RaceHorses_416x240_30.yuv-qp22/20.txt.gz"):
78                print("kjeh")
79            print("  Doing %s ..." % gz)
80            for block in read_in_blocks(f):
81                sink_f.write(block)
82                sink_f.flush()
83
84def run_job(jobname, input_gzs):
85    resultpath = os.path.join(resultdir, "%s.result" % jobname)
86    print("Running job %s" % jobname)
87
88    with tempfile.NamedTemporaryFile() as tf:
89        filt = subprocess.Popen(filtargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
90        octa = subprocess.Popen(octargs, stdin=filt.stdout, stdout=tf)
91
92        try:
93            exhaust_gzs(filt.stdin, input_gzs)
94        except OSError as e:
95            print("OSError %s" % e, file=sys.stderr)
96            raise
97
98        filt.stdin.close()
99        filt.wait()
100        octa.wait()
101
102        if (filt.returncode != 0):
103            print("First stage failed: %s" % jobname, file=sys.stderr)
104            assert(0)
105
106        with open(resultpath, "w") as rf:
107            f2a = filt2args + [tf.name]
108            f2 = subprocess.Popen(f2a, stdin=subprocess.PIPE, stdout=rf)
109            exhaust_gzs(f2.stdin, input_gzs)
110            f2.communicate()
111            if (filt.returncode != 0):
112                print("Second stage failed: %s" % jobname, file=sys.stderr)
113                assert(0)
114
115    print("Job %s done" % jobname)
116
117def threadfunc(joblist):
118    for jobname, job in joblist:
119        run_job(jobname, job)
120
121def scan_datadirs(path):
122    seq_names = set()
123    for dirent in os.scandir(path):
124        if (not dirent.is_dir()):
125            continue
126        match = re.search("^([A-Za-z0-9_]+\.yuv)-qp[0-9]{1,2}$", dirent.name)
127        if (not match is None):
128            seq_name = match.groups()[0]
129            seq_names.add(seq_name)
130
131    for seq_name in seq_names:
132        seq_glob = os.path.join(path, seq_name + "-qp*/")
133
134        for qp in range(51):
135            job_name = seq_name + "-qp%02i" % qp
136            qp_fn = "%02i.txt.gz" % qp
137            yield job_name, glob.glob(os.path.join(seq_glob, qp_fn))
138
139def main():
140    for d in (datadirs, resultdir):
141        os.makedirs(d, exist_ok=True)
142
143    jobs = scan_datadirs(datadirs)
144    joblist = MTSafeIterable(iter(jobs))
145
146    threads = [threading.Thread(target=threadfunc, args=(joblist,)) for _ in range(n_threads)]
147    for thread in threads:
148        thread.start()
149
150    for thread in threads:
151        thread.join()
152
153if (__name__ == "__main__"):
154    main()
155