1 #include <stdlib.h>
2 #include <string.h>
3 
4 #include "feature/alias.h"
5 #include "log.h"
6 #include "libvmaf/libvmaf.h"
7 #include "model.h"
8 
log_fmt_map(const char * log_fmt)9 static enum VmafOutputFormat log_fmt_map(const char *log_fmt)
10 {
11     if (log_fmt) {
12         if (!strcmp(log_fmt, "xml"))
13             return VMAF_OUTPUT_FORMAT_XML;
14         if (!strcmp(log_fmt, "json"))
15             return VMAF_OUTPUT_FORMAT_JSON;
16         if (!strcmp(log_fmt, "csv"))
17             return VMAF_OUTPUT_FORMAT_CSV;
18         if (!strcmp(log_fmt, "sub"))
19             return VMAF_OUTPUT_FORMAT_SUB;
20     }
21 
22     return VMAF_OUTPUT_FORMAT_NONE;
23 }
24 
pool_method_map(const char * pool_method)25 static enum VmafPoolingMethod pool_method_map(const char *pool_method)
26 {
27     if (pool_method) {
28         if (!strcmp(pool_method, "min"))
29             return VMAF_POOL_METHOD_MIN;
30         if (!strcmp(pool_method, "mean"))
31             return VMAF_POOL_METHOD_MEAN;
32         if (!strcmp(pool_method, "harmonic_mean"))
33             return VMAF_POOL_METHOD_HARMONIC_MEAN;
34     }
35 
36     return VMAF_POOL_METHOD_MEAN;
37 }
38 
pix_fmt_map(char * fmt)39 static int pix_fmt_map(char *fmt)
40 {
41     if (fmt) {
42         if (!strcmp(fmt, "yuv420p"))
43             return VMAF_PIX_FMT_YUV420P;
44         if (!strcmp(fmt, "yuv422p"))
45             return VMAF_PIX_FMT_YUV422P;
46         if (!strcmp(fmt, "yuv444p"))
47             return VMAF_PIX_FMT_YUV444P;
48         if (!strcmp(fmt, "yuv420p10le"))
49             return VMAF_PIX_FMT_YUV420P;
50         if (!strcmp(fmt, "yuv420p12le"))
51             return VMAF_PIX_FMT_YUV420P;
52         if (!strcmp(fmt, "yuv420p16le"))
53             return VMAF_PIX_FMT_YUV420P;
54         if (!strcmp(fmt, "yuv422p10le"))
55             return VMAF_PIX_FMT_YUV422P;
56         if (!strcmp(fmt, "yuv422p10le"))
57             return VMAF_PIX_FMT_YUV422P;
58         if (!strcmp(fmt, "yuv444p10le"))
59             return VMAF_PIX_FMT_YUV444P;
60     }
61 
62     return VMAF_PIX_FMT_UNKNOWN;
63 
64 }
65 
bitdepth_map(char * fmt)66 static int bitdepth_map(char *fmt)
67 {
68     if (!strcmp(fmt, "yuv420p10le"))
69         return 10;
70     if (!strcmp(fmt, "yuv422p10le"))
71         return 10;
72     if (!strcmp(fmt, "yuv444p10le"))
73         return 10;
74     if (!strcmp(fmt, "yuv420p12le"))
75         return 12;
76     if (!strcmp(fmt, "yuv420p16le"))
77         return 16;
78 
79     return 8;
80 }
81 
copy_data(float * src,VmafPicture * dst,unsigned width,unsigned height,int src_stride)82 static void copy_data(float *src, VmafPicture *dst, unsigned width,
83                       unsigned height, int src_stride)
84 {
85     float *a = src;
86     uint8_t *b = dst->data[0];
87     for (unsigned i = 0; i < height; i++) {
88         for (unsigned j = 0; j < width; j++) {
89             b[j] = a[j];
90         }
91         a += src_stride / sizeof(float);
92         b += dst->stride[0];
93     }
94 }
95 
copy_data_hbd(float * src,VmafPicture * dst,unsigned width,unsigned height,int src_stride,unsigned bpc)96 static void copy_data_hbd(float *src, VmafPicture *dst, unsigned width,
97                           unsigned height, int src_stride, unsigned bpc)
98 {
99     float *a = src;
100     uint16_t *b = dst->data[0];
101     for (unsigned i = 0; i < height; i++) {
102         for (unsigned j = 0; j < width; j++) {
103             b[j] = a[j] * (1 << (bpc - 8));
104         }
105         a += src_stride / sizeof(float);
106         b += dst->stride[0] / sizeof(uint16_t);
107     }
108 }
109 
compute_vmaf(double * vmaf_score,char * fmt,int width,int height,int (* read_frame)(float * ref_data,float * main_data,float * temp_data,int stride_byte,void * user_data),void * user_data,char * model_path,char * log_path,char * log_fmt,int disable_clip,int disable_avx,int enable_transform,int phone_model,int do_psnr,int do_ssim,int do_ms_ssim,char * pool_method,int n_thread,int n_subsample,int enable_conf_interval)110 int compute_vmaf(double* vmaf_score, char* fmt, int width, int height,
111                  int (*read_frame)(float *ref_data, float *main_data,
112                                    float *temp_data, int stride_byte,
113                                    void *user_data),
114                  void *user_data, char *model_path, char *log_path,
115                  char *log_fmt, int disable_clip, int disable_avx,
116                  int enable_transform, int phone_model, int do_psnr,
117                  int do_ssim, int do_ms_ssim, char *pool_method,
118                  int n_thread, int n_subsample, int enable_conf_interval)
119 {
120 
121     vmaf_set_log_level(VMAF_LOG_LEVEL_INFO);
122     vmaf_log(VMAF_LOG_LEVEL_INFO, "`compute_vmaf()` is deprecated "
123              "and will be removed in a future libvmaf version\n");
124 
125     int err = 0;
126 
127     VmafConfiguration cfg = {
128         .log_level = VMAF_LOG_LEVEL_INFO,
129         .n_threads = n_thread,
130         .n_subsample = n_subsample,
131         .cpumask = disable_avx ? -1 : 0,
132     };
133 
134     VmafContext *vmaf;
135     err = vmaf_init(&vmaf, cfg);
136     if (err) {
137         vmaf_log(VMAF_LOG_LEVEL_ERROR, "problem initializing VMAF context\n");
138         return -1;
139     }
140 
141     enum VmafModelFlags flags = VMAF_MODEL_FLAGS_DEFAULT;
142     if (disable_clip)
143         flags |= VMAF_MODEL_FLAG_DISABLE_CLIP;
144     if (enable_transform || phone_model)
145         flags |= VMAF_MODEL_FLAG_ENABLE_TRANSFORM;
146 
147     VmafModelConfig model_cfg = {
148         .name = "vmaf",
149         .flags = flags,
150     };
151 
152     VmafModel *model = NULL;
153     VmafModelCollection *model_collection = NULL;
154 
155     if (enable_conf_interval) {
156         err = vmaf_model_collection_load_from_path(&model, &model_collection,
157                                                    &model_cfg, model_path);
158         if (err) {
159             vmaf_log(VMAF_LOG_LEVEL_ERROR,
160                      "problem loading model file: %s\n", model_path);
161             goto end;
162         }
163         err = vmaf_use_features_from_model_collection(vmaf, model_collection);
164         if (err) {
165             vmaf_log(VMAF_LOG_LEVEL_ERROR,
166                     "problem loading feature extractors from model file: %s\n",
167                     model_path);
168             goto end;
169         }
170     } else {
171         err = vmaf_model_load_from_path(&model, &model_cfg, model_path);
172         if (err) {
173             vmaf_log(VMAF_LOG_LEVEL_ERROR,
174                      "problem loading model file: %s\n", model_path);
175             goto end;
176         }
177         err = vmaf_use_features_from_model(vmaf, model);
178         if (err) {
179             vmaf_log(VMAF_LOG_LEVEL_ERROR,
180                     "problem loading feature extractors from model file: %s\n",
181                     model_path);
182             goto end;
183         }
184     }
185 
186     if (do_psnr) {
187         VmafFeatureDictionary *d = NULL;
188         vmaf_feature_dictionary_set(&d, "enable_chroma", "false");
189 
190         err = vmaf_use_feature(vmaf, "psnr", d);
191         if (err) {
192             vmaf_log(VMAF_LOG_LEVEL_ERROR,
193                      "problem loading feature extractor: psnr\n");
194             goto end;
195         }
196     }
197 
198     if (do_ssim) {
199         err = vmaf_use_feature(vmaf, "float_ssim", NULL);
200         if (err) {
201             vmaf_log(VMAF_LOG_LEVEL_ERROR,
202                      "problem loading feature extractor: ssim\n");
203             goto end;
204         }
205     }
206 
207     if (do_ms_ssim) {
208         err = vmaf_use_feature(vmaf, "float_ms_ssim", NULL);
209         if (err) {
210             vmaf_log(VMAF_LOG_LEVEL_ERROR,
211                      "problem loading feature extractor: ms_ssim\n");
212             goto end;
213         }
214     }
215 
216     int stride = width * sizeof(float);
217     float *ref_data = malloc(height * stride);
218     float *main_data = malloc(height * stride);
219     float *temp_data = malloc(height * stride);
220     if (!ref_data | !main_data | !temp_data) {
221         vmaf_log(VMAF_LOG_LEVEL_ERROR, "problem allocating picture memory\n");
222         err = -1;
223         goto free_data;
224     }
225 
226     unsigned picture_index;
227     for (picture_index = 0 ;; picture_index++) {
228         err = read_frame(ref_data, main_data, temp_data, stride, user_data);
229         if (err == 1) {
230             vmaf_log(VMAF_LOG_LEVEL_ERROR, "problem during read_frame\n");
231             goto free_data;
232         } else if (err == 2) {
233             break; //EOF
234         }
235 
236         VmafPicture pic_ref, pic_dist;
237         err = vmaf_picture_alloc(&pic_ref, pix_fmt_map(fmt),
238                                  bitdepth_map(fmt), width, height);
239         err |= vmaf_picture_alloc(&pic_dist, pix_fmt_map(fmt),
240                                   bitdepth_map(fmt), width, height);
241         if (err) {
242             vmaf_log(VMAF_LOG_LEVEL_ERROR,
243                      "problem allocating picture memory\n");
244             vmaf_picture_unref(&pic_ref);
245             vmaf_picture_unref(&pic_dist);
246             goto free_data;
247         }
248 
249         const unsigned bpc = bitdepth_map(fmt);
250         if (bpc > 8) {
251             copy_data_hbd(ref_data, &pic_ref, width, height, stride, bpc);
252             copy_data_hbd(main_data, &pic_dist, width, height, stride, bpc);
253         } else {
254             copy_data(ref_data, &pic_ref, width, height, stride);
255             copy_data(main_data, &pic_dist, width, height, stride);
256         }
257 
258         err = vmaf_read_pictures(vmaf, &pic_ref, &pic_dist, picture_index);
259         if (err) {
260             vmaf_log(VMAF_LOG_LEVEL_ERROR, "problem reading pictures\n");
261             break;
262         }
263     }
264 
265     err = vmaf_read_pictures(vmaf, NULL, NULL, 0);
266     if (err) {
267         vmaf_log(VMAF_LOG_LEVEL_ERROR, "problem flushing context\n");
268         return err;
269     }
270 
271      if (enable_conf_interval) {
272          VmafModelCollectionScore model_collection_score;
273          err = vmaf_score_pooled_model_collection(vmaf, model_collection,
274                                                   pool_method_map(pool_method),
275                                                   &model_collection_score, 0,
276                                                   picture_index - 1);
277          if (err) {
278              vmaf_log(VMAF_LOG_LEVEL_ERROR,
279                       "problem generating pooled VMAF score\n");
280              goto free_data;
281          }
282     }
283 
284      err = vmaf_score_pooled(vmaf, model, pool_method_map(pool_method),
285                              vmaf_score, 0, picture_index - 1);
286      if (err) {
287          vmaf_log(VMAF_LOG_LEVEL_ERROR,
288                   "problem generating pooled VMAF score\n");
289          goto free_data;
290      }
291 
292     enum VmafOutputFormat output_fmt = log_fmt_map(log_fmt);
293     if (output_fmt == VMAF_OUTPUT_FORMAT_NONE && log_path) {
294         output_fmt = VMAF_OUTPUT_FORMAT_XML;
295         vmaf_log(VMAF_LOG_LEVEL_WARNING, "use default log_fmt xml");
296     }
297     if (output_fmt) {
298         vmaf_use_vmafossexec_aliases();
299         err = vmaf_write_output(vmaf, log_path, output_fmt);
300         if (err) {
301             vmaf_log(VMAF_LOG_LEVEL_ERROR,
302                      "could not write output: %s\n", log_path);
303             goto free_data;
304         }
305     }
306 
307 free_data:
308     if (ref_data) free(ref_data);
309     if (main_data) free(main_data);
310     if (temp_data) free(temp_data);
311 end:
312     vmaf_model_destroy(model);
313     vmaf_model_collection_destroy(model_collection);
314     vmaf_close(vmaf);
315     return err;
316 }
317