1#!/usr/bin/env python3 2# Copyright (c) Glow Contributors. See CONTRIBUTORS file. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16 17# imagenet-process : Runs preprocessing of standard imagenet images 18# to work with a pretrained model (e.g. resnet) 19# through glow 20# usage: python3 imagenet-process images/*.JPEG processed/ 21import PIL.Image 22import torchvision 23import glob 24import os 25import argparse 26 27parser = argparse.ArgumentParser( 28 description="imagenet preprocessor") 29parser.add_argument("input", metavar="input", help="glob to input images") 30parser.add_argument("output", metavar="output", default="./", 31 help="directory to put output images") 32parser.add_argument("--normalize", action='store_true') 33 34args = parser.parse_args() 35 36# create the output dir if necessary 37try: 38 os.mkdir(args.output) 39except Exception: 40 pass 41 42for ifn in glob.glob(args.input): 43 name, ext = os.path.splitext(ifn) 44 name = name.split("/")[-1] 45 outputname = args.output + "/" + name + ".png" 46 print("processing", name, "as", outputname) 47 48 im = PIL.Image.open(ifn) 49 im.convert("RGB") 50 resize = torchvision.transforms.Compose([ 51 torchvision.transforms.Resize(256), 52 torchvision.transforms.CenterCrop(224), 53 ]) 54 processed_im = resize(im) 55 56 if args.normalize: 57 normalize = torchvision.transforms.Compose([ 58 torchvision.transforms.ToTensor(), 59 torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 60 std=[0.229, 0.224, 0.225]), 61 ]) 62 processed_im = normalize(processed_im) 63 64 processed_im = processed_im.unsqueeze(0) 65 66 torchvision.utils.save_image(processed_im, outputname) 67