1"""Generate a mock model for LLVM tests for Register Allocation. 2The generated model is not a neural net - it is just a tf.function with the 3correct input and output parameters. 4""" 5## By construction, the mock model will always output the first liverange that can be evicted. 6 7import os 8import sys 9import tensorflow as tf 10POLICY_DECISION_LABEL = 'priority' 11POLICY_OUTPUT_SPEC = """ 12[ 13 { 14 "logging_name": "priority", 15 "tensor_spec": { 16 "name": "StatefulPartitionedCall", 17 "port": 0, 18 "type": "float", 19 "shape": [ 20 1 21 ] 22 } 23 } 24] 25""" 26PER_LIVEINTERVAL_INT64_FEATURE_LIST = [ 27 'li_size', 'stage' 28] 29PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST = ['weight' 30] 31PER_LIVEINTERVAL_FEATURE_LIST = PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST + \ 32 PER_LIVEINTERVAL_INT64_FEATURE_LIST 33CONTEXT_FEATURE_LIST = ('discount', 'reward', 'step_type') 34 35 36def get_input_signature(): 37 """Returns (time_step_spec, action_spec) for LLVM register allocation.""" 38 inputs = dict( 39 (key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) 40 for key in PER_LIVEINTERVAL_INT64_FEATURE_LIST) 41 inputs.update( 42 dict((key, 43 tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 44 for key in PER_LIVEINTERVAL_FLOAT32_FEATURE_LIST)) 45 inputs.update( 46 dict((key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key)) 47 for key in ['discount', 'reward'])) 48 inputs.update( 49 dict((key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key)) 50 for key in ['step_type'])) 51 return inputs 52 53 54def get_output_spec_path(path): 55 return os.path.join(path, 'output_spec.json') 56 57 58def build_mock_model(path): 59 """Build and save the mock model with the given signature.""" 60 module = tf.Module() 61 # We have to set this useless variable in order for the TF C API to correctly 62 # intake it 63 module.var = tf.Variable(0, dtype=tf.float32) 64 65 def action(*inputs): 66 s1 = tf.reduce_sum([ 67 tf.cast(inputs[0][key], tf.float32) for key in PER_LIVEINTERVAL_FEATURE_LIST 68 ], 69 axis=0) 70 s2 = tf.reduce_sum( 71 [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST]) 72 # Add a large number so s won't be 0. 73 s = s1 + s2 74 result = s + module.var 75 return {POLICY_DECISION_LABEL: result} 76 module.action = tf.function()(action) 77 action = { 78 'action': module.action.get_concrete_function(get_input_signature()) 79 } 80 81 tf.saved_model.save(module, path, signatures=action) 82 output_spec_path = get_output_spec_path(path) 83 with open(output_spec_path, 'w') as f: 84 print(f'Writing output spec to {output_spec_path}.') 85 f.write(POLICY_OUTPUT_SPEC) 86 87 88def main(argv): 89 assert len(argv) == 2 90 model_path = argv[1] 91 build_mock_model(model_path) 92 93 94if __name__ == '__main__': 95 main(sys.argv) 96