1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Tuning a single dense operator"""
19
20from collections import namedtuple
21import logging
22import os
23
24import tvm
25from tvm import autotvm
26from tvm.contrib.util import get_lower_ir
27import topi
28import vta
29import vta.testing
30
31env = vta.get_env()
32
33Workload = namedtuple("DenseWorkload",
34                      ['batch', 'in_filter', 'out_filter'])
35
36dense_wkls = [
37    ('lstm.dense.1',  Workload(1, 256, 128)),
38    ('lstm.dense.4',  Workload(4, 256, 128)),
39]
40
41@tvm.tag_scope(tag=topi.tag.ELEMWISE)
42def my_clip(x, a_min, a_max):
43    """Unlike topi's current clip, put min and max into two stages."""
44    const_min = tvm.const(a_min, x.dtype)
45    const_max = tvm.const(a_max, x.dtype)
46    x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
47    x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
48    return x
49
50def dense(N, CI, CO):
51    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
52    kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN)
53
54    data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
55    kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
56
57    with tvm.target.vta():
58        res = topi.nn.dense(data, kernel, None, 'int32')
59        res = topi.right_shift(res, 8)
60        res = my_clip(res, 0, 127)
61        res = topi.cast(res, "int8")
62
63    if tvm.target.current_target().device_name == 'vta':
64        s = topi.generic.schedule_dense([res])
65    else:
66        s = tvm.create_schedule([res.op])
67
68    return s, [data, kernel, res]
69
70if __name__ == '__main__':
71
72    # Logging config (for printing tuning log to the screen)
73    logging.basicConfig()
74    # logging.getLogger('autotvm').setLevel(logging.DEBUG)
75
76    # Tuning log files
77    log_file = "%s.dense.log" % (env.TARGET)
78    # create tmp log file
79    tmp_log_file = log_file + ".tmp"
80    if os.path.exists(log_file):
81        os.remove(log_file)
82
83    # Get tracker info from env
84    tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
85    tracket_port = os.environ.get("TVM_TRACKER_PORT", None)
86    if not tracket_host or not tracket_port:
87        print("Set your AutoTVM tracker node host and port variables to run the autotuner")
88        exit()
89
90    for idx, (wl_name, wl) in enumerate(dense_wkls):
91
92        prefix = "[Task %2d/%2d] " % (idx, len(dense_wkls))
93
94        # Workload parameters
95        N = wl.batch
96        CI = wl.in_filter
97        CO = wl.out_filter
98
99        task = autotvm.task.create(dense, args=(N, CI, CO),
100                target=tvm.target.vta(), target_host=env.target_host, template_key='direct')
101        print(task.config_space)
102
103        # Tune
104        measure_option = autotvm.measure_option(
105                builder=autotvm.LocalBuilder(),
106                runner=autotvm.RPCRunner(
107                        env.TARGET, host=tracket_host, port=int(tracket_port),
108                        number=5, timeout=60,
109                        check_correctness=True))
110
111        # Run Tuner
112        tuner = autotvm.tuner.RandomTuner(task)
113        tuner.tune(
114                n_trial=len(task.config_space),
115                early_stopping=None,
116                measure_option=measure_option,
117                callbacks=[
118                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
119                    autotvm.callback.log_to_file(tmp_log_file)])
120
121    # Pick best records to a cache file
122    autotvm.record.pick_best(tmp_log_file, log_file)
123    os.remove(tmp_log_file)