1local gm = require 'graphicsmagick'
2require 'pl'
3
4local image_loader = {}
5
6function image_loader.decode_float(blob)
7   local im = image_loader.decode_byte(blob)
8   if im then
9      im = im:float():div(255)
10   end
11   return im
12end
13function image_loader.encode_png(tensor)
14   local im = gm.Image(tensor, "RGB", "DHW")
15   im:format("png")
16   return im:toBlob()
17end
18function image_loader.decode_byte(blob)
19   local load_image = function()
20      local im = gm.Image()
21      im:fromBlob(blob, #blob)
22      -- FIXME: How to detect that a image has an alpha channel?
23      if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
24	 -- merge alpha channel
25	 im = im:toTensor('float', 'RGBA', 'DHW')
26	 local w2 = im[4]
27	 local w1 = im[4] * -1 + 1
28	 local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
29	 -- apply the white background
30	 new_im[1]:copy(im[1]):cmul(w2):add(w1)
31	 new_im[2]:copy(im[2]):cmul(w2):add(w1)
32	 new_im[3]:copy(im[3]):cmul(w2):add(w1)
33	 im = new_im:mul(255):byte()
34      else
35	 im = im:toTensor('byte', 'RGB', 'DHW')
36      end
37      return im
38   end
39   local state, ret = pcall(load_image)
40   if state then
41      return ret
42   else
43      return nil
44   end
45end
46function image_loader.load_float(file)
47   local fp = io.open(file, "rb")
48   local buff = fp:read("*a")
49   fp:close()
50   return image_loader.decode_float(buff)
51end
52function image_loader.load_byte(file)
53   local fp = io.open(file, "rb")
54   local buff = fp:read("*a")
55   fp:close()
56   return image_loader.decode_byte(buff)
57end
58local function test()
59   require 'image'
60   local img
61   img = image_loader.load_float("./a.jpg")
62   if img then
63      print(img:min())
64      print(img:max())
65      image.display(img)
66   end
67   img = image_loader.load_float("./b.png")
68   if img then
69      image.display(img)
70   end
71end
72--test()
73return image_loader
74