1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import argparse
19import logging
20import os
21import pickle
22import re
23import sys
24
25import mxnet as mx
26import numpy as np
27
28import gluonnlp as nlp
29from utils import _split_dict, get_hash, to_gluon_kwargs, read_tf_checkpoint
30
31
32def to_gluon_vocab(corpus):
33    """Convert a TransformerXL corpus object to a GluonNLP Vocab."""
34    # Clean up latin-1 mis-encoding of words
35    idx2sym = [w.encode('latin-1').decode('utf-8') for w in corpus.vocab.idx2sym]
36    sym2idx = {sym: idx for idx, sym in enumerate(idx2sym)}
37
38    special_tokens = dict(unknown_token=None, padding_token=None, bos_token=None)
39    if hasattr(corpus.vocab, 'unk_idx'):
40        special_tokens['unknown_token'] = idx2sym[corpus.vocab.unk_idx]
41    elif '<unk>' in sym2idx:
42        special_tokens['unknown_token'] = '<unk>'
43    elif '<UNK>' in sym2idx:
44        special_tokens['unknown_token'] = '<UNK>'
45
46    # Discover special tokens
47    if ['<eos>'] == corpus.vocab.special:
48        if '<eos>' in sym2idx:  # Only include if special token is actually used
49            special_tokens['eos_token'] = '<eos>'
50    elif '<S>' in sym2idx:
51        # Special case for model trained on Google 1 Billion Word LM dataset
52        special_tokens['eos_token'] = '<S>'
53    elif corpus.vocab.special:
54        raise NotImplementedError('Provided TransformerXL cache.pkl uses an unknown special token. '
55                                  'You must extend the `to_gluon_vocab` method to support it.')
56    else:
57        special_tokens['eos_token'] = None
58
59    counter = nlp.data.count_tokens(sym2idx.keys())
60    vocab = nlp.vocab.Vocab(counter, token_to_idx=sym2idx, **special_tokens)
61    return vocab
62
63
64def set_params(model, tf_tensors, kwargs, tie_r):
65    # Drop optimizer params
66    _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam'), tf_tensors)
67    _, tf_tensors = _split_dict(lambda k, v: k.endswith('Adam_1'), tf_tensors)
68    del tf_tensors['global_step']
69    del tf_tensors['beta1_power']
70    del tf_tensors['beta2_power']
71
72    loaded = set()  # Cache of processed parameters
73
74    if 'embed_cutoffs' in kwargs:  # Adaptive Embedding and Softmax
75        # Embedding
76        for name, param in model._net.embedding._collect_params_with_prefix().items():
77            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
78            if purpose == 'embedding':
79                assert postfix == '_weight'
80                tf_param = tf_tensors.pop(
81                    'transformer/adaptive_embed/cutoff_{}/lookup_table'.format(i))
82            elif purpose == 'projection':
83                assert postfix == '_weight'
84                tf_param = tf_tensors.pop('transformer/adaptive_embed/cutoff_{}/proj_W'.format(i)).T
85            else:
86                raise RuntimeError('Embedding had unexpected parameter: {}'.format(name))
87
88            param.set_data(mx.nd.array(tf_param))
89            loaded.add(param)
90
91        # Softmax
92        for name, param in model._net.crit._collect_params_with_prefix().items():
93            if param in loaded:
94                continue  # Some parameters are shared between Embedding and Softmax
95
96            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
97            if purpose == 'outembedding':
98                if postfix == '_weight':
99                    tf_param = tf_tensors.pop(
100                        'transformer/adaptive_softmax/cutoff_{}/lookup_table'.format(i))
101                elif postfix == '_bias':
102                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/b'.format(i))
103                else:
104                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
105            elif purpose == 'outprojection':
106                assert postfix == '_weight'
107                tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_{}/proj'.format(i)).T
108            elif purpose == 'cluster':
109                if postfix == '.weight':
110                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_W')
111                elif postfix == '.bias':
112                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/cutoff_0/cluster_b')
113                else:
114                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
115            else:
116                raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
117
118            param.set_data(mx.nd.array(tf_param))
119            loaded.add(param)
120    else:  # Non-adaptive, (possibly) projected embedding and softmax
121        # Embedding
122        tf_param = tf_tensors.pop('transformer/adaptive_embed/lookup_table')
123        model._net.embedding.embedding_weight.set_data(mx.nd.array(tf_param))
124        loaded.add(model._net.embedding.embedding_weight)
125        if kwargs['embed_size'] != kwargs['units']:
126            tf_param = tf_tensors.pop('transformer/adaptive_embed/proj_W')
127            model._net.embedding.projection_weight.set_data(mx.nd.array(tf_param))
128            loaded.add(model._net.embedding.projection_weight)
129            assert len(model._net.embedding.collect_params().keys()) == 2
130        else:
131            assert len(model._net.embedding.collect_params().keys()) == 1
132
133        # Softmax
134        for name, param in model._net.crit._collect_params_with_prefix().items():
135            if param in loaded:
136                continue  # Some parameters are shared between Embedding and Softmax
137
138            purpose, i, postfix = re.match(r'([a-zA-Z]*)(\d*)(.*)', name).groups()
139            if purpose == 'outembedding':
140                if postfix == '_weight':
141                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/lookup_table')
142                elif postfix == '_bias':
143                    tf_param = tf_tensors.pop('transformer/adaptive_softmax/bias')
144                else:
145                    raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
146            elif purpose == 'outprojection':
147                assert postfix == '_weight'
148                tf_param = tf_tensors.pop('transformer/adaptive_softmax/proj').T
149            else:
150                raise RuntimeError('Softmax had unexpected parameter: {}'.format(name))
151
152            param.set_data(mx.nd.array(tf_param))
153            loaded.add(param)
154
155    tf_r_r_bias = tf_tensors.pop('transformer/r_r_bias')
156    tf_r_w_bias = tf_tensors.pop('transformer/r_w_bias')
157    for layer_i in range(kwargs['num_layers']):
158        # Attention Cell
159        attention_cell = model._net.transformer_cells[layer_i].attention_cell
160        # TODO(leezu): Duplicate tied parameters until parameter sharing
161        # support is improved in Gluon 2. (It is currently impossible to share
162        # only subsets of parameters between Blocks due to name clashes between
163        # the non-shared parameters (due to same prefix))
164        attention_cell.query_key_bias.set_data(
165            mx.nd.array(tf_r_w_bias if tie_r else tf_r_w_bias[layer_i]))
166        attention_cell.query_emb_bias.set_data(
167            mx.nd.array(tf_r_r_bias if tie_r else tf_r_r_bias[layer_i]))
168        tf_param = np.split(
169            tf_tensors.pop('transformer/layer_{}/rel_attn/qkv/kernel'.format(layer_i)).T, 3, axis=0)
170        attention_cell.proj_query.weight.set_data(mx.nd.array(tf_param[0]))
171        attention_cell.proj_key.weight.set_data(mx.nd.array(tf_param[1]))
172        attention_cell.proj_value.weight.set_data(mx.nd.array(tf_param[2]))
173        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/r/kernel'.format(layer_i))
174        attention_cell.proj_emb.weight.set_data(mx.nd.array(tf_param.T))
175
176        # Projection
177        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/o/kernel'.format(layer_i))
178        model._net.transformer_cells[layer_i].proj.weight.set_data(mx.nd.array(tf_param.T))
179
180        # Layer Norm
181        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/beta'.format(layer_i))
182        model._net.transformer_cells[layer_i].layer_norm.beta.set_data(mx.nd.array(tf_param))
183        tf_param = tf_tensors.pop('transformer/layer_{}/rel_attn/LayerNorm/gamma'.format(layer_i))
184        model._net.transformer_cells[layer_i].layer_norm.gamma.set_data(mx.nd.array(tf_param))
185
186        # FFN
187        ffn = model._net.transformer_cells[layer_i].ffn
188        tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/beta'.format(layer_i))
189        ffn.layer_norm.beta.set_data(mx.nd.array(tf_param))
190        tf_param = tf_tensors.pop('transformer/layer_{}/ff/LayerNorm/gamma'.format(layer_i))
191        ffn.layer_norm.gamma.set_data(mx.nd.array(tf_param))
192        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/kernel'.format(layer_i))
193        ffn.ffn_1.weight.set_data(mx.nd.array(tf_param.T))
194        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_1/bias'.format(layer_i))
195        ffn.ffn_1.bias.set_data(mx.nd.array(tf_param))
196        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/kernel'.format(layer_i))
197        ffn.ffn_2.weight.set_data(mx.nd.array(tf_param.T))
198        tf_param = tf_tensors.pop('transformer/layer_{}/ff/layer_2/bias'.format(layer_i))
199        ffn.ffn_2.bias.set_data(mx.nd.array(tf_param))
200
201
202def convert_transformerxl(args):
203    # Load tf model and vocab
204    with open(args.cache_pkl, 'rb') as f:
205        corpus = pickle.load(f, encoding='latin1')
206    vocab = to_gluon_vocab(corpus)
207    tf_checkpoint_file = os.path.expanduser(
208        os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix))
209    tf_tensors = read_tf_checkpoint(tf_checkpoint_file)
210
211    # Initialize Gluon model
212    kwargs, tie_r = to_gluon_kwargs(tf_tensors)
213    model = TransformerXL(vocab_size=len(vocab), **kwargs)
214    model.initialize(init=mx.init.Normal(0.02))
215
216    # Shape inference based on forward pass
217    batch_size, seq_len = 2, 16
218    mem_length = 100
219    mems = model.begin_mems(batch_size, mem_length, context=mx.cpu())
220    x = mx.nd.ones(shape=(batch_size, seq_len))
221    model(x, x, mems)
222
223    # Convert parameters
224    set_params(model, tf_tensors, kwargs, tie_r)
225
226    # Serialization
227    tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
228    with open(tmp_file_path, 'w') as f:
229        f.write(vocab.to_json())
230    hash_full, hash_short = get_hash(tmp_file_path)
231    gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab'))
232    with open(gluon_vocab_path, 'w') as f:
233        f.write(vocab.to_json())
234        logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full)
235    model.save_parameters(tmp_file_path)
236    hash_full, hash_short = get_hash(tmp_file_path)
237    os.remove(tmp_file_path)
238    gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params'))
239    logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
240    model.save_parameters(gluon_param_path)
241    mx.nd.waitall()
242
243
244if __name__ == '__main__':
245    parser = argparse.ArgumentParser(
246        description='Conversion script for Tensorflow Transformer-XL model',
247        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
248    parser.add_argument('--transformer-xl-repo', type=str, required=True,
249                        help='Path to https://github.com/kimiyoung/transformer-xl repo.')
250    parser.add_argument('--tf-checkpoint-dir', type=str, required=True,
251                        help='Path to Tensorflow checkpoint folder.')
252    parser.add_argument(
253        '--tf-model-prefix', type=str, required=True, help='Prefix of the checkpoint files. '
254        'For example model.ckpt-0 or model.ckpt-1191000')
255    parser.add_argument('--cache-pkl', type=str, required=True,
256                        help='Path to TransformerXL cache.pkl file.')
257    parser.add_argument('--out-dir', type=str, required=True,
258                        help='Path to output folder. The folder must exist.')
259    parser.add_argument('--debug', action='store_true', help='debugging mode')
260    args = parser.parse_args()
261    logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
262    logging.info(args)
263
264    # Load stuff required for unpickling
265    sys.path.append(os.path.join((args.transformer_xl_repo), 'tf'))
266    import vocabulary  # pylint: disable=unused-import
267    import data_utils  # pylint: disable=unused-import
268
269    sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)))
270    from transformer import TransformerXL
271
272    convert_transformerxl(args)
273