1 // This simple IO library works the Halide::Buffer<T> type or any
2 // other image type with the same API.
3 
4 #ifndef HALIDE_IMAGE_IO_H
5 #define HALIDE_IMAGE_IO_H
6 
7 #include <algorithm>
8 #include <cctype>
9 #include <cstdarg>
10 #include <cstddef>
11 #include <cstdio>
12 #include <cstdlib>
13 #include <functional>
14 #include <map>
15 #include <set>
16 #include <string>
17 #include <vector>
18 
19 #ifndef HALIDE_NO_PNG
20 #include "png.h"
21 #endif
22 
23 #ifndef HALIDE_NO_JPEG
24 #ifdef _WIN32
25 #ifndef NOMINMAX
26 #define NOMINMAX
27 #endif
28 #include <windows.h>
29 #endif
30 #include "jpeglib.h"
31 #endif
32 
33 #include "HalideRuntime.h"  // for halide_type_t
34 
35 namespace Halide {
36 namespace Tools {
37 
38 struct FormatInfo {
39     halide_type_t type;
40     int dimensions;
41 
42     bool operator<(const FormatInfo &other) const {
43         if (type.code < other.type.code) {
44             return true;
45         } else if (type.code > other.type.code) {
46             return false;
47         }
48         if (type.bits < other.type.bits) {
49             return true;
50         } else if (type.bits > other.type.bits) {
51             return false;
52         }
53         if (type.lanes < other.type.lanes) {
54             return true;
55         } else if (type.lanes > other.type.lanes) {
56             return false;
57         }
58         return (dimensions < other.dimensions);
59     }
60 };
61 
62 namespace Internal {
63 
64 // Must be constexpr to allow use in case clauses.
halide_type_code(halide_type_code_t code,int bits)65 inline constexpr int halide_type_code(halide_type_code_t code, int bits) {
66     return (((int)code) << 8) | bits;
67 }
68 
69 typedef bool (*CheckFunc)(bool condition, const char *msg);
70 
CheckFail(bool condition,const char * msg)71 inline bool CheckFail(bool condition, const char *msg) {
72     if (!condition) {
73         fprintf(stderr, "%s\n", msg);
74         abort();
75     }
76     return condition;
77 }
78 
CheckReturn(bool condition,const char * msg)79 inline bool CheckReturn(bool condition, const char *msg) {
80     return condition;
81 }
82 
83 template<typename To, typename From>
84 To convert(const From &from);
85 
86 // Convert to bool
87 template<>
convert(const bool & in)88 inline bool convert(const bool &in) {
89     return in;
90 }
91 template<>
convert(const uint8_t & in)92 inline bool convert(const uint8_t &in) {
93     return in != 0;
94 }
95 template<>
convert(const uint16_t & in)96 inline bool convert(const uint16_t &in) {
97     return in != 0;
98 }
99 template<>
convert(const uint32_t & in)100 inline bool convert(const uint32_t &in) {
101     return in != 0;
102 }
103 template<>
convert(const uint64_t & in)104 inline bool convert(const uint64_t &in) {
105     return in != 0;
106 }
107 template<>
convert(const int8_t & in)108 inline bool convert(const int8_t &in) {
109     return in != 0;
110 }
111 template<>
convert(const int16_t & in)112 inline bool convert(const int16_t &in) {
113     return in != 0;
114 }
115 template<>
convert(const int32_t & in)116 inline bool convert(const int32_t &in) {
117     return in != 0;
118 }
119 template<>
convert(const int64_t & in)120 inline bool convert(const int64_t &in) {
121     return in != 0;
122 }
123 template<>
convert(const float & in)124 inline bool convert(const float &in) {
125     return in != 0;
126 }
127 template<>
convert(const double & in)128 inline bool convert(const double &in) {
129     return in != 0;
130 }
131 
132 // Convert to u8
133 template<>
convert(const bool & in)134 inline uint8_t convert(const bool &in) {
135     return in;
136 }
137 template<>
convert(const uint8_t & in)138 inline uint8_t convert(const uint8_t &in) {
139     return in;
140 }
141 template<>
convert(const uint16_t & in)142 inline uint8_t convert(const uint16_t &in) {
143     uint32_t tmp = (uint32_t)(in) + 0x80;
144     // Fast approximation of div-by-257: see http://research.swtch.com/divmult
145     return ((tmp * 255 + 255) >> 16);
146 }
147 template<>
convert(const uint32_t & in)148 inline uint8_t convert(const uint32_t &in) {
149     return (uint8_t)((((uint64_t)in) + 0x00808080) / 0x01010101);
150 }
151 // uint64 -> 8 just discards the lower 32 bits: if you were expecting more precision, well, sorry
152 template<>
convert(const uint64_t & in)153 inline uint8_t convert(const uint64_t &in) {
154     return convert<uint8_t, uint32_t>(uint32_t(in >> 32));
155 }
156 template<>
convert(const int8_t & in)157 inline uint8_t convert(const int8_t &in) {
158     return convert<uint8_t, uint8_t>(in);
159 }
160 template<>
convert(const int16_t & in)161 inline uint8_t convert(const int16_t &in) {
162     return convert<uint8_t, uint16_t>(in);
163 }
164 template<>
convert(const int32_t & in)165 inline uint8_t convert(const int32_t &in) {
166     return convert<uint8_t, uint32_t>(in);
167 }
168 template<>
convert(const int64_t & in)169 inline uint8_t convert(const int64_t &in) {
170     return convert<uint8_t, uint64_t>(in);
171 }
172 template<>
convert(const float & in)173 inline uint8_t convert(const float &in) {
174     return (uint8_t)(in * 255.0f + 0.5f);
175 }
176 template<>
convert(const double & in)177 inline uint8_t convert(const double &in) {
178     return (uint8_t)(in * 255.0 + 0.5);
179 }
180 
181 // Convert to u16
182 template<>
convert(const bool & in)183 inline uint16_t convert(const bool &in) {
184     return in;
185 }
186 template<>
convert(const uint8_t & in)187 inline uint16_t convert(const uint8_t &in) {
188     return uint16_t(in) * 0x0101;
189 }
190 template<>
convert(const uint16_t & in)191 inline uint16_t convert(const uint16_t &in) {
192     return in;
193 }
194 template<>
convert(const uint32_t & in)195 inline uint16_t convert(const uint32_t &in) {
196     return in >> 16;
197 }
198 template<>
convert(const uint64_t & in)199 inline uint16_t convert(const uint64_t &in) {
200     return in >> 48;
201 }
202 template<>
convert(const int8_t & in)203 inline uint16_t convert(const int8_t &in) {
204     return convert<uint16_t, uint8_t>(in);
205 }
206 template<>
convert(const int16_t & in)207 inline uint16_t convert(const int16_t &in) {
208     return convert<uint16_t, uint16_t>(in);
209 }
210 template<>
convert(const int32_t & in)211 inline uint16_t convert(const int32_t &in) {
212     return convert<uint16_t, uint32_t>(in);
213 }
214 template<>
convert(const int64_t & in)215 inline uint16_t convert(const int64_t &in) {
216     return convert<uint16_t, uint64_t>(in);
217 }
218 template<>
convert(const float & in)219 inline uint16_t convert(const float &in) {
220     return (uint16_t)(in * 65535.0f + 0.5f);
221 }
222 template<>
convert(const double & in)223 inline uint16_t convert(const double &in) {
224     return (uint16_t)(in * 65535.0 + 0.5);
225 }
226 
227 // Convert to u32
228 template<>
convert(const bool & in)229 inline uint32_t convert(const bool &in) {
230     return in;
231 }
232 template<>
convert(const uint8_t & in)233 inline uint32_t convert(const uint8_t &in) {
234     return uint32_t(in) * 0x01010101;
235 }
236 template<>
convert(const uint16_t & in)237 inline uint32_t convert(const uint16_t &in) {
238     return uint32_t(in) * 0x00010001;
239 }
240 template<>
convert(const uint32_t & in)241 inline uint32_t convert(const uint32_t &in) {
242     return in;
243 }
244 template<>
convert(const uint64_t & in)245 inline uint32_t convert(const uint64_t &in) {
246     return (uint32_t)(in >> 32);
247 }
248 template<>
convert(const int8_t & in)249 inline uint32_t convert(const int8_t &in) {
250     return convert<uint32_t, uint8_t>(in);
251 }
252 template<>
convert(const int16_t & in)253 inline uint32_t convert(const int16_t &in) {
254     return convert<uint32_t, uint16_t>(in);
255 }
256 template<>
convert(const int32_t & in)257 inline uint32_t convert(const int32_t &in) {
258     return convert<uint32_t, uint32_t>(in);
259 }
260 template<>
convert(const int64_t & in)261 inline uint32_t convert(const int64_t &in) {
262     return convert<uint32_t, uint64_t>(in);
263 }
264 template<>
convert(const float & in)265 inline uint32_t convert(const float &in) {
266     return (uint32_t)(in * 4294967295.0 + 0.5);
267 }
268 template<>
convert(const double & in)269 inline uint32_t convert(const double &in) {
270     return (uint32_t)(in * 4294967295.0 + 0.5f);
271 }
272 
273 // Convert to u64
274 template<>
convert(const bool & in)275 inline uint64_t convert(const bool &in) {
276     return in;
277 }
278 template<>
convert(const uint8_t & in)279 inline uint64_t convert(const uint8_t &in) {
280     return uint64_t(in) * 0x0101010101010101LL;
281 }
282 template<>
convert(const uint16_t & in)283 inline uint64_t convert(const uint16_t &in) {
284     return uint64_t(in) * 0x0001000100010001LL;
285 }
286 template<>
convert(const uint32_t & in)287 inline uint64_t convert(const uint32_t &in) {
288     return uint64_t(in) * 0x0000000100000001LL;
289 }
290 template<>
convert(const uint64_t & in)291 inline uint64_t convert(const uint64_t &in) {
292     return in;
293 }
294 template<>
convert(const int8_t & in)295 inline uint64_t convert(const int8_t &in) {
296     return convert<uint64_t, uint8_t>(in);
297 }
298 template<>
convert(const int16_t & in)299 inline uint64_t convert(const int16_t &in) {
300     return convert<uint64_t, uint16_t>(in);
301 }
302 template<>
convert(const int32_t & in)303 inline uint64_t convert(const int32_t &in) {
304     return convert<uint64_t, uint64_t>(in);
305 }
306 template<>
convert(const int64_t & in)307 inline uint64_t convert(const int64_t &in) {
308     return convert<uint64_t, uint64_t>(in);
309 }
310 template<>
convert(const float & in)311 inline uint64_t convert(const float &in) {
312     return convert<uint64_t, uint32_t>((uint32_t)(in * 4294967295.0 + 0.5));
313 }
314 template<>
convert(const double & in)315 inline uint64_t convert(const double &in) {
316     return convert<uint64_t, uint32_t>((uint32_t)(in * 4294967295.0 + 0.5));
317 }
318 
319 // Convert to i8
320 template<>
convert(const bool & in)321 inline int8_t convert(const bool &in) {
322     return in;
323 }
324 template<>
convert(const uint8_t & in)325 inline int8_t convert(const uint8_t &in) {
326     return convert<uint8_t, uint8_t>(in);
327 }
328 template<>
convert(const uint16_t & in)329 inline int8_t convert(const uint16_t &in) {
330     return convert<uint8_t, uint16_t>(in);
331 }
332 template<>
convert(const uint32_t & in)333 inline int8_t convert(const uint32_t &in) {
334     return convert<uint8_t, uint32_t>(in);
335 }
336 template<>
convert(const uint64_t & in)337 inline int8_t convert(const uint64_t &in) {
338     return convert<uint8_t, uint64_t>(in);
339 }
340 template<>
convert(const int8_t & in)341 inline int8_t convert(const int8_t &in) {
342     return convert<uint8_t, int8_t>(in);
343 }
344 template<>
convert(const int16_t & in)345 inline int8_t convert(const int16_t &in) {
346     return convert<uint8_t, int16_t>(in);
347 }
348 template<>
convert(const int32_t & in)349 inline int8_t convert(const int32_t &in) {
350     return convert<uint8_t, int32_t>(in);
351 }
352 template<>
convert(const int64_t & in)353 inline int8_t convert(const int64_t &in) {
354     return convert<uint8_t, int64_t>(in);
355 }
356 template<>
convert(const float & in)357 inline int8_t convert(const float &in) {
358     return convert<uint8_t, float>(in);
359 }
360 template<>
convert(const double & in)361 inline int8_t convert(const double &in) {
362     return convert<uint8_t, double>(in);
363 }
364 
365 // Convert to i16
366 template<>
convert(const bool & in)367 inline int16_t convert(const bool &in) {
368     return in;
369 }
370 template<>
convert(const uint8_t & in)371 inline int16_t convert(const uint8_t &in) {
372     return convert<uint16_t, uint8_t>(in);
373 }
374 template<>
convert(const uint16_t & in)375 inline int16_t convert(const uint16_t &in) {
376     return convert<uint16_t, uint16_t>(in);
377 }
378 template<>
convert(const uint32_t & in)379 inline int16_t convert(const uint32_t &in) {
380     return convert<uint16_t, uint32_t>(in);
381 }
382 template<>
convert(const uint64_t & in)383 inline int16_t convert(const uint64_t &in) {
384     return convert<uint16_t, uint64_t>(in);
385 }
386 template<>
convert(const int8_t & in)387 inline int16_t convert(const int8_t &in) {
388     return convert<uint16_t, int8_t>(in);
389 }
390 template<>
convert(const int16_t & in)391 inline int16_t convert(const int16_t &in) {
392     return convert<uint16_t, int16_t>(in);
393 }
394 template<>
convert(const int32_t & in)395 inline int16_t convert(const int32_t &in) {
396     return convert<uint16_t, int32_t>(in);
397 }
398 template<>
convert(const int64_t & in)399 inline int16_t convert(const int64_t &in) {
400     return convert<uint16_t, int64_t>(in);
401 }
402 template<>
convert(const float & in)403 inline int16_t convert(const float &in) {
404     return convert<uint16_t, float>(in);
405 }
406 template<>
convert(const double & in)407 inline int16_t convert(const double &in) {
408     return convert<uint16_t, double>(in);
409 }
410 
411 // Convert to i32
412 template<>
convert(const bool & in)413 inline int32_t convert(const bool &in) {
414     return in;
415 }
416 template<>
convert(const uint8_t & in)417 inline int32_t convert(const uint8_t &in) {
418     return convert<uint32_t, uint8_t>(in);
419 }
420 template<>
convert(const uint16_t & in)421 inline int32_t convert(const uint16_t &in) {
422     return convert<uint32_t, uint16_t>(in);
423 }
424 template<>
convert(const uint32_t & in)425 inline int32_t convert(const uint32_t &in) {
426     return convert<uint32_t, uint32_t>(in);
427 }
428 template<>
convert(const uint64_t & in)429 inline int32_t convert(const uint64_t &in) {
430     return convert<uint32_t, uint64_t>(in);
431 }
432 template<>
convert(const int8_t & in)433 inline int32_t convert(const int8_t &in) {
434     return convert<uint32_t, int8_t>(in);
435 }
436 template<>
convert(const int16_t & in)437 inline int32_t convert(const int16_t &in) {
438     return convert<uint32_t, int16_t>(in);
439 }
440 template<>
convert(const int32_t & in)441 inline int32_t convert(const int32_t &in) {
442     return convert<uint32_t, int32_t>(in);
443 }
444 template<>
convert(const int64_t & in)445 inline int32_t convert(const int64_t &in) {
446     return convert<uint32_t, int64_t>(in);
447 }
448 template<>
convert(const float & in)449 inline int32_t convert(const float &in) {
450     return convert<uint32_t, float>(in);
451 }
452 template<>
convert(const double & in)453 inline int32_t convert(const double &in) {
454     return convert<uint32_t, double>(in);
455 }
456 
457 // Convert to i64
458 template<>
convert(const bool & in)459 inline int64_t convert(const bool &in) {
460     return in;
461 }
462 template<>
convert(const uint8_t & in)463 inline int64_t convert(const uint8_t &in) {
464     return convert<uint64_t, uint8_t>(in);
465 }
466 template<>
convert(const uint16_t & in)467 inline int64_t convert(const uint16_t &in) {
468     return convert<uint64_t, uint16_t>(in);
469 }
470 template<>
convert(const uint32_t & in)471 inline int64_t convert(const uint32_t &in) {
472     return convert<uint64_t, uint32_t>(in);
473 }
474 template<>
convert(const uint64_t & in)475 inline int64_t convert(const uint64_t &in) {
476     return convert<uint64_t, uint64_t>(in);
477 }
478 template<>
convert(const int8_t & in)479 inline int64_t convert(const int8_t &in) {
480     return convert<uint64_t, int8_t>(in);
481 }
482 template<>
convert(const int16_t & in)483 inline int64_t convert(const int16_t &in) {
484     return convert<uint64_t, int16_t>(in);
485 }
486 template<>
convert(const int32_t & in)487 inline int64_t convert(const int32_t &in) {
488     return convert<uint64_t, int32_t>(in);
489 }
490 template<>
convert(const int64_t & in)491 inline int64_t convert(const int64_t &in) {
492     return convert<uint64_t, int64_t>(in);
493 }
494 template<>
convert(const float & in)495 inline int64_t convert(const float &in) {
496     return convert<uint64_t, float>(in);
497 }
498 template<>
convert(const double & in)499 inline int64_t convert(const double &in) {
500     return convert<uint64_t, double>(in);
501 }
502 
503 // Convert to f32
504 template<>
convert(const bool & in)505 inline float convert(const bool &in) {
506     return in;
507 }
508 template<>
convert(const uint8_t & in)509 inline float convert(const uint8_t &in) {
510     return in / 255.0f;
511 }
512 template<>
convert(const uint16_t & in)513 inline float convert(const uint16_t &in) {
514     return in / 65535.0f;
515 }
516 template<>
convert(const uint32_t & in)517 inline float convert(const uint32_t &in) {
518     return (float)(in / 4294967295.0);
519 }
520 template<>
convert(const uint64_t & in)521 inline float convert(const uint64_t &in) {
522     return convert<float, uint32_t>(uint32_t(in >> 32));
523 }
524 template<>
convert(const int8_t & in)525 inline float convert(const int8_t &in) {
526     return convert<float, uint8_t>(in);
527 }
528 template<>
convert(const int16_t & in)529 inline float convert(const int16_t &in) {
530     return convert<float, uint16_t>(in);
531 }
532 template<>
convert(const int32_t & in)533 inline float convert(const int32_t &in) {
534     return convert<float, uint64_t>(in);
535 }
536 template<>
convert(const int64_t & in)537 inline float convert(const int64_t &in) {
538     return convert<float, uint64_t>(in);
539 }
540 template<>
convert(const float & in)541 inline float convert(const float &in) {
542     return in;
543 }
544 template<>
convert(const double & in)545 inline float convert(const double &in) {
546     return (float)in;
547 }
548 
549 // Convert to f64
550 template<>
convert(const bool & in)551 inline double convert(const bool &in) {
552     return in;
553 }
554 template<>
convert(const uint8_t & in)555 inline double convert(const uint8_t &in) {
556     return in / 255.0f;
557 }
558 template<>
convert(const uint16_t & in)559 inline double convert(const uint16_t &in) {
560     return in / 65535.0f;
561 }
562 template<>
convert(const uint32_t & in)563 inline double convert(const uint32_t &in) {
564     return (double)(in / 4294967295.0);
565 }
566 template<>
convert(const uint64_t & in)567 inline double convert(const uint64_t &in) {
568     return convert<double, uint32_t>(uint32_t(in >> 32));
569 }
570 template<>
convert(const int8_t & in)571 inline double convert(const int8_t &in) {
572     return convert<double, uint8_t>(in);
573 }
574 template<>
convert(const int16_t & in)575 inline double convert(const int16_t &in) {
576     return convert<double, uint16_t>(in);
577 }
578 template<>
convert(const int32_t & in)579 inline double convert(const int32_t &in) {
580     return convert<double, uint64_t>(in);
581 }
582 template<>
convert(const int64_t & in)583 inline double convert(const int64_t &in) {
584     return convert<double, uint64_t>(in);
585 }
586 template<>
convert(const float & in)587 inline double convert(const float &in) {
588     return (double)in;
589 }
590 template<>
convert(const double & in)591 inline double convert(const double &in) {
592     return in;
593 }
594 
to_lowercase(const std::string & s)595 inline std::string to_lowercase(const std::string &s) {
596     std::string r = s;
597     std::transform(r.begin(), r.end(), r.begin(), ::tolower);
598     return r;
599 }
600 
get_lowercase_extension(const std::string & path)601 inline std::string get_lowercase_extension(const std::string &path) {
602     size_t last_dot = path.rfind('.');
603     if (last_dot == std::string::npos) {
604         return "";
605     }
606     return to_lowercase(path.substr(last_dot + 1));
607 }
608 
609 template<typename ElemType>
610 ElemType read_big_endian(const uint8_t *src);
611 
612 template<>
read_big_endian(const uint8_t * src)613 inline uint8_t read_big_endian(const uint8_t *src) {
614     return *src;
615 }
616 
617 template<>
read_big_endian(const uint8_t * src)618 inline uint16_t read_big_endian(const uint8_t *src) {
619     return (((uint16_t)src[0]) << 8) | ((uint16_t)src[1]);
620 }
621 
622 template<typename ElemType>
623 void write_big_endian(const ElemType &src, uint8_t *dst);
624 
625 template<>
write_big_endian(const uint8_t & src,uint8_t * dst)626 inline void write_big_endian(const uint8_t &src, uint8_t *dst) {
627     *dst = src;
628 }
629 
630 template<>
write_big_endian(const uint16_t & src,uint8_t * dst)631 inline void write_big_endian(const uint16_t &src, uint8_t *dst) {
632     dst[0] = src >> 8;
633     dst[1] = src & 0xff;
634 }
635 
636 struct FileOpener {
FileOpenerFileOpener637     FileOpener(const std::string &filename, const char *mode)
638         : f(fopen(filename.c_str(), mode)) {
639         // nothing
640     }
641 
~FileOpenerFileOpener642     ~FileOpener() {
643         if (f != nullptr) {
644             fclose(f);
645         }
646     }
647 
648     // read a line of data, skipping lines that begin with '#"
read_lineFileOpener649     char *read_line(char *buf, int maxlen) {
650         char *status;
651         do {
652             status = fgets(buf, maxlen, f);
653         } while (status && buf[0] == '#');
654         return (status);
655     }
656 
657     // call read_line and to a sscanf() on it
scan_lineFileOpener658     int scan_line(const char *fmt, ...) {
659         char buf[1024];
660         if (!read_line(buf, 1024)) {
661             return 0;
662         }
663         va_list args;
664         va_start(args, fmt);
665         int result = vsscanf(buf, fmt, args);
666         va_end(args);
667         return result;
668     }
669 
read_bytesFileOpener670     bool read_bytes(void *data, size_t count) {
671         return fread(data, 1, count, f) == count;
672     }
673 
674     template<typename T, size_t N>
read_arrayFileOpener675     bool read_array(T (&data)[N]) {
676         return read_bytes(&data[0], sizeof(T) * N);
677     }
678 
679     template<typename T>
read_vectorFileOpener680     bool read_vector(std::vector<T> *v) {
681         return read_bytes(v->data(), v->size() * sizeof(T));
682     }
683 
write_bytesFileOpener684     bool write_bytes(const void *data, size_t count) {
685         return fwrite(data, 1, count, f) == count;
686     }
687 
688     template<typename T>
write_vectorFileOpener689     bool write_vector(const std::vector<T> &v) {
690         return write_bytes(v.data(), v.size() * sizeof(T));
691     }
692 
693     template<typename T, size_t N>
write_arrayFileOpener694     bool write_array(const T (&data)[N]) {
695         return write_bytes(&data[0], sizeof(T) * N);
696     }
697 
698     FILE *const f;
699 };
700 
701 // Read a row of ElemTypes from a byte buffer and copy them into a specific image row.
702 // Multibyte elements are assumed to be big-endian.
703 template<typename ElemType, typename ImageType>
read_big_endian_row(const uint8_t * src,int y,ImageType * im)704 void read_big_endian_row(const uint8_t *src, int y, ImageType *im) {
705     auto im_typed = im->template as<ElemType>();
706     const int xmin = im_typed.dim(0).min();
707     const int xmax = im_typed.dim(0).max();
708     if (im_typed.dimensions() > 2) {
709         const int cmin = im_typed.dim(2).min();
710         const int cmax = im_typed.dim(2).max();
711         for (int x = xmin; x <= xmax; x++) {
712             for (int c = cmin; c <= cmax; c++) {
713                 im_typed(x, y, c + cmin) = read_big_endian<ElemType>(src);
714                 src += sizeof(ElemType);
715             }
716         }
717     } else {
718         for (int x = xmin; x <= xmax; x++) {
719             im_typed(x, y) = read_big_endian<ElemType>(src);
720             src += sizeof(ElemType);
721         }
722     }
723 }
724 
725 // Copy a row from an image into a byte buffer.
726 // Multibyte elements are written in big-endian layout.
727 template<typename ElemType, typename ImageType>
write_big_endian_row(const ImageType & im,int y,uint8_t * dst)728 void write_big_endian_row(const ImageType &im, int y, uint8_t *dst) {
729     auto im_typed = im.template as<typename std::add_const<ElemType>::type>();
730     const int xmin = im_typed.dim(0).min();
731     const int xmax = im_typed.dim(0).max();
732     if (im_typed.dimensions() > 2) {
733         const int cmin = im_typed.dim(2).min();
734         const int cmax = im_typed.dim(2).max();
735         for (int x = xmin; x <= xmax; x++) {
736             for (int c = cmin; c <= cmax; c++) {
737                 write_big_endian<ElemType>(im_typed(x, y, c), dst);
738                 dst += sizeof(ElemType);
739             }
740         }
741     } else {
742         for (int x = xmin; x <= xmax; x++) {
743             write_big_endian<ElemType>(im_typed(x, y), dst);
744             dst += sizeof(ElemType);
745         }
746     }
747 }
748 
749 #ifndef HALIDE_NO_PNG
750 
751 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_png(const std::string & filename,ImageType * im)752 bool load_png(const std::string &filename, ImageType *im) {
753     static_assert(!ImageType::has_static_halide_type, "");
754 
755     /* open file and test for it being a png */
756     Internal::FileOpener f(filename, "rb");
757     if (!check(f.f != nullptr, "File could not be opened for reading")) {
758         return false;
759     }
760     png_byte header[8];
761     if (!check(f.read_array(header), "File ended before end of header")) {
762         return false;
763     }
764     if (!check(!png_sig_cmp(header, 0, 8), "File is not recognized as a PNG file")) {
765         return false;
766     }
767 
768     /* initialize stuff */
769     png_structp png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
770     if (!check(png_ptr != nullptr, "png_create_read_struct failed")) {
771         return false;
772     }
773 
774     png_infop info_ptr = png_create_info_struct(png_ptr);
775     if (!check(info_ptr != nullptr, "png_create_info_struct failed")) {
776         return false;
777     }
778 
779     if (!check(!setjmp(png_jmpbuf(png_ptr)), "Error loading PNG")) {
780         return false;
781     }
782 
783     png_init_io(png_ptr, f.f);
784     png_set_sig_bytes(png_ptr, 8);
785 
786     png_read_info(png_ptr, info_ptr);
787 
788     const int width = png_get_image_width(png_ptr, info_ptr);
789     const int height = png_get_image_height(png_ptr, info_ptr);
790     const int channels = png_get_channels(png_ptr, info_ptr);
791     const int bit_depth = png_get_bit_depth(png_ptr, info_ptr);
792 
793     const halide_type_t im_type(halide_type_uint, bit_depth);
794     std::vector<int> im_dimensions = {width, height};
795     if (channels != 1) {
796         im_dimensions.push_back(channels);
797     }
798 
799     *im = ImageType(im_type, im_dimensions);
800 
801     png_read_update_info(png_ptr, info_ptr);
802 
803     auto copy_to_image = bit_depth == 8 ?
804                              Internal::read_big_endian_row<uint8_t, ImageType> :
805                              Internal::read_big_endian_row<uint16_t, ImageType>;
806 
807     std::vector<uint8_t> row(png_get_rowbytes(png_ptr, info_ptr));
808     const int ymin = im->dim(1).min();
809     const int ymax = im->dim(1).max();
810     for (int y = ymin; y <= ymax; ++y) {
811         png_read_row(png_ptr, row.data(), nullptr);
812         copy_to_image(row.data(), y, im);
813     }
814 
815     png_destroy_read_struct(&png_ptr, &info_ptr, NULL);
816 
817     return true;
818 }
819 
query_png()820 inline const std::set<FormatInfo> &query_png() {
821     static std::set<FormatInfo> info = {
822         {halide_type_t(halide_type_uint, 8), 2},
823         {halide_type_t(halide_type_uint, 16), 2},
824         {halide_type_t(halide_type_uint, 8), 3},
825         {halide_type_t(halide_type_uint, 16), 3}};
826     return info;
827 }
828 
829 // "im" is not const-ref because copy_to_host() is not const.
830 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_png(ImageType & im,const std::string & filename)831 bool save_png(ImageType &im, const std::string &filename) {
832     static_assert(!ImageType::has_static_halide_type, "");
833 
834     im.copy_to_host();
835 
836     const int width = im.width();
837     const int height = im.height();
838     const int channels = im.channels();
839 
840     if (!check(channels >= 1 && channels <= 4,
841                "Can't write PNG files that have other than 1, 2, 3, or 4 channels")) {
842         return false;
843     }
844 
845     const png_byte color_types[4] = {
846         PNG_COLOR_TYPE_GRAY,
847         PNG_COLOR_TYPE_GRAY_ALPHA,
848         PNG_COLOR_TYPE_RGB,
849         PNG_COLOR_TYPE_RGB_ALPHA};
850     png_byte color_type = color_types[channels - 1];
851 
852     // open file
853     Internal::FileOpener f(filename, "wb");
854     if (!check(f.f != nullptr, "[write_png_file] File could not be opened for writing")) {
855         return false;
856     }
857 
858     // initialize stuff
859     png_structp png_ptr = png_create_write_struct(PNG_LIBPNG_VER_STRING, NULL, NULL, NULL);
860     if (!check(png_ptr != nullptr, "[write_png_file] png_create_write_struct failed")) {
861         return false;
862     }
863 
864     png_infop info_ptr = png_create_info_struct(png_ptr);
865     if (!check(info_ptr != nullptr, "[write_png_file] png_create_info_struct failed")) {
866         return false;
867     }
868 
869     if (!check(!setjmp(png_jmpbuf(png_ptr)), "Error saving PNG")) {
870         return false;
871     }
872 
873     png_init_io(png_ptr, f.f);
874 
875     const halide_type_t im_type = im.type();
876     const int bit_depth = im_type.bits;
877 
878     png_set_IHDR(png_ptr, info_ptr, width, height,
879                  bit_depth, color_type, PNG_INTERLACE_NONE,
880                  PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE);
881 
882     png_write_info(png_ptr, info_ptr);
883 
884     auto copy_from_image = bit_depth == 8 ?
885                                Internal::write_big_endian_row<uint8_t, ImageType> :
886                                Internal::write_big_endian_row<uint16_t, ImageType>;
887 
888     std::vector<uint8_t> row(png_get_rowbytes(png_ptr, info_ptr));
889     const int ymin = im.dim(1).min();
890     const int ymax = im.dim(1).max();
891     for (int y = ymin; y <= ymax; ++y) {
892         copy_from_image(im, y, row.data());
893         png_write_row(png_ptr, row.data());
894     }
895     png_write_end(png_ptr, NULL);
896     png_destroy_write_struct(&png_ptr, &info_ptr);
897 
898     return true;
899 }
900 
901 #endif  // not HALIDE_NO_PNG
902 
903 template<Internal::CheckFunc check>
read_pnm_header(Internal::FileOpener & f,const std::string & hdr_fmt,int * width,int * height,int * bit_depth)904 bool read_pnm_header(Internal::FileOpener &f, const std::string &hdr_fmt, int *width, int *height, int *bit_depth) {
905     if (!check(f.f != nullptr, "File could not be opened for reading")) {
906         return false;
907     }
908 
909     char header[256];
910     if (!check(f.scan_line("%255s", header) == 1, "Could not read header")) {
911         return false;
912     }
913 
914     if (!check(to_lowercase(hdr_fmt) == to_lowercase(header), "Unexpected file header")) {
915         return false;
916     }
917 
918     if (!check(f.scan_line("%d %d\n", width, height) == 2, "Could not read width and height")) {
919         return false;
920     }
921 
922     int maxval;
923     if (!check(f.scan_line("%d", &maxval) == 1, "Could not read max value")) {
924         return false;
925     }
926     if (maxval == 255) {
927         *bit_depth = 8;
928     } else if (maxval == 65535) {
929         *bit_depth = 16;
930     } else {
931         *bit_depth = 0;
932         return check(false, "Invalid bit depth");
933     }
934 
935     return true;
936 }
937 
938 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_pnm(const std::string & filename,int channels,ImageType * im)939 bool load_pnm(const std::string &filename, int channels, ImageType *im) {
940     static_assert(!ImageType::has_static_halide_type, "");
941 
942     const char *hdr_fmt = channels == 3 ? "P6" : "P5";
943 
944     Internal::FileOpener f(filename, "rb");
945     int width, height, bit_depth;
946     if (!Internal::read_pnm_header<check>(f, hdr_fmt, &width, &height, &bit_depth)) {
947         return false;
948     }
949 
950     const halide_type_t im_type(halide_type_uint, bit_depth);
951     std::vector<int> im_dimensions = {width, height};
952     if (channels > 1) {
953         im_dimensions.push_back(channels);
954     }
955     *im = ImageType(im_type, im_dimensions);
956 
957     auto copy_to_image = bit_depth == 8 ?
958                              Internal::read_big_endian_row<uint8_t, ImageType> :
959                              Internal::read_big_endian_row<uint16_t, ImageType>;
960 
961     std::vector<uint8_t> row(width * channels * (bit_depth / 8));
962     const int ymin = im->dim(1).min();
963     const int ymax = im->dim(1).max();
964     for (int y = ymin; y <= ymax; ++y) {
965         if (!check(f.read_vector(&row), "Could not read data")) {
966             return false;
967         }
968         copy_to_image(row.data(), y, im);
969     }
970 
971     return true;
972 }
973 
974 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_pnm(ImageType & im,const int channels,const std::string & filename)975 bool save_pnm(ImageType &im, const int channels, const std::string &filename) {
976     static_assert(!ImageType::has_static_halide_type, "");
977 
978     if (!check(im.channels() == channels, "Wrong number of channels")) {
979         return false;
980     }
981 
982     im.copy_to_host();
983 
984     const halide_type_t im_type = im.type();
985     const int width = im.width();
986     const int height = im.height();
987     const int bit_depth = im_type.bits;
988 
989     Internal::FileOpener f(filename, "wb");
990     if (!check(f.f != nullptr, "File could not be opened for writing")) {
991         return false;
992     }
993     const char *hdr_fmt = channels == 3 ? "P6" : "P5";
994     fprintf(f.f, "%s\n%d %d\n%d\n", hdr_fmt, width, height, (1 << bit_depth) - 1);
995 
996     auto copy_from_image = bit_depth == 8 ?
997                                Internal::write_big_endian_row<uint8_t, ImageType> :
998                                Internal::write_big_endian_row<uint16_t, ImageType>;
999 
1000     std::vector<uint8_t> row(width * channels * (bit_depth / 8));
1001     const int ymin = im.dim(1).min();
1002     const int ymax = im.dim(1).max();
1003     for (int y = ymin; y <= ymax; ++y) {
1004         copy_from_image(im, y, row.data());
1005         if (!check(f.write_vector(row), "Could not write data")) {
1006             return false;
1007         }
1008     }
1009 
1010     return true;
1011 }
1012 
1013 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_pgm(const std::string & filename,ImageType * im)1014 bool load_pgm(const std::string &filename, ImageType *im) {
1015     return Internal::load_pnm<ImageType, check>(filename, 1, im);
1016 }
1017 
query_pgm()1018 inline const std::set<FormatInfo> &query_pgm() {
1019     static std::set<FormatInfo> info = {
1020         {halide_type_t(halide_type_uint, 8), 2},
1021         {halide_type_t(halide_type_uint, 16), 2}};
1022     return info;
1023 }
1024 
1025 // "im" is not const-ref because copy_to_host() is not const.
1026 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_pgm(ImageType & im,const std::string & filename)1027 bool save_pgm(ImageType &im, const std::string &filename) {
1028     return Internal::save_pnm<ImageType, check>(im, 1, filename);
1029 }
1030 
1031 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_ppm(const std::string & filename,ImageType * im)1032 bool load_ppm(const std::string &filename, ImageType *im) {
1033     return Internal::load_pnm<ImageType, check>(filename, 3, im);
1034 }
1035 
query_ppm()1036 inline const std::set<FormatInfo> &query_ppm() {
1037     static std::set<FormatInfo> info = {
1038         {halide_type_t(halide_type_uint, 8), 3},
1039         {halide_type_t(halide_type_uint, 16), 3}};
1040     return info;
1041 }
1042 
1043 // "im" is not const-ref because copy_to_host() is not const.
1044 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_ppm(ImageType & im,const std::string & filename)1045 bool save_ppm(ImageType &im, const std::string &filename) {
1046     return Internal::save_pnm<ImageType, check>(im, 3, filename);
1047 }
1048 
1049 #ifndef HALIDE_NO_JPEG
1050 
1051 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_jpg(const std::string & filename,ImageType * im)1052 bool load_jpg(const std::string &filename, ImageType *im) {
1053     static_assert(!ImageType::has_static_halide_type, "");
1054 
1055     Internal::FileOpener f(filename, "rb");
1056     if (!check(f.f != nullptr, "File could not be opened for reading")) {
1057         return false;
1058     }
1059 
1060     struct jpeg_decompress_struct cinfo;
1061     struct jpeg_error_mgr jerr;
1062     cinfo.err = jpeg_std_error(&jerr);
1063     jpeg_create_decompress(&cinfo);
1064     jpeg_stdio_src(&cinfo, f.f);
1065     jpeg_read_header(&cinfo, TRUE);
1066     jpeg_start_decompress(&cinfo);
1067 
1068     const int width = cinfo.output_width;
1069     const int height = cinfo.output_height;
1070     const int channels = cinfo.output_components;
1071 
1072     const halide_type_t im_type(halide_type_uint, 8);
1073     std::vector<int> im_dimensions = {width, height};
1074     if (channels > 1) {
1075         im_dimensions.push_back(channels);
1076     }
1077     *im = ImageType(im_type, im_dimensions);
1078 
1079     auto copy_to_image = Internal::read_big_endian_row<uint8_t, ImageType>;
1080 
1081     std::vector<uint8_t> row(width * channels);
1082     const int ymin = im->dim(1).min();
1083     const int ymax = im->dim(1).max();
1084     for (int y = ymin; y <= ymax; ++y) {
1085         uint8_t *src = row.data();
1086         jpeg_read_scanlines(&cinfo, &src, 1);
1087         copy_to_image(row.data(), y, im);
1088     }
1089 
1090     jpeg_finish_decompress(&cinfo);
1091     jpeg_destroy_decompress(&cinfo);
1092 
1093     return true;
1094 }
1095 
query_jpg()1096 inline const std::set<FormatInfo> &query_jpg() {
1097     static std::set<FormatInfo> info = {
1098         {halide_type_t(halide_type_uint, 8), 2},
1099         {halide_type_t(halide_type_uint, 8), 3},
1100     };
1101     return info;
1102 }
1103 
1104 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_jpg(ImageType & im,const std::string & filename)1105 bool save_jpg(ImageType &im, const std::string &filename) {
1106     static_assert(!ImageType::has_static_halide_type, "");
1107 
1108     im.copy_to_host();
1109 
1110     const int width = im.width();
1111     const int height = im.height();
1112     const int channels = im.channels();
1113     if (!check(channels == 1 || channels == 3, "Wrong number of channels")) {
1114         return false;
1115     }
1116 
1117     Internal::FileOpener f(filename, "wb");
1118     if (!check(f.f != nullptr, "File could not be opened for writing")) {
1119         return false;
1120     }
1121 
1122     // TODO: Make this an argument?
1123     constexpr int quality = 99;
1124 
1125     struct jpeg_compress_struct cinfo;
1126     struct jpeg_error_mgr jerr;
1127     cinfo.err = jpeg_std_error(&jerr);
1128     jpeg_create_compress(&cinfo);
1129     jpeg_stdio_dest(&cinfo, f.f);
1130     cinfo.image_width = width;
1131     cinfo.image_height = height;
1132     cinfo.input_components = channels;
1133     cinfo.in_color_space = (channels == 3) ? JCS_RGB : JCS_GRAYSCALE;
1134     jpeg_set_defaults(&cinfo);
1135     jpeg_set_quality(&cinfo, quality, TRUE);
1136     jpeg_start_compress(&cinfo, TRUE);
1137 
1138     auto copy_from_image = Internal::write_big_endian_row<uint8_t, ImageType>;
1139 
1140     std::vector<uint8_t> row(width * channels);
1141     const int ymin = im.dim(1).min();
1142     const int ymax = im.dim(1).max();
1143     for (int y = ymin; y <= ymax; ++y) {
1144         uint8_t *dst = row.data();
1145         copy_from_image(im, y, dst);
1146         jpeg_write_scanlines(&cinfo, &dst, 1);
1147     }
1148 
1149     jpeg_finish_compress(&cinfo);
1150     jpeg_destroy_compress(&cinfo);
1151 
1152     return true;
1153 }
1154 
1155 #endif  // not HALIDE_NO_JPEG
1156 
1157 constexpr int kNumTmpCodes = 10;
1158 
tmp_code_to_halide_type()1159 inline const halide_type_t *tmp_code_to_halide_type() {
1160     static const halide_type_t tmp_code_to_halide_type_[kNumTmpCodes] = {
1161         {halide_type_float, 32},
1162         {halide_type_float, 64},
1163         {halide_type_uint, 8},
1164         {halide_type_int, 8},
1165         {halide_type_uint, 16},
1166         {halide_type_int, 16},
1167         {halide_type_uint, 32},
1168         {halide_type_int, 32},
1169         {halide_type_uint, 64},
1170         {halide_type_int, 64}};
1171     return tmp_code_to_halide_type_;
1172 }
1173 
1174 // return true iff the buffer storage has no padding between
1175 // any elements, and is in strictly planar order.
1176 template<typename ImageType>
buffer_is_compact_planar(ImageType & im)1177 bool buffer_is_compact_planar(ImageType &im) {
1178     const halide_type_t im_type = im.type();
1179     const size_t elem_size = (im_type.bits / 8);
1180     if (((const uint8_t *)im.begin() + (im.number_of_elements() * elem_size)) != (const uint8_t *)im.end()) {
1181         return false;
1182     }
1183     for (int d = 1; d < im.dimensions(); ++d) {
1184         if (im.dim(d - 1).stride() > im.dim(d).stride()) {
1185             return false;
1186         }
1187         // Strides can only match if the previous dimension has extent 1
1188         // (this can happen when artificially adding dimension(s), e.g.
1189         // to write a .tmp file)
1190         if (im.dim(d - 1).stride() == im.dim(d).stride() && im.dim(d - 1).extent() != 1) {
1191             return false;
1192         }
1193     }
1194     return true;
1195 }
1196 
1197 // ".tmp" is a file format used by the ImageStack tool (see https://github.com/abadams/ImageStack)
1198 template<typename ImageType, CheckFunc check = CheckReturn>
load_tmp(const std::string & filename,ImageType * im)1199 bool load_tmp(const std::string &filename, ImageType *im) {
1200     static_assert(!ImageType::has_static_halide_type, "");
1201 
1202     FileOpener f(filename, "rb");
1203     if (!check(f.f != nullptr, "File could not be opened for reading")) {
1204         return false;
1205     }
1206 
1207     int32_t header[5];
1208     if (!check(f.read_array(header), "Count not read .tmp header")) {
1209         return false;
1210     }
1211 
1212     if (!check(header[0] > 0 && header[1] > 0 && header[2] > 0 && header[3] > 0 &&
1213                    header[4] >= 0 && header[4] < kNumTmpCodes,
1214                "Bad header on .tmp file")) {
1215         return false;
1216     }
1217 
1218     const halide_type_t im_type = tmp_code_to_halide_type()[header[4]];
1219     std::vector<int> im_dimensions = {header[0], header[1], header[2], header[3]};
1220     *im = ImageType(im_type, im_dimensions);
1221 
1222     // This should never fail unless the default Buffer<> constructor behavior changes.
1223     if (!check(buffer_is_compact_planar(*im), "load_tmp() requires compact planar images")) {
1224         return false;
1225     }
1226 
1227     if (!check(f.read_bytes(im->begin(), im->size_in_bytes()), "Count not read .tmp payload")) {
1228         return false;
1229     }
1230 
1231     im->set_host_dirty();
1232     return true;
1233 }
1234 
query_tmp()1235 inline const std::set<FormatInfo> &query_tmp() {
1236     // TMP files require exactly 4 dimensions.
1237     static std::set<FormatInfo> info = {
1238         {halide_type_t(halide_type_float, 32), 4},
1239         {halide_type_t(halide_type_float, 64), 4},
1240         {halide_type_t(halide_type_uint, 8), 4},
1241         {halide_type_t(halide_type_int, 8), 4},
1242         {halide_type_t(halide_type_uint, 16), 4},
1243         {halide_type_t(halide_type_int, 16), 4},
1244         {halide_type_t(halide_type_uint, 32), 4},
1245         {halide_type_t(halide_type_int, 32), 4},
1246         {halide_type_t(halide_type_uint, 64), 4},
1247         {halide_type_t(halide_type_int, 64), 4},
1248     };
1249     return info;
1250 }
1251 
1252 template<typename ImageType, CheckFunc check = CheckReturn>
write_planar_payload(ImageType & im,FileOpener & f)1253 bool write_planar_payload(ImageType &im, FileOpener &f) {
1254     if (im.dimensions() == 0 || buffer_is_compact_planar(im)) {
1255         // Contiguous buffer! Write it all in one swell foop.
1256         if (!check(f.write_bytes(im.begin(), im.size_in_bytes()), "Count not write .tmp payload")) {
1257             return false;
1258         }
1259     } else {
1260         // We have to do this the hard way.
1261         int d = im.dimensions() - 1;
1262         for (int i = im.dim(d).min(); i <= im.dim(d).max(); i++) {
1263             auto slice = im.sliced(d, i);
1264             if (!write_planar_payload(slice, f)) {
1265                 return false;
1266             }
1267         }
1268     }
1269     return true;
1270 }
1271 
1272 // ".tmp" is a file format used by the ImageStack tool (see https://github.com/abadams/ImageStack)
1273 template<typename ImageType, CheckFunc check = CheckReturn>
save_tmp(ImageType & im,const std::string & filename)1274 bool save_tmp(ImageType &im, const std::string &filename) {
1275     static_assert(!ImageType::has_static_halide_type, "");
1276 
1277     im.copy_to_host();
1278 
1279     int32_t header[5] = {1, 1, 1, 1, -1};
1280     for (int i = 0; i < im.dimensions(); ++i) {
1281         header[i] = im.dim(i).extent();
1282     }
1283     auto *table = tmp_code_to_halide_type();
1284     for (int i = 0; i < kNumTmpCodes; i++) {
1285         if (im.type() == table[i]) {
1286             header[4] = i;
1287             break;
1288         }
1289     }
1290     if (!check(header[4] >= 0, "Unsupported type for .tmp file")) {
1291         return false;
1292     }
1293 
1294     FileOpener f(filename, "wb");
1295     if (!check(f.f != nullptr, "File could not be opened for writing")) {
1296         return false;
1297     }
1298     if (!check(f.write_array(header), "Could not write .tmp header")) {
1299         return false;
1300     }
1301 
1302     if (!write_planar_payload<ImageType, check>(im, f)) {
1303         return false;
1304     }
1305 
1306     return true;
1307 }
1308 
1309 // ".mat" is the matlab level 5 format documented here:
1310 // http://www.mathworks.com/help/pdf_doc/matlab/matfile_format.pdf
1311 
1312 enum MatlabTypeCode {
1313     miINT8 = 1,
1314     miUINT8 = 2,
1315     miINT16 = 3,
1316     miUINT16 = 4,
1317     miINT32 = 5,
1318     miUINT32 = 6,
1319     miSINGLE = 7,
1320     miDOUBLE = 9,
1321     miINT64 = 12,
1322     miUINT64 = 13,
1323     miMATRIX = 14,
1324     miCOMPRESSED = 15,
1325     miUTF8 = 16,
1326     miUTF16 = 17,
1327     miUTF32 = 18
1328 };
1329 
1330 enum MatlabClassCode {
1331     mxCHAR_CLASS = 3,
1332     mxDOUBLE_CLASS = 6,
1333     mxSINGLE_CLASS = 7,
1334     mxINT8_CLASS = 8,
1335     mxUINT8_CLASS = 9,
1336     mxINT16_CLASS = 10,
1337     mxUINT16_CLASS = 11,
1338     mxINT32_CLASS = 12,
1339     mxUINT32_CLASS = 13,
1340     mxINT64_CLASS = 14,
1341     mxUINT64_CLASS = 15
1342 };
1343 
1344 template<typename ImageType, CheckFunc check = CheckReturn>
load_mat(const std::string & filename,ImageType * im)1345 bool load_mat(const std::string &filename, ImageType *im) {
1346     static_assert(!ImageType::has_static_halide_type, "");
1347 
1348     FileOpener f(filename, "rb");
1349     if (!check(f.f != nullptr, "File could not be opened for reading")) {
1350         return false;
1351     }
1352 
1353     uint8_t header[128];
1354     if (!check(f.read_array(header), "Could not read .mat header\n")) {
1355         return false;
1356     }
1357 
1358     // Matrix header
1359     uint32_t matrix_header[2];
1360     if (!check(f.read_array(matrix_header), "Could not read .mat header\n")) {
1361         return false;
1362     }
1363     if (!check(matrix_header[0] == miMATRIX, "Could not parse this .mat file: bad matrix header\n")) {
1364         return false;
1365     }
1366 
1367     // Array flags
1368     uint32_t flags[4];
1369     if (!check(f.read_array(flags), "Could not read .mat header\n")) {
1370         return false;
1371     }
1372     if (!check(flags[0] == miUINT32 && flags[1] == 8, "Could not parse this .mat file: bad flags\n")) {
1373         return false;
1374     }
1375 
1376     // Shape
1377     uint32_t shape_header[2];
1378     if (!check(f.read_array(shape_header), "Could not read .mat header\n")) {
1379         return false;
1380     }
1381     if (!check(shape_header[0] == miINT32, "Could not parse this .mat file: bad shape header\n")) {
1382         return false;
1383     }
1384     int dims = shape_header[1] / 4;
1385     std::vector<int> extents(dims);
1386     if (!check(f.read_vector(&extents), "Could not read .mat header\n")) {
1387         return false;
1388     }
1389     if (dims & 1) {
1390         uint32_t padding;
1391         if (!check(f.read_bytes(&padding, 4), "Could not read .mat header\n")) {
1392             return false;
1393         }
1394     }
1395 
1396     // Skip over the name
1397     uint32_t name_header[2];
1398     if (!check(f.read_array(name_header), "Could not read .mat header\n")) {
1399         return false;
1400     }
1401 
1402     if (name_header[0] >> 16) {
1403         // Name must be fewer than 4 chars, and so the whole name
1404         // field was stored packed into 8 bytes
1405     } else {
1406         if (!check(name_header[0] == miINT8, "Could not parse this .mat file: bad name header\n")) {
1407             return false;
1408         }
1409         std::vector<uint64_t> scratch((name_header[1] + 7) / 8);
1410         if (!check(f.read_vector(&scratch), "Could not read .mat header\n")) {
1411             return false;
1412         }
1413     }
1414 
1415     // Payload header
1416     uint32_t payload_header[2];
1417     if (!check(f.read_array(payload_header), "Could not read .mat header\n")) {
1418         return false;
1419     }
1420     halide_type_t type;
1421     switch (payload_header[0]) {
1422     case miINT8:
1423         type = halide_type_of<int8_t>();
1424         break;
1425     case miINT16:
1426         type = halide_type_of<int16_t>();
1427         break;
1428     case miINT32:
1429         type = halide_type_of<int32_t>();
1430         break;
1431     case miINT64:
1432         type = halide_type_of<int64_t>();
1433         break;
1434     case miUINT8:
1435         type = halide_type_of<uint8_t>();
1436         break;
1437     case miUINT16:
1438         type = halide_type_of<uint16_t>();
1439         break;
1440     case miUINT32:
1441         type = halide_type_of<uint32_t>();
1442         break;
1443     case miUINT64:
1444         type = halide_type_of<uint64_t>();
1445         break;
1446     case miSINGLE:
1447         type = halide_type_of<float>();
1448         break;
1449     case miDOUBLE:
1450         type = halide_type_of<double>();
1451         break;
1452     }
1453 
1454     *im = ImageType(type, extents);
1455 
1456     // This should never fail unless the default Buffer<> constructor behavior changes.
1457     if (!check(buffer_is_compact_planar(*im), "load_mat() requires compact planar images")) {
1458         return false;
1459     }
1460 
1461     if (!check(f.read_bytes(im->begin(), im->size_in_bytes()), "Could not read .tmp payload")) {
1462         return false;
1463     }
1464 
1465     im->set_host_dirty();
1466     return true;
1467 }
1468 
query_mat()1469 inline const std::set<FormatInfo> &query_mat() {
1470     // MAT files must have at least 2 dimensions, but there's no upper
1471     // bound. Our support arbitrarily stops at 16 dimensions.
1472     static std::set<FormatInfo> info = []() {
1473         std::set<FormatInfo> s;
1474         for (int i = 2; i < 16; i++) {
1475             s.insert({halide_type_t(halide_type_float, 32), i});
1476             s.insert({halide_type_t(halide_type_float, 64), i});
1477             s.insert({halide_type_t(halide_type_uint, 8), i});
1478             s.insert({halide_type_t(halide_type_int, 8), i});
1479             s.insert({halide_type_t(halide_type_uint, 16), i});
1480             s.insert({halide_type_t(halide_type_int, 16), i});
1481             s.insert({halide_type_t(halide_type_uint, 32), i});
1482             s.insert({halide_type_t(halide_type_int, 32), i});
1483             s.insert({halide_type_t(halide_type_uint, 64), i});
1484             s.insert({halide_type_t(halide_type_int, 64), i});
1485         }
1486         return s;
1487     }();
1488     return info;
1489 }
1490 
1491 template<typename ImageType, CheckFunc check = CheckReturn>
save_mat(ImageType & im,const std::string & filename)1492 bool save_mat(ImageType &im, const std::string &filename) {
1493     static_assert(!ImageType::has_static_halide_type, "");
1494 
1495     im.copy_to_host();
1496 
1497     uint32_t class_code = 0, type_code = 0;
1498     switch (im.raw_buffer()->type.code) {
1499     case halide_type_int:
1500         switch (im.raw_buffer()->type.bits) {
1501         case 8:
1502             class_code = mxINT8_CLASS;
1503             type_code = miINT8;
1504             break;
1505         case 16:
1506             class_code = mxINT16_CLASS;
1507             type_code = miINT16;
1508             break;
1509         case 32:
1510             class_code = mxINT32_CLASS;
1511             type_code = miINT32;
1512             break;
1513         case 64:
1514             class_code = mxINT64_CLASS;
1515             type_code = miINT64;
1516             break;
1517         default:
1518             check(false, "unreachable");
1519         };
1520         break;
1521     case halide_type_uint:
1522         switch (im.raw_buffer()->type.bits) {
1523         case 8:
1524             class_code = mxUINT8_CLASS;
1525             type_code = miUINT8;
1526             break;
1527         case 16:
1528             class_code = mxUINT16_CLASS;
1529             type_code = miUINT16;
1530             break;
1531         case 32:
1532             class_code = mxUINT32_CLASS;
1533             type_code = miUINT32;
1534             break;
1535         case 64:
1536             class_code = mxUINT64_CLASS;
1537             type_code = miUINT64;
1538             break;
1539         default:
1540             check(false, "unreachable");
1541         };
1542         break;
1543     case halide_type_float:
1544         switch (im.raw_buffer()->type.bits) {
1545         case 16:
1546             check(false, "float16 not supported by .mat");
1547             break;
1548         case 32:
1549             class_code = mxSINGLE_CLASS;
1550             type_code = miSINGLE;
1551             break;
1552         case 64:
1553             class_code = mxDOUBLE_CLASS;
1554             type_code = miDOUBLE;
1555             break;
1556         default:
1557             check(false, "unreachable");
1558         };
1559         break;
1560     case halide_type_bfloat:
1561         check(false, "bfloat not supported by .mat");
1562         break;
1563     default:
1564         check(false, "unreachable");
1565     }
1566 
1567     FileOpener f(filename, "wb");
1568     if (!check(f.f != nullptr, "File could not be opened for writing")) {
1569         return false;
1570     }
1571 
1572     // Pick a name for the array
1573     size_t idx = filename.rfind('.');
1574     std::string name = filename.substr(0, idx);
1575     idx = filename.rfind('/');
1576     if (idx != std::string::npos) {
1577         name = name.substr(idx + 1);
1578     }
1579 
1580     // Matlab variable names conform to similar rules as C
1581     if (name.empty() || !std::isalpha(name[0])) {
1582         name = "v" + name;
1583     }
1584     for (size_t i = 0; i < name.size(); i++) {
1585         if (!std::isalnum(name[i])) {
1586             name[i] = '_';
1587         }
1588     }
1589 
1590     uint32_t name_size = (int)name.size();
1591     while (name.size() & 0x7)
1592         name += '\0';
1593 
1594     char header[128] = "MATLAB 5.0 MAT-file, produced by Halide";
1595     int len = strlen(header);
1596     memset(header + len, ' ', sizeof(header) - len);
1597 
1598     // Version
1599     *((uint16_t *)(header + 124)) = 0x0100;
1600 
1601     // Endianness check
1602     header[126] = 'I';
1603     header[127] = 'M';
1604 
1605     uint64_t payload_bytes = im.size_in_bytes();
1606 
1607     if (!check((payload_bytes >> 32) == 0, "Buffer too large to save as .mat")) {
1608         return false;
1609     }
1610 
1611     int dims = im.dimensions();
1612     if (dims < 2) {
1613         dims = 2;
1614     }
1615     int padded_dims = dims + (dims & 1);
1616 
1617     uint32_t padding_bytes = 7 - ((payload_bytes - 1) & 7);
1618 
1619     // Matrix header
1620     uint32_t matrix_header[2] = {
1621         miMATRIX, 40 + padded_dims * 4 + (uint32_t)name.size() + (uint32_t)payload_bytes + padding_bytes};
1622 
1623     // Array flags
1624     uint32_t flags[4] = {
1625         miUINT32, 8, class_code, 1};
1626 
1627     // Shape
1628     int32_t shape[2] = {
1629         miINT32,
1630         im.dimensions() * 4,
1631     };
1632     std::vector<int> extents(im.dimensions());
1633     for (int d = 0; d < im.dimensions(); d++) {
1634         extents[d] = im.dim(d).extent();
1635     }
1636     while ((int)extents.size() < dims) {
1637         extents.push_back(1);
1638     }
1639     while ((int)extents.size() < padded_dims) {
1640         extents.push_back(0);
1641     }
1642 
1643     // Name
1644     uint32_t name_header[2] = {
1645         miINT8, name_size};
1646 
1647     // Payload header
1648     uint32_t payload_header[2] = {
1649         type_code, (uint32_t)payload_bytes};
1650 
1651     bool success =
1652         f.write_array(header) &&
1653         f.write_array(matrix_header) &&
1654         f.write_array(flags) &&
1655         f.write_array(shape) &&
1656         f.write_vector(extents) &&
1657         f.write_array(name_header) &&
1658         f.write_bytes(&name[0], name.size()) &&
1659         f.write_array(payload_header);
1660 
1661     if (!check(success, "Could not write .mat header")) {
1662         return false;
1663     }
1664 
1665     if (!write_planar_payload<ImageType, check>(im, f)) {
1666         return false;
1667     }
1668 
1669     // Padding
1670     if (!check(padding_bytes < 8, "Too much padding!\n")) {
1671         return false;
1672     }
1673     uint64_t padding = 0;
1674     if (!f.write_bytes(&padding, padding_bytes)) {
1675         return false;
1676     }
1677 
1678     return true;
1679 }
1680 
1681 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load_tiff(const std::string & filename,ImageType * im)1682 bool load_tiff(const std::string &filename, ImageType *im) {
1683     static_assert(!ImageType::has_static_halide_type, "");
1684     check(false, "Reading TIFF is not yet supported");
1685     return false;
1686 }
1687 
query_tiff()1688 inline const std::set<FormatInfo> &query_tiff() {
1689     auto build_set = []() -> std::set<FormatInfo> {
1690         std::set<FormatInfo> s;
1691         for (halide_type_code_t code : {halide_type_int, halide_type_uint, halide_type_float}) {
1692             for (int bits : {8, 16, 32, 64}) {
1693                 for (int dims : {1, 2, 3, 4}) {
1694                     if (code == halide_type_float && bits < 32) {
1695                         continue;
1696                     }
1697                     s.insert({halide_type_t(code, bits), dims});
1698                 }
1699             }
1700         }
1701         return s;
1702     };
1703 
1704     static std::set<FormatInfo> info = build_set();
1705     return info;
1706 }
1707 
1708 #pragma pack(push)
1709 #pragma pack(2)
1710 
1711 struct halide_tiff_tag {
1712     uint16_t tag_code;
1713     int16_t type_code;
1714     int32_t count;
1715     union {
1716         int8_t i8;
1717         int16_t i16;
1718         int32_t i32;
1719     } value;
1720 
assign16halide_tiff_tag1721     void assign16(uint16_t tag_code, int32_t count, int16_t value) {
1722         this->tag_code = tag_code;
1723         this->type_code = 3;  // SHORT
1724         this->count = count;
1725         this->value.i16 = value;
1726     }
1727 
assign32halide_tiff_tag1728     void assign32(uint16_t tag_code, int32_t count, int32_t value) {
1729         this->tag_code = tag_code;
1730         this->type_code = 4;  // LONG
1731         this->count = count;
1732         this->value.i32 = value;
1733     }
1734 
assign32halide_tiff_tag1735     void assign32(uint16_t tag_code, int16_t type_code, int32_t count, int32_t value) {
1736         this->tag_code = tag_code;
1737         this->type_code = type_code;
1738         this->count = count;
1739         this->value.i32 = value;
1740     }
1741 };
1742 
1743 struct halide_tiff_header {
1744     int16_t byte_order_marker;
1745     int16_t version;
1746     int32_t ifd0_offset;
1747     int16_t entry_count;
1748     halide_tiff_tag entries[15];
1749     int32_t ifd0_end;
1750     int32_t width_resolution[2];
1751     int32_t height_resolution[2];
1752 };
1753 
1754 #pragma pack(pop)
1755 
1756 template<typename ElemType, int BUFFER_SIZE = 1024>
1757 struct ElemWriter {
ElemWriterElemWriter1758     ElemWriter(FileOpener *f)
1759         : f(f), next(&buf[0]), ok(true) {
1760     }
~ElemWriterElemWriter1761     ~ElemWriter() {
1762         flush();
1763     }
1764 
operatorElemWriter1765     void operator()(const ElemType &elem) {
1766         if (!ok) return;
1767 
1768         *next++ = elem;
1769         if (next == &buf[BUFFER_SIZE]) {
1770             flush();
1771         }
1772     }
1773 
flushElemWriter1774     void flush() {
1775         if (!ok) return;
1776 
1777         if (next > buf) {
1778             if (!f->write_bytes(buf, (next - buf) * sizeof(ElemType))) {
1779                 ok = false;
1780             }
1781             next = buf;
1782         }
1783     }
1784 
1785     FileOpener *const f;
1786     ElemType buf[BUFFER_SIZE];
1787     ElemType *next;
1788     bool ok;
1789 };
1790 
1791 // Note that this is a fairly simpleminded TIFF writer that doesn't
1792 // do any compression. It would be desirable to (optionally) support using libtiff
1793 // here instead, which would also allow us to provide a useful implementation
1794 // for TIFF reading.
1795 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_tiff(ImageType & im,const std::string & filename)1796 bool save_tiff(ImageType &im, const std::string &filename) {
1797     static_assert(!ImageType::has_static_halide_type, "");
1798 
1799     im.copy_to_host();
1800 
1801     if (!check(im.dimensions() <= 4, "Can only save TIFF files with <= 4 dimensions")) {
1802         return false;
1803     }
1804 
1805     FileOpener f(filename, "wb");
1806     if (!check(f.f != nullptr, "File could not be opened for writing")) {
1807         return false;
1808     }
1809 
1810     const size_t elements = im.number_of_elements();
1811     halide_dimension_t shape[4];
1812     for (int i = 0; i < im.dimensions() && i < 4; i++) {
1813         const auto &d = im.dim(i);
1814         shape[i].min = d.min();
1815         shape[i].extent = d.extent();
1816         shape[i].stride = d.stride();
1817     }
1818     for (int i = im.dimensions(); i < 4; i++) {
1819         shape[i].min = 0;
1820         shape[i].extent = 1;
1821         shape[i].stride = 0;
1822     }
1823     const halide_type_t im_type = im.type();
1824     if (!check(im_type.code >= 0 && im_type.code < 3, "Unsupported image type")) {
1825         return false;
1826     }
1827     const int32_t bytes_per_element = im_type.bytes();
1828     const int32_t width = shape[0].extent;
1829     const int32_t height = shape[1].extent;
1830     int32_t depth = shape[2].extent;
1831     int32_t channels = shape[3].extent;
1832 
1833     if ((channels == 0 || channels == 1) && (depth < 5)) {
1834         channels = depth;
1835         depth = 1;
1836     }
1837 
1838     // TIFF sample type values are:
1839     //     0 => Signed int
1840     //     1 => Unsigned int
1841     //     2 => Floating-point
1842     static const int16_t type_code_to_tiff_sample_type[] = {
1843         2, 1, 3};
1844 
1845     struct halide_tiff_header header;
1846     memset(&header, 0, sizeof(header));
1847 
1848     const int32_t MMII = 0x4d4d4949;
1849     // Select the appropriate two bytes signaling byte order automatically
1850     const char *c = (const char *)&MMII;
1851     header.byte_order_marker = (c[0] << 8) | c[1];
1852     header.version = 42;
1853     header.ifd0_offset = offsetof(halide_tiff_header, entry_count);
1854     header.entry_count = sizeof(header.entries) / sizeof(header.entries[0]);
1855 
1856     static_assert(sizeof(halide_tiff_tag) == 12, "Unexpected halide_tiff_tag packing");
1857     halide_tiff_tag *tag = &header.entries[0];
1858     tag++->assign32(256, 1, width);                           // ImageWidth
1859     tag++->assign32(257, 1, height);                          // ImageLength
1860     tag++->assign16(258, 1, int16_t(bytes_per_element * 8));  // BitsPerSample
1861     tag++->assign16(259, 1, 1);                               // Compression -- none
1862     tag++->assign16(262, 1, channels >= 3 ? 2 : 1);           // PhotometricInterpretation -- black is zero or RGB
1863     tag++->assign32(273, channels, sizeof(header));           // StripOffsets
1864     tag++->assign16(277, 1, int16_t(channels));               // SamplesPerPixel
1865     tag++->assign32(278, 1, height);                          // RowsPerStrip
1866     tag++->assign32(279, channels,                            // StripByteCounts
1867                     (channels == 1) ?
1868                         elements * bytes_per_element :
1869                         sizeof(header) + channels * sizeof(int32_t));  // for channels > 1, this is an offset
1870     tag++->assign32(282, 5, 1,
1871                     offsetof(halide_tiff_header, width_resolution));  // XResolution
1872     tag++->assign32(283, 5, 1,
1873                     offsetof(halide_tiff_header, height_resolution));      // YResolution
1874     tag++->assign16(284, 1, 2);                                            // PlanarConfiguration -- planar
1875     tag++->assign16(296, 1, 1);                                            // ResolutionUnit -- none
1876     tag++->assign16(339, 1, type_code_to_tiff_sample_type[im_type.code]);  // SampleFormat
1877     tag++->assign32(32997, 1, depth);                                      // Image depth
1878 
1879     // Verify we used exactly the number we declared
1880     assert(tag == &header.entries[header.entry_count]);
1881 
1882     header.ifd0_end = 0;
1883     header.width_resolution[0] = 1;
1884     header.width_resolution[1] = 1;
1885     header.height_resolution[0] = 1;
1886     header.height_resolution[1] = 1;
1887 
1888     if (!check(f.write_bytes(&header, sizeof(header)), "TIFF write failed")) {
1889         return false;
1890     }
1891 
1892     if (channels > 1) {
1893         // Fill in the values for StripOffsets
1894         int32_t offset = sizeof(header) + channels * sizeof(int32_t) * 2;
1895         for (int32_t i = 0; i < channels; i++) {
1896             if (!check(f.write_bytes(&offset, sizeof(offset)), "TIFF write failed")) {
1897                 return false;
1898             }
1899             offset += width * height * depth * bytes_per_element;
1900         }
1901         // Fill in the values for StripByteCounts
1902         int32_t count = width * height * depth * bytes_per_element;
1903         for (int32_t i = 0; i < channels; i++) {
1904             if (!check(f.write_bytes(&count, sizeof(count)), "TIFF write failed")) {
1905                 return false;
1906             }
1907         }
1908     }
1909 
1910     // If image is dense, we can write it in one fell swoop
1911     if (elements * bytes_per_element == im.size_in_bytes()) {
1912         if (!check(f.write_bytes(im.data(), im.size_in_bytes()), "TIFF write failed")) {
1913             return false;
1914         }
1915         return true;
1916     }
1917 
1918     // Otherwise, write it out via manual traversal.
1919 #define HANDLE_CASE(CODE, BITS, TYPE)                    \
1920     case halide_type_code(CODE, BITS): {                 \
1921         ElemWriter<TYPE> ew(&f);                         \
1922         im.template as<const TYPE>().for_each_value(ew); \
1923         if (!check(ew.ok, "TIFF write failed")) {        \
1924             return false;                                \
1925         }                                                \
1926         break;                                           \
1927     }
1928 
1929     switch (halide_type_code((halide_type_code_t)im_type.code, im_type.bits)) {
1930         HANDLE_CASE(halide_type_float, 32, float)
1931         HANDLE_CASE(halide_type_float, 64, double)
1932         HANDLE_CASE(halide_type_int, 8, int8_t)
1933         HANDLE_CASE(halide_type_int, 16, int16_t)
1934         HANDLE_CASE(halide_type_int, 32, int32_t)
1935         HANDLE_CASE(halide_type_int, 64, int64_t)
1936         HANDLE_CASE(halide_type_uint, 1, bool)
1937         HANDLE_CASE(halide_type_uint, 8, uint8_t)
1938         HANDLE_CASE(halide_type_uint, 16, uint16_t)
1939         HANDLE_CASE(halide_type_uint, 32, uint32_t)
1940         HANDLE_CASE(halide_type_uint, 64, uint64_t)
1941     // Note that we don't attempt to handle halide_type_handle here.
1942     default:
1943         assert(false && "Unsupported type");
1944         return false;
1945     }
1946 #undef HANDLE_CASE
1947 
1948     return true;
1949 }
1950 
1951 // Given something like ImageType<Foo>, produce typedef ImageType<Bar>
1952 template<typename ImageType, typename ElemType>
1953 struct ImageTypeWithElemType {
1954     using type = decltype(std::declval<ImageType>().template as<ElemType>());
1955 };
1956 
1957 // Given something like ImageType<Foo>, produce typedef ImageType<const Bar>
1958 template<typename ImageType, typename ElemType>
1959 struct ImageTypeWithConstElemType {
1960     using type = decltype(std::declval<ImageType>().template as<typename std::add_const<ElemType>::type>());
1961 };
1962 
1963 template<typename ImageType, Internal::CheckFunc check>
1964 struct ImageIO {
1965     using ConstImageType = typename ImageTypeWithConstElemType<ImageType, typename ImageType::ElemType>::type;
1966 
1967     std::function<bool(const std::string &, ImageType *)> load;
1968     std::function<bool(ConstImageType &im, const std::string &)> save;
1969     std::function<const std::set<FormatInfo> &()> query;
1970 };
1971 
1972 template<typename ImageType, Internal::CheckFunc check>
find_imageio(const std::string & filename,ImageIO<ImageType,check> * result)1973 bool find_imageio(const std::string &filename, ImageIO<ImageType, check> *result) {
1974     static_assert(!ImageType::has_static_halide_type, "");
1975     using ConstImageType = typename ImageTypeWithConstElemType<ImageType, typename ImageType::ElemType>::type;
1976 
1977     const std::map<std::string, ImageIO<ImageType, check>> m = {
1978 #ifndef HALIDE_NO_JPEG
1979         {"jpeg", {load_jpg<ImageType, check>, save_jpg<ConstImageType, check>, query_jpg}},
1980         {"jpg", {load_jpg<ImageType, check>, save_jpg<ConstImageType, check>, query_jpg}},
1981 #endif
1982         {"pgm", {load_pgm<ImageType, check>, save_pgm<ConstImageType, check>, query_pgm}},
1983 #ifndef HALIDE_NO_PNG
1984         {"png", {load_png<ImageType, check>, save_png<ConstImageType, check>, query_png}},
1985 #endif
1986         {"ppm", {load_ppm<ImageType, check>, save_ppm<ConstImageType, check>, query_ppm}},
1987         {"tmp", {load_tmp<ImageType, check>, save_tmp<ConstImageType, check>, query_tmp}},
1988         {"mat", {load_mat<ImageType, check>, save_mat<ConstImageType, check>, query_mat}},
1989         {"tiff", {load_tiff<ImageType, check>, save_tiff<ConstImageType, check>, query_tiff}},
1990     };
1991     std::string ext = Internal::get_lowercase_extension(filename);
1992     auto it = m.find(ext);
1993     if (it != m.end()) {
1994         *result = it->second;
1995         return true;
1996     }
1997 
1998     std::string err = "unsupported file extension \"" + ext + "\", supported are:";
1999     for (auto &it : m) {
2000         err += " " + it.first;
2001     }
2002     err += "\n";
2003     return check(false, err.c_str());
2004 }
2005 
2006 template<typename ImageType>
best_save_format(const ImageType & im,const std::set<FormatInfo> & info)2007 FormatInfo best_save_format(const ImageType &im, const std::set<FormatInfo> &info) {
2008     // A bit ad hoc, but will do for now:
2009     // Perfect score is zero (exact match).
2010     // The larger the score, the worse the match.
2011     int best_score = 0x7fffffff;
2012     FormatInfo best{};
2013     const halide_type_t im_type = im.type();
2014     const int im_dimensions = im.dimensions();
2015     for (auto &f : info) {
2016         int score = 0;
2017         // If format has too-few dimensions, that's very bad.
2018         score += std::max(0, im_dimensions - f.dimensions) * 1024;
2019         // If format has too-few bits, that's pretty bad.
2020         score += std::max(0, im_type.bits - f.type.bits) * 8;
2021         // If format has too-many bits, that's a little bad.
2022         score += std::max(0, f.type.bits - im_type.bits);
2023         // If format has different code, that's a little bad.
2024         score += (f.type.code != im_type.code) ? 1 : 0;
2025         if (score < best_score) {
2026             best_score = score;
2027             best = f;
2028         }
2029     }
2030 
2031     return best;
2032 }
2033 
2034 }  // namespace Internal
2035 
2036 struct ImageTypeConversion {
2037     // Convert an Image from one ElemType to another, where the src and
2038     // dst types are statically known (e.g. Buffer<uint8_t> -> Buffer<float>).
2039     // Note that this does conversion with scaling -- intepreting integers
2040     // as fixed-point numbers between 0 and 1 -- not merely C-style casting.
2041     //
2042     // You'd normally call this with an explicit type for DstElemType and
2043     // allow ImageType to be inferred, e.g.
2044     //     Buffer<uint8_t> src = ...;
2045     //     Buffer<float> dst = convert_image<float>(src);
2046     template<typename DstElemType, typename ImageType,
2047              typename std::enable_if<ImageType::has_static_halide_type && !std::is_void<DstElemType>::value>::type * = nullptr>
2048     static auto convert_image(const ImageType &src) ->
2049         typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type {
2050         // The enable_if ensures this will never fire; this is here primarily
2051         // as documentation and a backstop against breakage.
2052         static_assert(ImageType::has_static_halide_type,
2053                       "This variant of convert_image() requires a statically-typed image");
2054 
2055         using SrcImageType = ImageType;
2056         using SrcElemType = typename SrcImageType::ElemType;
2057 
2058         using DstImageType = typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type;
2059 
2060         DstImageType dst = DstImageType::make_with_shape_of(src);
2061         const auto converter = [](DstElemType &dst_elem, SrcElemType src_elem) {
2062             dst_elem = Internal::convert<DstElemType>(src_elem);
2063         };
2064         dst.for_each_value(converter, src);
2065         dst.set_host_dirty();
2066 
2067         return dst;
2068     }
2069 
2070     // Convert an Image from one ElemType to another, where the dst type is statically
2071     // known but the src type is not (e.g. Buffer<> -> Buffer<float>).
2072     // You'd normally call this with an explicit type for DstElemType and
2073     // allow ImageType to be inferred, e.g.
2074     //     Buffer<uint8_t> src = ...;
2075     //     Buffer<float> dst = convert_image<float>(src);
2076     template<typename DstElemType, typename ImageType,
2077              typename std::enable_if<!ImageType::has_static_halide_type && !std::is_void<DstElemType>::value>::type * = nullptr>
2078     static auto convert_image(const ImageType &src) ->
2079         typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type {
2080         // The enable_if ensures this will never fire; this is here primarily
2081         // as documentation and a backstop against breakage.
2082         static_assert(!ImageType::has_static_halide_type,
2083                       "This variant of convert_image() requires a dynamically-typed image");
2084 
2085         const halide_type_t src_type = src.type();
2086         switch (Internal::halide_type_code((halide_type_code_t)src_type.code, src_type.bits)) {
2087         case Internal::halide_type_code(halide_type_float, 32):
2088             return convert_image<DstElemType>(src.template as<float>());
2089         case Internal::halide_type_code(halide_type_float, 64):
2090             return convert_image<DstElemType>(src.template as<double>());
2091         case Internal::halide_type_code(halide_type_int, 8):
2092             return convert_image<DstElemType>(src.template as<int8_t>());
2093         case Internal::halide_type_code(halide_type_int, 16):
2094             return convert_image<DstElemType>(src.template as<int16_t>());
2095         case Internal::halide_type_code(halide_type_int, 32):
2096             return convert_image<DstElemType>(src.template as<int32_t>());
2097         case Internal::halide_type_code(halide_type_int, 64):
2098             return convert_image<DstElemType>(src.template as<int64_t>());
2099         case Internal::halide_type_code(halide_type_uint, 1):
2100             return convert_image<DstElemType>(src.template as<bool>());
2101         case Internal::halide_type_code(halide_type_uint, 8):
2102             return convert_image<DstElemType>(src.template as<uint8_t>());
2103         case Internal::halide_type_code(halide_type_uint, 16):
2104             return convert_image<DstElemType>(src.template as<uint16_t>());
2105         case Internal::halide_type_code(halide_type_uint, 32):
2106             return convert_image<DstElemType>(src.template as<uint32_t>());
2107         case Internal::halide_type_code(halide_type_uint, 64):
2108             return convert_image<DstElemType>(src.template as<uint64_t>());
2109         default:
2110             assert(false && "Unsupported type");
2111             using DstImageType = typename Internal::ImageTypeWithElemType<ImageType, DstElemType>::type;
2112             return DstImageType();
2113         }
2114     }
2115 
2116     // Convert an Image from one ElemType to another, where the src type
2117     // is statically known but the dst type is not
2118     // (e.g. Buffer<uint8_t> -> Buffer<>(halide_type_t)).
2119     template<typename DstElemType = void,
2120              typename ImageType,
2121              typename std::enable_if<ImageType::has_static_halide_type && std::is_void<DstElemType>::value>::type * = nullptr>
2122     static auto convert_image(const ImageType &src, const halide_type_t &dst_type) ->
2123         typename Internal::ImageTypeWithElemType<ImageType, void>::type {
2124         // The enable_if ensures this will never fire; this is here primarily
2125         // as documentation and a backstop against breakage.
2126         static_assert(ImageType::has_static_halide_type,
2127                       "This variant of convert_image() requires a statically-typed image");
2128 
2129         // Call the appropriate static-to-static conversion routine
2130         // based on the desired dst type.
2131         switch (Internal::halide_type_code((halide_type_code_t)dst_type.code, dst_type.bits)) {
2132         case Internal::halide_type_code(halide_type_float, 32):
2133             return convert_image<float>(src);
2134         case Internal::halide_type_code(halide_type_float, 64):
2135             return convert_image<double>(src);
2136         case Internal::halide_type_code(halide_type_int, 8):
2137             return convert_image<int8_t>(src);
2138         case Internal::halide_type_code(halide_type_int, 16):
2139             return convert_image<int16_t>(src);
2140         case Internal::halide_type_code(halide_type_int, 32):
2141             return convert_image<int32_t>(src);
2142         case Internal::halide_type_code(halide_type_int, 64):
2143             return convert_image<int64_t>(src);
2144         case Internal::halide_type_code(halide_type_uint, 1):
2145             return convert_image<bool>(src);
2146         case Internal::halide_type_code(halide_type_uint, 8):
2147             return convert_image<uint8_t>(src);
2148         case Internal::halide_type_code(halide_type_uint, 16):
2149             return convert_image<uint16_t>(src);
2150         case Internal::halide_type_code(halide_type_uint, 32):
2151             return convert_image<uint32_t>(src);
2152         case Internal::halide_type_code(halide_type_uint, 64):
2153             return convert_image<uint64_t>(src);
2154         default:
2155             assert(false && "Unsupported type");
2156             return ImageType();
2157         }
2158     }
2159 
2160     // Convert an Image from one ElemType to another, where neither src type
2161     // nor dst type are statically known
2162     // (e.g. Buffer<>(halide_type_t) -> Buffer<>(halide_type_t)).
2163     template<typename DstElemType = void,
2164              typename ImageType,
2165              typename std::enable_if<!ImageType::has_static_halide_type && std::is_void<DstElemType>::value>::type * = nullptr>
2166     static auto convert_image(const ImageType &src, const halide_type_t &dst_type) ->
2167         typename Internal::ImageTypeWithElemType<ImageType, void>::type {
2168         // The enable_if ensures this will never fire; this is here primarily
2169         // as documentation and a backstop against breakage.
2170         static_assert(!ImageType::has_static_halide_type,
2171                       "This variant of convert_image() requires a dynamically-typed image");
2172 
2173         // Sniff the runtime type of src, coerce it to that type using as<>(),
2174         // and call the static-to-dynamic variant of this method. (Note that
2175         // this forces instantiation of the complete any-to-any conversion
2176         // matrix of code.)
2177         const halide_type_t src_type = src.type();
2178         switch (Internal::halide_type_code((halide_type_code_t)src_type.code, src_type.bits)) {
2179         case Internal::halide_type_code(halide_type_float, 32):
2180             return convert_image(src.template as<float>(), dst_type);
2181         case Internal::halide_type_code(halide_type_float, 64):
2182             return convert_image(src.template as<double>(), dst_type);
2183         case Internal::halide_type_code(halide_type_int, 8):
2184             return convert_image(src.template as<int8_t>(), dst_type);
2185         case Internal::halide_type_code(halide_type_int, 16):
2186             return convert_image(src.template as<int16_t>(), dst_type);
2187         case Internal::halide_type_code(halide_type_int, 32):
2188             return convert_image(src.template as<int32_t>(), dst_type);
2189         case Internal::halide_type_code(halide_type_int, 64):
2190             return convert_image(src.template as<int64_t>(), dst_type);
2191         case Internal::halide_type_code(halide_type_uint, 1):
2192             return convert_image(src.template as<bool>(), dst_type);
2193         case Internal::halide_type_code(halide_type_uint, 8):
2194             return convert_image(src.template as<uint8_t>(), dst_type);
2195         case Internal::halide_type_code(halide_type_uint, 16):
2196             return convert_image(src.template as<uint16_t>(), dst_type);
2197         case Internal::halide_type_code(halide_type_uint, 32):
2198             return convert_image(src.template as<uint32_t>(), dst_type);
2199         case Internal::halide_type_code(halide_type_uint, 64):
2200             return convert_image(src.template as<uint64_t>(), dst_type);
2201         default:
2202             assert(false && "Unsupported type");
2203             return ImageType();
2204         }
2205     }
2206 };
2207 
2208 // Load the Image from the given file.
2209 // If output Image has a static type, and the loaded image cannot be stored
2210 // in such an image without losing data, fail.
2211 // Returns false upon failure.
2212 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
load(const std::string & filename,ImageType * im)2213 bool load(const std::string &filename, ImageType *im) {
2214     using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2215     Internal::ImageIO<DynamicImageType, check> imageio;
2216     if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2217         return false;
2218     }
2219     using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2220     DynamicImageType im_d;
2221     if (!imageio.load(filename, &im_d)) {
2222         return false;
2223     }
2224     // Allow statically-typed images to be passed as the out-param, but do
2225     // a runtime check to ensure
2226     if (ImageType::has_static_halide_type) {
2227         const halide_type_t expected_type = ImageType::static_halide_type();
2228         if (!check(im_d.type() == expected_type, "Image loaded did not match the expected type")) {
2229             return false;
2230         }
2231     }
2232     *im = im_d.template as<typename ImageType::ElemType>();
2233     im->set_host_dirty();
2234     return true;
2235 }
2236 
2237 // Save the Image in the format associated with the filename's extension.
2238 // If the format can't represent the Image without losing data, fail.
2239 // Returns false upon failure.
2240 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save(ImageType & im,const std::string & filename)2241 bool save(ImageType &im, const std::string &filename) {
2242     using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2243     Internal::ImageIO<DynamicImageType, check> imageio;
2244     if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2245         return false;
2246     }
2247     if (!check(imageio.query().count({im.type(), im.dimensions()}) > 0, "Image cannot be saved in this format")) {
2248         return false;
2249     }
2250 
2251     // Allow statically-typed images to be passed in, but quietly pass them on
2252     // as dynamically-typed images.
2253     auto im_d = im.template as<const void>();
2254     return imageio.save(im_d, filename);
2255 }
2256 
2257 // Return a set of FormatInfo structs that contain the legal type-and-dimensions
2258 // that can be saved in this format. Most applications won't ever need to use
2259 // this call. Returns false upon failure.
2260 template<typename ImageType, Internal::CheckFunc check = Internal::CheckReturn>
save_query(const std::string & filename,std::set<FormatInfo> * info)2261 bool save_query(const std::string &filename, std::set<FormatInfo> *info) {
2262     using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2263     Internal::ImageIO<DynamicImageType, check> imageio;
2264     if (!Internal::find_imageio<DynamicImageType, check>(filename, &imageio)) {
2265         return false;
2266     }
2267     *info = imageio.query();
2268     return true;
2269 }
2270 
2271 // Fancy wrapper to call load() with CheckFail, inferring the return type;
2272 // this allows you to simply use
2273 //
2274 //    Image im = load_image("filename");
2275 //
2276 // without bothering to check error results (all errors simply abort).
2277 //
2278 // Note that if the image being loaded doesn't match the static type and
2279 // dimensions of of the image on the LHS, a runtime error will occur.
2280 class load_image {
2281 public:
load_image(const std::string & f)2282     load_image(const std::string &f)
2283         : filename(f) {
2284     }
2285 
2286     template<typename ImageType>
ImageType()2287     operator ImageType() {
2288         using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2289         DynamicImageType im_d;
2290         (void)load<DynamicImageType, Internal::CheckFail>(filename, &im_d);
2291         Internal::CheckFail(ImageType::can_convert_from(im_d),
2292                             "Type mismatch assigning the result of load_image. "
2293                             "Did you mean to use load_and_convert_image?");
2294         return im_d.template as<typename ImageType::ElemType>();
2295     }
2296 
2297 private:
2298     const std::string filename;
2299 };
2300 
2301 // Like load_image, but quietly convert the loaded image to the type of the LHS
2302 // if necessary, discarding information if necessary.
2303 class load_and_convert_image {
2304 public:
load_and_convert_image(const std::string & f)2305     load_and_convert_image(const std::string &f)
2306         : filename(f) {
2307     }
2308 
2309     template<typename ImageType>
ImageType()2310     inline operator ImageType() {
2311         using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2312         DynamicImageType im_d;
2313         (void)load<DynamicImageType, Internal::CheckFail>(filename, &im_d);
2314         const halide_type_t expected_type = ImageType::static_halide_type();
2315         if (im_d.type() == expected_type) {
2316             return im_d.template as<typename ImageType::ElemType>();
2317         } else {
2318             return ImageTypeConversion::convert_image<typename ImageType::ElemType>(im_d);
2319         }
2320     }
2321 
2322 private:
2323     const std::string filename;
2324 };
2325 
2326 // Fancy wrapper to call save() with CheckFail; this allows you to simply use
2327 //
2328 //    save_image(im, "filename");
2329 //
2330 // without bothering to check error results (all errors simply abort).
2331 //
2332 // If the specified image file format cannot represent the image without
2333 // losing data (e.g, a float32 or 4-dimensional image saved as a JPEG),
2334 // a runtime error will occur.
2335 template<typename ImageType, Internal::CheckFunc check = Internal::CheckFail>
save_image(ImageType & im,const std::string & filename)2336 void save_image(ImageType &im, const std::string &filename) {
2337     (void)save<ImageType, check>(im, filename);
2338 }
2339 
2340 // Like save_image, but quietly convert the saved image to a type that the
2341 // specified image file format can hold, discarding information if necessary.
2342 // (Note that the input image is unaffected!)
2343 template<typename ImageType, Internal::CheckFunc check = Internal::CheckFail>
convert_and_save_image(ImageType & im,const std::string & filename)2344 void convert_and_save_image(ImageType &im, const std::string &filename) {
2345     // We'll be doing any conversion on the CPU
2346     im.copy_to_host();
2347 
2348     std::set<FormatInfo> info;
2349     (void)save_query<ImageType, check>(filename, &info);
2350     const FormatInfo best = Internal::best_save_format(im, info);
2351     if (best.type == im.type() && best.dimensions == im.dimensions()) {
2352         // It's an exact match, we can save as-is.
2353         (void)save<ImageType, check>(im, filename);
2354     } else {
2355         using DynamicImageType = typename Internal::ImageTypeWithElemType<ImageType, void>::type;
2356         DynamicImageType im_converted = ImageTypeConversion::convert_image(im, best.type);
2357         while (im_converted.dimensions() < best.dimensions) {
2358             im_converted.add_dimension();
2359         }
2360         (void)save<DynamicImageType, check>(im_converted, filename);
2361     }
2362 }
2363 
2364 }  // namespace Tools
2365 }  // namespace Halide
2366 
2367 #endif  // HALIDE_IMAGE_IO_H
2368