1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""
18Compile Tensorflow Models
19=========================
20This article is an introductory tutorial to deploy tensorflow models with TVM.
21
22For us to begin with, tensorflow python module is required to be installed.
23
24Please refer to https://www.tensorflow.org/install
25"""
26
27# tvm and nnvm
28import nnvm
29import tvm
30
31# os and numpy
32import numpy as np
33import os.path
34
35# Tensorflow imports
36import tensorflow as tf
37from tensorflow.core.framework import graph_pb2
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import tensor_util
40
41# Tensorflow utility functions
42import tvm.relay.testing.tf as tf_testing
43
44# Base location for model related files.
45repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
46
47# Test image
48img_name = 'elephant-299.jpg'
49image_url = os.path.join(repo_base, img_name)
50
51######################################################################
52# Tutorials
53# ---------
54# .. note::
55#
56#   protobuf should be exported with :any:`add_shapes=True` option.
57#   Could use https://github.com/dmlc/web-data/tree/master/tensorflow/scripts/tf-to-nnvm.py
58#   to add shapes for existing models.
59#
60# Please refer docs/frontend/tensorflow.md for more details for various models
61# from tensorflow.
62
63model_name = 'classify_image_graph_def-with_shapes.pb'
64model_url = os.path.join(repo_base, model_name)
65
66# Image label map
67map_proto = 'imagenet_2012_challenge_label_map_proto.pbtxt'
68map_proto_url = os.path.join(repo_base, map_proto)
69
70# Human readable text for labels
71label_map = 'imagenet_synset_to_human_label_map.txt'
72label_map_url = os.path.join(repo_base, label_map)
73
74# Target settings
75# Use these commented settings to build for cuda.
76#target = 'cuda'
77#target_host = 'llvm'
78#layout = "NCHW"
79#ctx = tvm.gpu(0)
80target = 'llvm'
81target_host = 'llvm'
82layout = None
83ctx = tvm.cpu(0)
84
85######################################################################
86# Download required files
87# -----------------------
88# Download files listed above.
89from tvm.contrib.download import download_testdata
90
91img_path = download_testdata(image_url, img_name, module='data')
92model_path = download_testdata(model_url, model_name, module=['tf', 'InceptionV1'])
93map_proto_path = download_testdata(map_proto_url, map_proto, module='data')
94label_path = download_testdata(label_map_url, label_map, module='data')
95
96######################################################################
97# Import model
98# ------------
99# Creates tensorflow graph definition from protobuf file.
100
101with tf.gfile.FastGFile(model_path, 'rb') as f:
102    graph_def = tf.GraphDef()
103    graph_def.ParseFromString(f.read())
104    graph = tf.import_graph_def(graph_def, name='')
105    # Call the utility to import the graph definition into default graph.
106    graph_def = tf_testing.ProcessGraphDefParam(graph_def)
107    # Add shapes to the graph.
108    with tf.Session() as sess:
109        graph_def = tf_testing.AddShapesToGraphDef(sess, 'softmax')
110
111######################################################################
112# Decode image
113# ------------
114# .. note::
115#
116#   tensorflow frontend import doesn't support preprocessing ops like JpegDecode.
117#   JpegDecode is bypassed (just return source node).
118#   Hence we supply decoded frame to TVM instead.
119#
120
121from PIL import Image
122image = Image.open(img_path).resize((299, 299))
123
124x = np.array(image)
125
126######################################################################
127# Import the graph to NNVM
128# ------------------------
129# Import tensorflow graph definition to nnvm.
130#
131# Results:
132#   sym: nnvm graph for given tensorflow protobuf.
133#   params: params converted from tensorflow params (tensor protobuf).
134sym, params = nnvm.frontend.from_tensorflow(graph_def, layout=layout)
135
136print("Tensorflow protobuf imported as nnvm graph")
137######################################################################
138# NNVM Compilation
139# ----------------
140# Compile the graph to llvm target with given input specification.
141#
142# Results:
143#   graph: Final graph after compilation.
144#   params: final params after compilation.
145#   lib: target library which can be deployed on target with tvm runtime.
146
147import nnvm.compiler
148shape_dict = {'DecodeJpeg/contents': x.shape}
149dtype_dict = {'DecodeJpeg/contents': 'uint8'}
150graph, lib, params = nnvm.compiler.build(sym, shape=shape_dict, target=target, target_host=target_host, dtype=dtype_dict, params=params)
151
152######################################################################
153# Execute the portable graph on TVM
154# ---------------------------------
155# Now we can try deploying the NNVM compiled model on target.
156
157from tvm.contrib import graph_runtime
158dtype = 'uint8'
159m = graph_runtime.create(graph, lib, ctx)
160# set inputs
161m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
162m.set_input(**params)
163# execute
164m.run()
165# get outputs
166tvm_output = m.get_output(0, tvm.nd.empty(((1, 1008)), 'float32'))
167
168######################################################################
169# Process the output
170# ------------------
171# Process the model output to human readable text for InceptionV1.
172predictions = tvm_output.asnumpy()
173predictions = np.squeeze(predictions)
174
175# Creates node ID --> English string lookup.
176node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
177                                    uid_lookup_path=label_path)
178
179# Print top 5 predictions from TVM output.
180top_k = predictions.argsort()[-5:][::-1]
181for node_id in top_k:
182    human_string = node_lookup.id_to_string(node_id)
183    score = predictions[node_id]
184    print('%s (score = %.5f)' % (human_string, score))
185
186######################################################################
187# Inference on tensorflow
188# -----------------------
189# Run the corresponding model on tensorflow
190
191def create_graph():
192    """Creates a graph from saved GraphDef file and returns a saver."""
193    # Creates graph from saved graph_def.pb.
194    with tf.gfile.FastGFile(model_path, 'rb') as f:
195        graph_def = tf.GraphDef()
196        graph_def.ParseFromString(f.read())
197        graph = tf.import_graph_def(graph_def, name='')
198        # Call the utility to import the graph definition into default graph.
199        graph_def = tf_testing.ProcessGraphDefParam(graph_def)
200
201def run_inference_on_image(image):
202    """Runs inference on an image.
203
204    Parameters
205    ----------
206    image: String
207        Image file name.
208
209    Returns
210    -------
211        Nothing
212    """
213    if not tf.gfile.Exists(image):
214        tf.logging.fatal('File does not exist %s', image)
215    image_data = tf.gfile.FastGFile(image, 'rb').read()
216
217    # Creates graph from saved GraphDef.
218    create_graph()
219
220    with tf.Session() as sess:
221        softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
222        predictions = sess.run(softmax_tensor,
223                               {'DecodeJpeg/contents:0': image_data})
224
225        predictions = np.squeeze(predictions)
226
227        # Creates node ID --> English string lookup.
228        node_lookup = tf_testing.NodeLookup(label_lookup_path=map_proto_path,
229                                            uid_lookup_path=label_path)
230
231        # Print top 5 predictions from tensorflow.
232        top_k = predictions.argsort()[-5:][::-1]
233        print ("===== TENSORFLOW RESULTS =======")
234        for node_id in top_k:
235            human_string = node_lookup.id_to_string(node_id)
236            score = predictions[node_id]
237            print('%s (score = %.5f)' % (human_string, score))
238
239run_inference_on_image(img_path)
240