1"""Generate a mock model for LLVM tests for Register Allocation.
2The generated model is not a neural net - it is just a tf.function with the
3correct input and output parameters.
4"""
5## By construction, the mock model will always output the first liverange that can be evicted.
6
7import os
8import sys
9import tensorflow as tf
10POLICY_DECISION_LABEL = 'priority'
11POLICY_OUTPUT_SPEC = """
12[
13    {
14        "logging_name": "priority",
15        "tensor_spec": {
16            "name": "StatefulPartitionedCall",
17            "port": 0,
18            "type": "float",
19            "shape": [
20                1
21            ]
22        }
23    }
24]
25"""
26PER_LIVEINTERVAL_INT64_FEATURE_LIST = [
27    'li_size', 'stage'
28]
29PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ['weight'
30]
31PER_LIVEINTERVAL_FEATURE_LIST = PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + \
32    PER_LIVEINTERVAL_INT64_FEATURE_LIST
33CONTEXT_FEATURE_LIST =  ('discount', 'reward', 'step_type')
34
35
36def get_input_signature():
37   """Returns (time_step_spec, action_spec) for LLVM register allocation."""
38   inputs = dict(
39       (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
40       for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST)
41   inputs.update(
42       dict((key,
43             tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
44            for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST))
45   inputs.update(
46       dict((key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
47            for key in ['discount', 'reward']))
48   inputs.update(
49       dict((key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key))
50            for key in ['step_type']))
51   return inputs
52
53
54def get_output_spec_path(path):
55   return os.path.join(path, 'output_spec.json')
56
57
58def build_mock_model(path):
59   """Build and save the mock model with the given signature."""
60   module = tf.Module()
61   # We have to set this useless variable in order for the TF C API to correctly
62   # intake it
63   module.var = tf.Variable(0, dtype=tf.float32)
64
65   def action(*inputs):
66     s1 = tf.reduce_sum([
67         tf.cast(inputs[0][key], tf.float32) for key in PER_LIVEINTERVAL_FEATURE_LIST
68     ],
69         axis=0)
70     s2 = tf.reduce_sum(
71         [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST])
72     # Add a large number so s won't be 0.
73     s = s1 + s2
74     result = s + module.var
75     return {POLICY_DECISION_LABEL: result}
76   module.action = tf.function()(action)
77   action = {
78       'action': module.action.get_concrete_function(get_input_signature())
79   }
80
81   tf.saved_model.save(module, path, signatures=action)
82   output_spec_path = get_output_spec_path(path)
83   with open(output_spec_path, 'w') as f:
84     print(f'Writing output spec to {output_spec_path}.')
85     f.write(POLICY_OUTPUT_SPEC)
86
87
88def main(argv):
89   assert len(argv) == 2
90   model_path = argv[1]
91   build_mock_model(model_path)
92
93
94if __name__ == '__main__':
95   main(sys.argv)
96