18import argparse
19import logging
20import os
21import pickle
22import re
23import sys
25import mxnet as mx
26import numpy as np
28import gluonnlp as nlp
29from utils import _split_dict, get_hash, to_gluon_kwargs, read_tf_checkpoint
32def to_gluon_vocab(corpus):
33    """Convert a TransformerXL corpus object to a GluonNLP Vocab."""
34    # Clean up latin-1 mis-encoding of words
35    idx2sym = [w.encode('latin-1').decode('utf-8') for w in corpus.vocab.idx2sym]
36    sym2idx = {sym: idx for idx, sym in enumerate(idx2sym)}
38    special_tokens = dict(unknown_token=None, padding_token=None, bos_token=None)
39    if hasattr(corpus.vocab, 'unk_idx'):
40        special_tokens['unknown_token'] = idx2sym[corpus.vocab.unk_idx]
41    elif '<unk>' in sym2idx:
42        special_tokens['unknown_token'] = '<unk>'
43    elif '<UNK>' in sym2idx:
44        special_tokens['unknown_token'] = '<UNK>'
46    # Discover special tokens
47    if ['<eos>'] == corpus.vocab.special:
48        if '<eos>' in sym2idx:  # Only include if special token is actually used
49            special_tokens['eos_token'] = '<eos>'
50    elif '<S>' in sym2idx:
51        # Special case for model trained on Google 1 Billion Word LM dataset
52        special_tokens['eos_token'] = '<S>'
53    elif corpus.vocab.special:
54        raise NotImplementedError('Provided TransformerXL cache.pkl uses an unknown special token. '
55                                  'You must extend the `to_gluon_vocab` method to support it.')
56    else:
57        special_tokens['eos_token'] = None
59    counter = nlp.data.count_tokens(sym2idx.keys())
60    vocab = nlp.vocab.Vocab(counter, token_to_idx=sym2idx, **special_tokens)
61    return vocab
64def set_params(model, tf_tensors, kwargs, tie_r):
65    # Drop optimizer params
66    _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam'), tf_tensors)
67    _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam_1'), tf_tensors)
68    del tf_tensors['global_step']
69    del tf_tensors['beta1_power']
70    del tf_tensors['beta2_power']
72    loaded = set()  # Cache of processed parameters
74    if 'embed_cutoffs' in kwargs:  # Adaptive Embedding and Softmax
75        # Embedding
76        for name, param in model._net.embedding._collect_params_with_prefix().items():
77            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
78            if purpose == 'embedding':
79                assert postfix == '_weight'
80                tf_param = tf_tensors.pop(
81                    'transformer/adaptive_embed/cutoff_{}/lookup_table'.format(i))
82            elif purpose == 'projection':
83                assert postfix == '_weight'
84                tf_param = tf_tensors.pop('transformer/adaptive_embed/cutoff_{}/proj_W'.format(i)).T
85            else:
86                raise RuntimeError('Embedding had unexpected parameter: {}'.format(name))
88            param.set_data(mx.nd.array(tf_param))
89            loaded.add(param)
91        # Softmax
92        for name, param in model._net.crit._collect_params_with_prefix().items():
93            if param in loaded:
94                continue  # Some parameters are shared between Embedding and Softmax
96            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
97            if purpose == 'outembedding':
98                if postfix == '_weight':
99                    tf_param = tf_tensors.pop(
100                        'transformer/adaptive_softmax/cutoff_{}/lookup_table'.format(i))
101                elif postfix == '_bias':
102                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/b'.format(i))
103                else:
104                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
105            elif purpose == 'outprojection':
106                assert postfix == '_weight'
107                tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/proj'.format(i)).T
108            elif purpose == 'cluster':
109                if postfix == '.weight':
110                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_W')
111                elif postfix == '.bias':
112                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_b')
113                else:
114                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
115            else:
116                raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
118            param.set_data(mx.nd.array(tf_param))
119            loaded.add(param)
120    else:  # Non-adaptive, (possibly) projected embedding and softmax
121        # Embedding
122        tf_param = tf_tensors.pop('transformer/adaptive_embed/lookup_table')
123        model._net.embedding.embedding_weight.set_data(mx.nd.array(tf_param))
124        loaded.add(model._net.embedding.embedding_weight)
125        if kwargs['embed_size'] != kwargs['units']:
126            tf_param = tf_tensors.pop('transformer/adaptive_embed/proj_W')
127            model._net.embedding.projection_weight.set_data(mx.nd.array(tf_param))
128            loaded.add(model._net.embedding.projection_weight)
129            assert len(model._net.embedding.collect_params().keys()) == 2
130        else:
131            assert len(model._net.embedding.collect_params().keys()) == 1
133        # Softmax
134        for name, param in model._net.crit._collect_params_with_prefix().items():
135            if param in loaded:
136                continue  # Some parameters are shared between Embedding and Softmax
138            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
139            if purpose == 'outembedding':
140                if postfix == '_weight':
141                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/lookup_table')
142                elif postfix == '_bias':
143                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/bias')
144                else:
145                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
146            elif purpose == 'outprojection':
147                assert postfix == '_weight'
148                tf_param = tf_tensors.pop('transformer/adaptive_softmax/proj').T
149            else:
150                raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
152            param.set_data(mx.nd.array(tf_param))
153            loaded.add(param)
155    tf_r_r_bias = tf_tensors.pop('transformer/r_r_bias')
156    tf_r_w_bias = tf_tensors.pop('transformer/r_w_bias')
157    for layer_i in range(kwargs['num_layers']):
158        # Attention Cell
159        attention_cell = model._net.transformer_cells[layer_i].attention_cell
160        # TODO(leezu): Duplicate tied parameters until parameter sharing
161        # support is improved in Gluon 2. (It is currently impossible to share
162        # only subsets of parameters between Blocks due to name clashes between
163        # the non-shared parameters (due to same prefix))
164        attention_cell.query_key_bias.set_data(
165            mx.nd.array(tf_r_w_bias if tie_r else tf_r_w_bias[layer_i]))
166        attention_cell.query_emb_bias.set_data(
167            mx.nd.array(tf_r_r_bias if tie_r else tf_r_r_bias[layer_i]))
168        tf_param = np.split(
169            tf_tensors.pop('transformer/layer_{}/rel_attn/qkv/kernel'.format(layer_i)).T, 3, axis=0)
170        attention_cell.proj_query.weight.set_data(mx.nd.array(tf_param[0]))
171        attention_cell.proj_key.weight.set_data(mx.nd.array(tf_param[1]))
172        attention_cell.proj_value.weight.set_data(mx.nd.array(tf_param[2]))
173        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/r/kernel'.format(layer_i))
174        attention_cell.proj_emb.weight.set_data(mx.nd.array(tf_param.T))
176        # Projection
177        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/o/kernel'.format(layer_i))
178        model._net.transformer_cells[layer_i].proj.weight.set_data(mx.nd.array(tf_param.T))
180        # Layer Norm
181        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/beta'.format(layer_i))
182        model._net.transformer_cells[layer_i].layer_norm.beta.set_data(mx.nd.array(tf_param))
183        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/gamma'.format(layer_i))
184        model._net.transformer_cells[layer_i].layer_norm.gamma.set_data(mx.nd.array(tf_param))
186        # FFN
187        ffn = model._net.transformer_cells[layer_i].ffn
188        tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/beta'.format(layer_i))
189        ffn.layer_norm.beta.set_data(mx.nd.array(tf_param))
190        tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/gamma'.format(layer_i))
191        ffn.layer_norm.gamma.set_data(mx.nd.array(tf_param))
192        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/kernel'.format(layer_i))
193        ffn.ffn_1.weight.set_data(mx.nd.array(tf_param.T))
194        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/bias'.format(layer_i))
195        ffn.ffn_1.bias.set_data(mx.nd.array(tf_param))
196        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/kernel'.format(layer_i))
197        ffn.ffn_2.weight.set_data(mx.nd.array(tf_param.T))
198        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/bias'.format(layer_i))
199        ffn.ffn_2.bias.set_data(mx.nd.array(tf_param))
202def convert_transformerxl(args):
203    # Load tf model and vocab
204    with open(args.cache_pkl, 'rb') as f:
205        corpus = pickle.load(f, encoding='latin1')
206    vocab = to_gluon_vocab(corpus)
207    tf_checkpoint_file = os.path.expanduser(
208        os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix))
209    tf_tensors = read_tf_checkpoint(tf_checkpoint_file)
211    # Initialize Gluon model
212    kwargs, tie_r = to_gluon_kwargs(tf_tensors)
213    model = TransformerXL(vocab_size=len(vocab), **kwargs)
214    model.initialize(init=mx.init.Normal(0.02))
216    # Shape inference based on forward pass
217    batch_size, seq_len = 2, 16
218    mem_length = 100
219    mems = model.begin_mems(batch_size, mem_length, context=mx.cpu())
220    x = mx.nd.ones(shape=(batch_size, seq_len))
221    model(x, x, mems)
223    # Convert parameters
224    set_params(model, tf_tensors, kwargs, tie_r)
226    # Serialization
227    tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
228    with open(tmp_file_path, 'w') as f:
229        f.write(vocab.to_json())
230    hash_full, hash_short = get_hash(tmp_file_path)
231    gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab'))
232    with open(gluon_vocab_path, 'w') as f:
233        f.write(vocab.to_json())
234        logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full)
235    model.save_parameters(tmp_file_path)
236    hash_full, hash_short = get_hash(tmp_file_path)
237    os.remove(tmp_file_path)
238    gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params'))
239    logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
240    model.save_parameters(gluon_param_path)
241    mx.nd.waitall()
244if __name__ == '__main__':
245    parser = argparse.ArgumentParser(
246        description='Conversion script for Tensorflow Transformer-XL model',
247        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
248    parser.add_argument('--transformer-xl-repo', type=str, required=True,
249                        help='Path to https://github.com/kimiyoung/transformer-xl repo.')
250    parser.add_argument('--tf-checkpoint-dir', type=str, required=True,
251                        help='Path to Tensorflow checkpoint folder.')
252    parser.add_argument(
253        '--tf-model-prefix', type=str, required=True, help='Prefix of the checkpoint files. '
254        'For example model.ckpt-0 or model.ckpt-1191000')
255    parser.add_argument('--cache-pkl', type=str, required=True,
256                        help='Path to TransformerXL cache.pkl file.')
257    parser.add_argument('--out-dir', type=str, required=True,
258                        help='Path to output folder. The folder must exist.')
259    parser.add_argument('--debug', action='store_true', help='debugging mode')
260    args = parser.parse_args()
261    logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
262    logging.info(args)
264    # Load stuff required for unpickling
265    sys.path.append(os.path.join((args.transformer_xl_repo), 'tf'))
266    import vocabulary  # pylint: disable=unused-import
267    import data_utils  # pylint: disable=unused-import
269    sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)))
270    from transformer import TransformerXL
272    convert_transformerxl(args)