1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import argparse 19import logging 20import os 21import pickle 22import re 23import sys 24 25import mxnet as mx 26import numpy as np 27 28import gluonnlp as nlp 29from utils import _split_dict, get_hash, to_gluon_kwargs, read_tf_checkpoint 30 31 32def to_gluon_vocab(corpus): 33 """Convert a TransformerXL corpus object to a GluonNLP Vocab.""" 34 # Clean up latin-1 mis-encoding of words 35 idx2sym = [w.encode('latin-1').decode('utf-8') for w in corpus.vocab.idx2sym] 36 sym2idx = {sym: idx for idx, sym in enumerate(idx2sym)} 37 38 special_tokens = dict(unknown_token=None, padding_token=None, bos_token=None) 39 if hasattr(corpus.vocab, 'unk_idx'): 40 special_tokens['unknown_token'] = idx2sym[corpus.vocab.unk_idx] 41 elif '<unk>' in sym2idx: 42 special_tokens['unknown_token'] = '<unk>' 43 elif '<UNK>' in sym2idx: 44 special_tokens['unknown_token'] = '<UNK>' 45 46 # Discover special tokens 47 if ['<eos>'] == corpus.vocab.special: 48 if '<eos>' in sym2idx: # Only include if special token is actually used 49 special_tokens['eos_token'] = '<eos>' 50 elif '<S>' in sym2idx: 51 # Special case for model trained on Google 1 Billion Word LM dataset 52 special_tokens['eos_token'] = '<S>' 53 elif corpus.vocab.special: 54 raise NotImplementedError('Provided TransformerXL cache.pkl uses an unknown special token. ' 55 'You must extend the `to_gluon_vocab` method to support it.') 56 else: 57 special_tokens['eos_token'] = None 58 59 counter = nlp.data.count_tokens(sym2idx.keys()) 60 vocab = nlp.vocab.Vocab(counter, token_to_idx=sym2idx, **special_tokens) 61 return vocab 62 63 64def set_params(model, tf_tensors, kwargs, tie_r): 65 # Drop optimizer params 66 _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam'), tf_tensors) 67 _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam_1'), tf_tensors) 68 del tf_tensors['global_step'] 69 del tf_tensors['beta1_power'] 70 del tf_tensors['beta2_power'] 71 72 loaded = set() # Cache of processed parameters 73 74 if 'embed_cutoffs' in kwargs: # Adaptive Embedding and Softmax 75 # Embedding 76 for name, param in model._net.embedding._collect_params_with_prefix().items(): 77 purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups() 78 if purpose == 'embedding': 79 assert postfix == '_weight' 80 tf_param = tf_tensors.pop( 81 'transformer/adaptive_embed/cutoff_{}/lookup_table'.format(i)) 82 elif purpose == 'projection': 83 assert postfix == '_weight' 84 tf_param = tf_tensors.pop('transformer/adaptive_embed/cutoff_{}/proj_W'.format(i)).T 85 else: 86 raise RuntimeError('Embedding had unexpected parameter: {}'.format(name)) 87 88 param.set_data(mx.nd.array(tf_param)) 89 loaded.add(param) 90 91 # Softmax 92 for name, param in model._net.crit._collect_params_with_prefix().items(): 93 if param in loaded: 94 continue # Some parameters are shared between Embedding and Softmax 95 96 purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups() 97 if purpose == 'outembedding': 98 if postfix == '_weight': 99 tf_param = tf_tensors.pop( 100 'transformer/adaptive_softmax/cutoff_{}/lookup_table'.format(i)) 101 elif postfix == '_bias': 102 tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/b'.format(i)) 103 else: 104 raise RuntimeError('Softmax had unexpected parameter: {}'.format(name)) 105 elif purpose == 'outprojection': 106 assert postfix == '_weight' 107 tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/proj'.format(i)).T 108 elif purpose == 'cluster': 109 if postfix == '.weight': 110 tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_W') 111 elif postfix == '.bias': 112 tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_b') 113 else: 114 raise RuntimeError('Softmax had unexpected parameter: {}'.format(name)) 115 else: 116 raise RuntimeError('Softmax had unexpected parameter: {}'.format(name)) 117 118 param.set_data(mx.nd.array(tf_param)) 119 loaded.add(param) 120 else: # Non-adaptive, (possibly) projected embedding and softmax 121 # Embedding 122 tf_param = tf_tensors.pop('transformer/adaptive_embed/lookup_table') 123 model._net.embedding.embedding_weight.set_data(mx.nd.array(tf_param)) 124 loaded.add(model._net.embedding.embedding_weight) 125 if kwargs['embed_size'] != kwargs['units']: 126 tf_param = tf_tensors.pop('transformer/adaptive_embed/proj_W') 127 model._net.embedding.projection_weight.set_data(mx.nd.array(tf_param)) 128 loaded.add(model._net.embedding.projection_weight) 129 assert len(model._net.embedding.collect_params().keys()) == 2 130 else: 131 assert len(model._net.embedding.collect_params().keys()) == 1 132 133 # Softmax 134 for name, param in model._net.crit._collect_params_with_prefix().items(): 135 if param in loaded: 136 continue # Some parameters are shared between Embedding and Softmax 137 138 purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups() 139 if purpose == 'outembedding': 140 if postfix == '_weight': 141 tf_param = tf_tensors.pop('transformer/adaptive_softmax/lookup_table') 142 elif postfix == '_bias': 143 tf_param = tf_tensors.pop('transformer/adaptive_softmax/bias') 144 else: 145 raise RuntimeError('Softmax had unexpected parameter: {}'.format(name)) 146 elif purpose == 'outprojection': 147 assert postfix == '_weight' 148 tf_param = tf_tensors.pop('transformer/adaptive_softmax/proj').T 149 else: 150 raise RuntimeError('Softmax had unexpected parameter: {}'.format(name)) 151 152 param.set_data(mx.nd.array(tf_param)) 153 loaded.add(param) 154 155 tf_r_r_bias = tf_tensors.pop('transformer/r_r_bias') 156 tf_r_w_bias = tf_tensors.pop('transformer/r_w_bias') 157 for layer_i in range(kwargs['num_layers']): 158 # Attention Cell 159 attention_cell = model._net.transformer_cells[layer_i].attention_cell 160 # TODO(leezu): Duplicate tied parameters until parameter sharing 161 # support is improved in Gluon 2. (It is currently impossible to share 162 # only subsets of parameters between Blocks due to name clashes between 163 # the non-shared parameters (due to same prefix)) 164 attention_cell.query_key_bias.set_data( 165 mx.nd.array(tf_r_w_bias if tie_r else tf_r_w_bias[layer_i])) 166 attention_cell.query_emb_bias.set_data( 167 mx.nd.array(tf_r_r_bias if tie_r else tf_r_r_bias[layer_i])) 168 tf_param = np.split( 169 tf_tensors.pop('transformer/layer_{}/rel_attn/qkv/kernel'.format(layer_i)).T, 3, axis=0) 170 attention_cell.proj_query.weight.set_data(mx.nd.array(tf_param[0])) 171 attention_cell.proj_key.weight.set_data(mx.nd.array(tf_param[1])) 172 attention_cell.proj_value.weight.set_data(mx.nd.array(tf_param[2])) 173 tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/r/kernel'.format(layer_i)) 174 attention_cell.proj_emb.weight.set_data(mx.nd.array(tf_param.T)) 175 176 # Projection 177 tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/o/kernel'.format(layer_i)) 178 model._net.transformer_cells[layer_i].proj.weight.set_data(mx.nd.array(tf_param.T)) 179 180 # Layer Norm 181 tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/beta'.format(layer_i)) 182 model._net.transformer_cells[layer_i].layer_norm.beta.set_data(mx.nd.array(tf_param)) 183 tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/gamma'.format(layer_i)) 184 model._net.transformer_cells[layer_i].layer_norm.gamma.set_data(mx.nd.array(tf_param)) 185 186 # FFN 187 ffn = model._net.transformer_cells[layer_i].ffn 188 tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/beta'.format(layer_i)) 189 ffn.layer_norm.beta.set_data(mx.nd.array(tf_param)) 190 tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/gamma'.format(layer_i)) 191 ffn.layer_norm.gamma.set_data(mx.nd.array(tf_param)) 192 tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/kernel'.format(layer_i)) 193 ffn.ffn_1.weight.set_data(mx.nd.array(tf_param.T)) 194 tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/bias'.format(layer_i)) 195 ffn.ffn_1.bias.set_data(mx.nd.array(tf_param)) 196 tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/kernel'.format(layer_i)) 197 ffn.ffn_2.weight.set_data(mx.nd.array(tf_param.T)) 198 tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/bias'.format(layer_i)) 199 ffn.ffn_2.bias.set_data(mx.nd.array(tf_param)) 200 201 202def convert_transformerxl(args): 203 # Load tf model and vocab 204 with open(args.cache_pkl, 'rb') as f: 205 corpus = pickle.load(f, encoding='latin1') 206 vocab = to_gluon_vocab(corpus) 207 tf_checkpoint_file = os.path.expanduser( 208 os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix)) 209 tf_tensors = read_tf_checkpoint(tf_checkpoint_file) 210 211 # Initialize Gluon model 212 kwargs, tie_r = to_gluon_kwargs(tf_tensors) 213 model = TransformerXL(vocab_size=len(vocab), **kwargs) 214 model.initialize(init=mx.init.Normal(0.02)) 215 216 # Shape inference based on forward pass 217 batch_size, seq_len = 2, 16 218 mem_length = 100 219 mems = model.begin_mems(batch_size, mem_length, context=mx.cpu()) 220 x = mx.nd.ones(shape=(batch_size, seq_len)) 221 model(x, x, mems) 222 223 # Convert parameters 224 set_params(model, tf_tensors, kwargs, tie_r) 225 226 # Serialization 227 tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) 228 with open(tmp_file_path, 'w') as f: 229 f.write(vocab.to_json()) 230 hash_full, hash_short = get_hash(tmp_file_path) 231 gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab')) 232 with open(gluon_vocab_path, 'w') as f: 233 f.write(vocab.to_json()) 234 logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) 235 model.save_parameters(tmp_file_path) 236 hash_full, hash_short = get_hash(tmp_file_path) 237 os.remove(tmp_file_path) 238 gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params')) 239 logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full) 240 model.save_parameters(gluon_param_path) 241 mx.nd.waitall() 242 243 244if __name__ == '__main__': 245 parser = argparse.ArgumentParser( 246 description='Conversion script for Tensorflow Transformer-XL model', 247 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 248 parser.add_argument('--transformer-xl-repo', type=str, required=True, 249 help='Path to https://github.com/kimiyoung/transformer-xl repo.') 250 parser.add_argument('--tf-checkpoint-dir', type=str, required=True, 251 help='Path to Tensorflow checkpoint folder.') 252 parser.add_argument( 253 '--tf-model-prefix', type=str, required=True, help='Prefix of the checkpoint files. ' 254 'For example model.ckpt-0 or model.ckpt-1191000') 255 parser.add_argument('--cache-pkl', type=str, required=True, 256 help='Path to TransformerXL cache.pkl file.') 257 parser.add_argument('--out-dir', type=str, required=True, 258 help='Path to output folder. The folder must exist.') 259 parser.add_argument('--debug', action='store_true', help='debugging mode') 260 args = parser.parse_args() 261 logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO) 262 logging.info(args) 263 264 # Load stuff required for unpickling 265 sys.path.append(os.path.join((args.transformer_xl_repo), 'tf')) 266 import vocabulary # pylint: disable=unused-import 267 import data_utils # pylint: disable=unused-import 268 269 sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))) 270 from transformer import TransformerXL 271 272 convert_transformerxl(args) 273