1#!/usr/bin/env python3
2
3import matplotlib
4matplotlib.use('Agg')
5
6import sys
7import os
8
9import numpy as np
10
11from vmaf.config import VmafConfig, DisplayConfig
12from vmaf.core.asset import Asset
13from vmaf.core.quality_runner import VmafQualityRunner
14from vmaf.tools.misc import get_file_name_without_extension, get_cmd_option, \
15    cmd_option_exists
16from vmaf.tools.stats import ListStats
17
18__copyright__ = "Copyright 2016-2020, Netflix, Inc."
19__license__ = "BSD+Patent"
20
21FMTS = ['yuv420p', 'yuv422p', 'yuv444p',
22        'yuv420p10le', 'yuv422p10le', 'yuv444p10le',
23        'yuv420p12le', 'yuv422p12le', 'yuv444p12le',
24        'yuv420p16le', 'yuv422p16le', 'yuv444p16le',
25        ]
26OUT_FMTS = ['text (default)', 'xml', 'json']
27POOL_METHODS = ['mean', 'harmonic_mean', 'min', 'median', 'perc5', 'perc10', 'perc20']
28
29
30def print_usage():
31    print("usage: " + os.path.basename(sys.argv[0]) \
32          + " fmt width height ref_path dis_path [--model model_path] [--out-fmt out_fmt] " \
33            "[--phone-model] [--ci] [--save-plot plot_dir]\n")
34    print("fmt:\n\t" + "\n\t".join(FMTS) + "\n")
35    print("out_fmt:\n\t" + "\n\t".join(OUT_FMTS) + "\n")
36
37
38def main():
39    if len(sys.argv) < 6:
40        print_usage()
41        return 2
42
43    try:
44        fmt = sys.argv[1]
45        width = int(sys.argv[2])
46        height = int(sys.argv[3])
47        ref_file = sys.argv[4]
48        dis_file = sys.argv[5]
49    except ValueError:
50        print_usage()
51        return 2
52
53    if width < 0 or height < 0:
54        print("width and height must be non-negative, but are {w} and {h}".format(w=width, h=height))
55        print_usage()
56        return 2
57
58    if fmt not in FMTS:
59        print_usage()
60        return 2
61
62    model_path = get_cmd_option(sys.argv, 6, len(sys.argv), '--model')
63
64    out_fmt = get_cmd_option(sys.argv, 6, len(sys.argv), '--out-fmt')
65    if not (out_fmt is None
66            or out_fmt == 'xml'
67            or out_fmt == 'json'
68            or out_fmt == 'text'):
69        print_usage()
70        return 2
71
72    pool_method = get_cmd_option(sys.argv, 6, len(sys.argv), '--pool')
73    if not (pool_method is None
74            or pool_method in POOL_METHODS):
75        print('--pool can only have option among {}'.format(', '.join(POOL_METHODS)))
76        return 2
77
78    show_local_explanation = cmd_option_exists(sys.argv, 6, len(sys.argv), '--local-explain')
79
80    phone_model = cmd_option_exists(sys.argv, 6, len(sys.argv), '--phone-model')
81
82    enable_conf_interval = cmd_option_exists(sys.argv, 6, len(sys.argv), '--ci')
83
84    save_plot_dir = get_cmd_option(sys.argv, 6, len(sys.argv), '--save-plot')
85
86    if show_local_explanation and enable_conf_interval:
87        print('cannot set both --local-explain and --ci flags')
88        return 2
89
90    asset = Asset(dataset="cmd",
91                  content_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
92                  asset_id=abs(hash(get_file_name_without_extension(ref_file))) % (10 ** 16),
93                  workdir_root=VmafConfig.workdir_path(),
94                  ref_path=ref_file,
95                  dis_path=dis_file,
96                  asset_dict={'width':width, 'height':height, 'yuv_type':fmt}
97                  )
98    assets = [asset]
99
100    if show_local_explanation:
101        from vmaf.core.quality_runner_extra import VmafQualityRunnerWithLocalExplainer
102        runner_class = VmafQualityRunnerWithLocalExplainer
103    elif enable_conf_interval:
104        from vmaf.core.quality_runner import BootstrapVmafQualityRunner
105        runner_class = BootstrapVmafQualityRunner
106    else:
107        runner_class = VmafQualityRunner
108
109    if model_path is None:
110        optional_dict = None
111    else:
112        optional_dict = {'model_filepath':model_path}
113
114    if phone_model:
115        if optional_dict is None:
116            optional_dict = {}
117        optional_dict['enable_transform_score'] = True
118
119    runner = runner_class(
120        assets, None, fifo_mode=True,
121        delete_workdir=True,
122        result_store=None,
123        optional_dict=optional_dict,
124        optional_dict2=None,
125    )
126
127    # run
128    runner.run()
129    result = runner.results[0]
130
131    # pooling
132    if pool_method == 'harmonic_mean':
133        result.set_score_aggregate_method(ListStats.harmonic_mean)
134    elif pool_method == 'min':
135        result.set_score_aggregate_method(np.min)
136    elif pool_method == 'median':
137        result.set_score_aggregate_method(np.median)
138    elif pool_method == 'perc5':
139        result.set_score_aggregate_method(ListStats.perc5)
140    elif pool_method == 'perc10':
141        result.set_score_aggregate_method(ListStats.perc10)
142    elif pool_method == 'perc20':
143        result.set_score_aggregate_method(ListStats.perc20)
144    else: # None or 'mean'
145        pass
146
147    # output
148    if out_fmt == 'xml':
149        print(result.to_xml())
150    elif out_fmt == 'json':
151        print(result.to_json())
152    else: # None or 'text'
153        print(str(result))
154
155    # local explanation
156    if show_local_explanation:
157        runner.show_local_explanations([result])
158
159        if save_plot_dir is None:
160            DisplayConfig.show()
161        else:
162            DisplayConfig.show(write_to_dir=save_plot_dir)
163
164    return 0
165
166
167if __name__ == "__main__":
168    ret = main()
169    exit(ret)
170