1__copyright__ = "Copyright 2016-2020, Netflix, Inc."
2__license__ = "BSD+Patent"
3
4import os
5
6import numpy as np
7
8from vmaf.config import VmafConfig
9from vmaf.core.executor import run_executors_in_parallel
10from vmaf.core.raw_extractor import DisYUVRawVideoExtractor
11from vmaf.core.nn_train_test_model import ToddNoiseClassifierTrainTestModel
12from vmaf.routine import read_dataset
13from vmaf.tools.misc import import_python_file
14
15
16def main():
17    # parameters
18    num_train = 500
19    num_test = 50
20    n_epochs = 30
21    seed = 0 # None
22
23    # read input dataset
24    dataset_path = VmafConfig.resource_path('dataset', 'BSDS500_noisy_dataset.py')
25    dataset = import_python_file(dataset_path)
26    assets = read_dataset(dataset)
27
28    # shuffle assets
29    np.random.seed(seed)
30    np.random.shuffle(assets)
31    assets = assets[:(num_train + num_test)]
32
33    raw_video_h5py_filepath = VmafConfig.workdir_path('rawvideo.hdf5')
34    raw_video_h5py_file = DisYUVRawVideoExtractor.open_h5py_file(raw_video_h5py_filepath)
35
36    print('======================== Extract raw YUVs ==============================')
37
38    _, raw_yuvs = run_executors_in_parallel(
39        DisYUVRawVideoExtractor,
40        assets,
41        fifo_mode=True,
42        delete_workdir=True,
43        parallelize=False, # CAN ONLY USE SERIAL MODE FOR DisYRawVideoExtractor
44        result_store=None,
45        optional_dict=None,
46        optional_dict2={'h5py_file': raw_video_h5py_file})
47
48    patch_h5py_filepath = VmafConfig.workdir_path('patch.hdf5')
49    patch_h5py_file = ToddNoiseClassifierTrainTestModel.open_h5py_file(patch_h5py_filepath)
50    model = ToddNoiseClassifierTrainTestModel(
51        param_dict={
52            'seed': seed,
53            'n_epochs': n_epochs,
54        },
55        logger=None,
56        optional_dict2={ # for options that won't impact the result
57            # 'checkpoints_dir': VmafConfig.workspace_path('checkpoints_dir'),
58            'h5py_file': patch_h5py_file,
59        })
60
61    print('============================ Train model ===============================')
62    xys = ToddNoiseClassifierTrainTestModel.get_xys_from_results(raw_yuvs[:num_train])
63    model.train(xys)
64
65    print('=========================== Evaluate model =============================')
66    xs = ToddNoiseClassifierTrainTestModel.get_xs_from_results(raw_yuvs[num_train:])
67    ys = ToddNoiseClassifierTrainTestModel.get_ys_from_results(raw_yuvs[num_train:])
68    result = model.evaluate(xs, ys)
69
70    print("")
71    print("f1 test %g, errorrate test %g" % (result['f1'], result['errorrate']))
72
73    # tear down
74    DisYUVRawVideoExtractor.close_h5py_file(raw_video_h5py_file)
75    ToddNoiseClassifierTrainTestModel.close_h5py_file(patch_h5py_file)
76    os.remove(raw_video_h5py_filepath)
77    os.remove(patch_h5py_filepath)
78
79    print('Done.')
80
81
82if __name__ == "__main__":
83    main()
84