1local gm = require 'graphicsmagick' 2require 'pl' 3 4local image_loader = {} 5 6function image_loader.decode_float(blob) 7 local im = image_loader.decode_byte(blob) 8 if im then 9 im = im:float():div(255) 10 end 11 return im 12end 13function image_loader.encode_png(tensor) 14 local im = gm.Image(tensor, "RGB", "DHW") 15 im:format("png") 16 return im:toBlob() 17end 18function image_loader.decode_byte(blob) 19 local load_image = function() 20 local im = gm.Image() 21 im:fromBlob(blob, #blob) 22 -- FIXME: How to detect that a image has an alpha channel? 23 if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then 24 -- merge alpha channel 25 im = im:toTensor('float', 'RGBA', 'DHW') 26 local w2 = im[4] 27 local w1 = im[4] * -1 + 1 28 local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) 29 -- apply the white background 30 new_im[1]:copy(im[1]):cmul(w2):add(w1) 31 new_im[2]:copy(im[2]):cmul(w2):add(w1) 32 new_im[3]:copy(im[3]):cmul(w2):add(w1) 33 im = new_im:mul(255):byte() 34 else 35 im = im:toTensor('byte', 'RGB', 'DHW') 36 end 37 return im 38 end 39 local state, ret = pcall(load_image) 40 if state then 41 return ret 42 else 43 return nil 44 end 45end 46function image_loader.load_float(file) 47 local fp = io.open(file, "rb") 48 local buff = fp:read("*a") 49 fp:close() 50 return image_loader.decode_float(buff) 51end 52function image_loader.load_byte(file) 53 local fp = io.open(file, "rb") 54 local buff = fp:read("*a") 55 fp:close() 56 return image_loader.decode_byte(buff) 57end 58local function test() 59 require 'image' 60 local img 61 img = image_loader.load_float("./a.jpg") 62 if img then 63 print(img:min()) 64 print(img:max()) 65 image.display(img) 66 end 67 img = image_loader.load_float("./b.png") 68 if img then 69 image.display(img) 70 end 71end 72--test() 73return image_loader 74