1#!/usr/bin/env python3 2 3import matplotlib 4matplotlib.use('Agg') 5 6import sys 7import os 8 9import numpy as np 10 11from vmaf.config import VmafConfig, DisplayConfig 12from vmaf.core.asset import Asset 13from vmaf.core.quality_runner import VmafQualityRunner 14from vmaf.tools.misc import get_file_name_without_extension, get_cmd_option, \ 15 cmd_option_exists 16from vmaf.tools.stats import ListStats 17 18__copyright__ = "Copyright 2016-2020, Netflix, Inc." 19__license__ = "BSD+Patent" 20 21FMTS = ['yuv420p', 'yuv422p', 'yuv444p', 22 'yuv420p10le', 'yuv422p10le', 'yuv444p10le', 23 'yuv420p12le', 'yuv422p12le', 'yuv444p12le', 24 'yuv420p16le', 'yuv422p16le', 'yuv444p16le', 25 ] 26OUT_FMTS = ['text (default)', 'xml', 'json'] 27POOL_METHODS = ['mean', 'harmonic_mean', 'min', 'median', 'perc5', 'perc10', 'perc20'] 28 29 30def print_usage(): 31 print("usage: " + os.path.basename(sys.argv[0]) \ 32 + " fmt width height ref_path dis_path [--model model_path] [--out-fmt out_fmt] " \ 33 "[--phone-model] [--ci] [--save-plot plot_dir]\n") 34 print("fmt:\n\t" + "\n\t".join(FMTS) + "\n") 35 print("out_fmt:\n\t" + "\n\t".join(OUT_FMTS) + "\n") 36 37 38def main(): 39 if len(sys.argv) < 6: 40 print_usage() 41 return 2 42 43 try: 44 fmt = sys.argv[1] 45 width = int(sys.argv[2]) 46 height = int(sys.argv[3]) 47 ref_file = sys.argv[4] 48 dis_file = sys.argv[5] 49 except ValueError: 50 print_usage() 51 return 2 52 53 if width < 0 or height < 0: 54 print("width and height must be non-negative, but are {w} and {h}".format(w=width, h=height)) 55 print_usage() 56 return 2 57 58 if fmt not in FMTS: 59 print_usage() 60 return 2 61 62 model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model') 63 64 out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt') 65 if not (out_fmt is None 66 or out_fmt == 'xml' 67 or out_fmt == 'json' 68 or out_fmt == 'text'): 69 print_usage() 70 return 2 71 72 pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool') 73 if not (pool_method is None 74 or pool_method in POOL_METHODS): 75 print('--pool can only have option among {}'.format(', '.join(POOL_METHODS))) 76 return 2 77 78 show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv), '--local-explain') 79 80 phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv), '--phone-model') 81 82 enable_conf_interval = cmd_option_exists(sys.argv, 6, len(sys.argv), '--ci') 83 84 save_plot_dir = get_cmd_option(sys.argv, 6, len(sys.argv), '--save-plot') 85 86 if show_local_explanation and enable_conf_interval: 87 print('cannot set both --local-explain and --ci flags') 88 return 2 89 90 asset = Asset(dataset="cmd", 91 content_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16), 92 asset_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16), 93 workdir_root=VmafConfig.workdir_path(), 94 ref_path=ref_file, 95 dis_path=dis_file, 96 asset_dict={'width':width, 'height':height, 'yuv_type':fmt} 97 ) 98 assets = [asset] 99 100 if show_local_explanation: 101 from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer 102 runner_class = VmafQualityRunnerWithLocalExplainer 103 elif enable_conf_interval: 104 from vmaf.core.quality_runner import BootstrapVmafQualityRunner 105 runner_class = BootstrapVmafQualityRunner 106 else: 107 runner_class = VmafQualityRunner 108 109 if model_path is None: 110 optional_dict = None 111 else: 112 optional_dict = {'model_filepath':model_path} 113 114 if phone_model: 115 if optional_dict is None: 116 optional_dict = {} 117 optional_dict['enable_transform_score'] = True 118 119 runner = runner_class( 120 assets, None, fifo_mode=True, 121 delete_workdir=True, 122 result_store=None, 123 optional_dict=optional_dict, 124 optional_dict2=None, 125 ) 126 127 # run 128 runner.run() 129 result = runner.results[0] 130 131 # pooling 132 if pool_method == 'harmonic_mean': 133 result.set_score_aggregate_method(ListStats.harmonic_mean) 134 elif pool_method == 'min': 135 result.set_score_aggregate_method(np.min) 136 elif pool_method == 'median': 137 result.set_score_aggregate_method(np.median) 138 elif pool_method == 'perc5': 139 result.set_score_aggregate_method(ListStats.perc5) 140 elif pool_method == 'perc10': 141 result.set_score_aggregate_method(ListStats.perc10) 142 elif pool_method == 'perc20': 143 result.set_score_aggregate_method(ListStats.perc20) 144 else: # None or 'mean' 145 pass 146 147 # output 148 if out_fmt == 'xml': 149 print(result.to_xml()) 150 elif out_fmt == 'json': 151 print(result.to_json()) 152 else: # None or 'text' 153 print(str(result)) 154 155 # local explanation 156 if show_local_explanation: 157 runner.show_local_explanations([result]) 158 159 if save_plot_dir is None: 160 DisplayConfig.show() 161 else: 162 DisplayConfig.show(write_to_dir=save_plot_dir) 163 164 return 0 165 166 167if __name__ == "__main__": 168 ret = main() 169 exit(ret) 170