1__copyright__ = "Copyright 2016-2020, Netflix, Inc." 2__license__ = "BSD+Patent" 3 4import os 5 6import numpy as np 7 8from vmaf.config import VmafConfig 9from vmaf.core.executor import run_executors_in_parallel 10from vmaf.core.raw_extractor import DisYUVRawVideoExtractor 11from vmaf.core.nn_train_test_model import ToddNoiseClassifierTrainTestModel 12from vmaf.routine import read_dataset 13from vmaf.tools.misc import import_python_file 14 15 16def main(): 17 # parameters 18 num_train = 500 19 num_test = 50 20 n_epochs = 30 21 seed = 0 # None 22 23 # read input dataset 24 dataset_path = VmafConfig.resource_path('dataset', 'BSDS500_noisy_dataset.py') 25 dataset = import_python_file(dataset_path) 26 assets = read_dataset(dataset) 27 28 # shuffle assets 29 np.random.seed(seed) 30 np.random.shuffle(assets) 31 assets = assets[:(num_train + num_test)] 32 33 raw_video_h5py_filepath = VmafConfig.workdir_path('rawvideo.hdf5') 34 raw_video_h5py_file = DisYUVRawVideoExtractor.open_h5py_file(raw_video_h5py_filepath) 35 36 print('======================== Extract raw YUVs ==============================') 37 38 _, raw_yuvs = run_executors_in_parallel( 39 DisYUVRawVideoExtractor, 40 assets, 41 fifo_mode=True, 42 delete_workdir=True, 43 parallelize=False, # CAN ONLY USE SERIAL MODE FOR DisYRawVideoExtractor 44 result_store=None, 45 optional_dict=None, 46 optional_dict2={'h5py_file': raw_video_h5py_file}) 47 48 patch_h5py_filepath = VmafConfig.workdir_path('patch.hdf5') 49 patch_h5py_file = ToddNoiseClassifierTrainTestModel.open_h5py_file(patch_h5py_filepath) 50 model = ToddNoiseClassifierTrainTestModel( 51 param_dict={ 52 'seed': seed, 53 'n_epochs': n_epochs, 54 }, 55 logger=None, 56 optional_dict2={ # for options that won't impact the result 57 # 'checkpoints_dir': VmafConfig.workspace_path('checkpoints_dir'), 58 'h5py_file': patch_h5py_file, 59 }) 60 61 print('============================ Train model ===============================') 62 xys = ToddNoiseClassifierTrainTestModel.get_xys_from_results(raw_yuvs[:num_train]) 63 model.train(xys) 64 65 print('=========================== Evaluate model =============================') 66 xs = ToddNoiseClassifierTrainTestModel.get_xs_from_results(raw_yuvs[num_train:]) 67 ys = ToddNoiseClassifierTrainTestModel.get_ys_from_results(raw_yuvs[num_train:]) 68 result = model.evaluate(xs, ys) 69 70 print("") 71 print("f1 test %g, errorrate test %g" % (result['f1'], result['errorrate'])) 72 73 # tear down 74 DisYUVRawVideoExtractor.close_h5py_file(raw_video_h5py_file) 75 ToddNoiseClassifierTrainTestModel.close_h5py_file(patch_h5py_file) 76 os.remove(raw_video_h5py_filepath) 77 os.remove(patch_h5py_filepath) 78 79 print('Done.') 80 81 82if __name__ == "__main__": 83 main() 84