1#!/usr/bin/env python3
2# Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17# imagenet-process : Runs preprocessing of standard imagenet images
18#                    to work with a pretrained model (e.g. resnet)
19#                    through glow
20# usage: python3 imagenet-process images/*.JPEG processed/
21import PIL.Image
22import torchvision
23import glob
24import os
25import argparse
26
27parser = argparse.ArgumentParser(
28    description="imagenet preprocessor")
29parser.add_argument("input", metavar="input", help="glob to input images")
30parser.add_argument("output", metavar="output", default="./",
31                    help="directory to put output images")
32parser.add_argument("--normalize", action='store_true')
33
34args = parser.parse_args()
35
36# create the output dir if necessary
37try:
38    os.mkdir(args.output)
39except Exception:
40    pass
41
42for ifn in glob.glob(args.input):
43    name, ext = os.path.splitext(ifn)
44    name = name.split("/")[-1]
45    outputname = args.output + "/" + name + ".png"
46    print("processing", name, "as", outputname)
47
48    im = PIL.Image.open(ifn)
49    im.convert("RGB")
50    resize = torchvision.transforms.Compose([
51        torchvision.transforms.Resize(256),
52        torchvision.transforms.CenterCrop(224),
53    ])
54    processed_im = resize(im)
55
56    if args.normalize:
57        normalize = torchvision.transforms.Compose([
58            torchvision.transforms.ToTensor(),
59            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
60                                             std=[0.229, 0.224, 0.225]),
61        ])
62        processed_im = normalize(processed_im)
63
64    processed_im = processed_im.unsqueeze(0)
65
66    torchvision.utils.save_image(processed_im, outputname)
67