1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# 'License'); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint:disable=redefined-outer-name,logging-format-interpolation
18""" Script for converting Fairseq Roberta Model to Gluon. """
19import argparse
20import logging
21import os
22import sys
23import io
24import numpy as np
25
26import torch
27from fairseq.models.roberta import RobertaModel
28
29import mxnet as mx
30import gluonnlp as nlp
31from gluonnlp.model import BERTEncoder, BERTModel
32from gluonnlp.model.bert import bert_hparams
33from gluonnlp.data.utils import _load_pretrained_vocab
34
35from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab
36
37parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model',
38                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
39parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder',
40                    default='/home/ubuntu/roberta/roberta.base')
41parser.add_argument('--model', type=str, help='Model type. ',
42                    choices=['roberta_12_768_12', 'roberta_24_1024_16'],
43                    default='roberta_12_768_12')
44parser.add_argument('--verbose', action='store_true', help='Verbose logging')
45
46args = parser.parse_args()
47
48ckpt_dir = os.path.expanduser(args.ckpt_dir)
49
50ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt'))
51pytorch_params = ckpt['model']
52
53if args.verbose:
54    print(ckpt['args'])
55    for k, v in pytorch_params.items():
56        print(k, v.shape)
57
58# Load the model in fairseq
59roberta = RobertaModel.from_pretrained(ckpt_dir)
60roberta.eval()
61
62def fairseq_vocab_to_gluon_vocab(torch_vocab):
63    index_to_words = [None] * len(torch_vocab)
64
65    bos_idx = torch_vocab.bos()
66    pad_idx = torch_vocab.pad()
67    eos_idx = torch_vocab.eos()
68    unk_idx = torch_vocab.unk()
69
70    index_to_words[bos_idx] = torch_vocab.symbols[bos_idx]
71    index_to_words[pad_idx] = torch_vocab.symbols[pad_idx]
72    index_to_words[eos_idx] = torch_vocab.symbols[eos_idx]
73    index_to_words[unk_idx] = torch_vocab.symbols[unk_idx]
74
75    specials = [bos_idx, pad_idx, eos_idx, unk_idx]
76
77    openai_to_roberta = {}
78    openai_vocab = _load_pretrained_vocab('openai_webtext', '.')
79
80    with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f:
81        for i, line in enumerate(f):
82            token, count = line.split(' ')
83            try:
84                fake_token = int(token)
85                openai_to_roberta[token] = i + len(specials)
86            except ValueError:
87                index_to_words[i + len(specials)] = token
88
89    for idx, token in enumerate(openai_vocab.idx_to_token):
90        if str(idx) in openai_to_roberta:
91            index_to_words[openai_to_roberta[str(idx)]] = token
92        else:
93            assert token == u'<mask>', token
94
95    mask_idx = torch_vocab.index(u'<mask>')
96    index_to_words[mask_idx] = torch_vocab.string([mask_idx])
97    assert None not in index_to_words
98    word2idx = {}
99    for idx, token in enumerate(index_to_words):
100        word2idx[token] = idx
101
102    vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx,
103                            unknown_token=index_to_words[unk_idx],
104                            padding_token=index_to_words[pad_idx],
105                            bos_token=index_to_words[bos_idx],
106                            eos_token=index_to_words[eos_idx],
107                            mask_token=u'<mask>')
108    return vocab
109
110vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary)
111
112predefined_args = bert_hparams[args.model]
113
114# BERT encoder
115encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
116                      num_layers=predefined_args['num_layers'], units=predefined_args['units'],
117                      hidden_size=predefined_args['hidden_size'],
118                      max_length=predefined_args['max_length'],
119                      num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'],
120                      dropout=predefined_args['dropout'],
121                      use_residual=predefined_args['use_residual'],
122                      layer_norm_eps=predefined_args['layer_norm_eps'])
123
124# BERT model
125bert = BERTModel(encoder, len(vocab),
126                 units=predefined_args['units'], embed_size=predefined_args['embed_size'],
127                 word_embed=predefined_args['word_embed'], use_pooler=False,
128                 use_token_type_embed=False, use_classifier=False)
129
130bert.initialize(init=mx.init.Normal(0.02))
131
132ones = mx.nd.ones((2, 8))
133out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]]))
134params = bert._collect_params_with_prefix()
135
136
137
138mapping = {
139    'decoder.2' : 'decoder.lm_head.layer_norm',
140    'decoder.0' : 'decoder.lm_head.dense',
141    'decoder.3' : 'decoder.lm_head',
142    'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm',
143    'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight',
144    'encoder.transformer_cells': 'decoder.sentence_encoder.layers',
145    'attention_cell.proj_key.' : 'self_attn.in_proj_',
146    'attention_cell.proj_value.' : 'self_attn.in_proj_',
147    'attention_cell.proj_query.' : 'self_attn.in_proj_',
148    'ffn.ffn_1' : 'fc1',
149    'ffn.ffn_2' : 'fc2',
150    'layer_norm.gamma' : 'layer_norm.weight',
151    'layer_norm.beta' : 'layer_norm.bias',
152    'ffn.layer_norm' : 'final_layer_norm',
153    'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight',
154}
155
156for i in range(24):
157    mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i)
158    mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i)
159
160# set parameter data
161loaded_params = {}
162visited_pytorch_params = {}
163for name in params:
164    pytorch_name = name
165    for source, dest in mapping.items():
166        pytorch_name = pytorch_name.replace(source, dest)
167
168    assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.'
169    torch_arr = pytorch_params[pytorch_name].cpu()
170    # fairseq positional embedding starts with index 2
171    if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight':
172       torch_arr = torch_arr[2:]
173
174    arr = mx.nd.array(torch_arr)
175    if 'attention_cell.proj' in name:
176        unfused = ['query', 'key', 'value']
177        arrs = arr.split(num_outputs=3, axis=0)
178        for i, p in enumerate(unfused):
179            if p in name:
180                arr = arrs[i]
181    else:
182        assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name)
183    params[name].set_data(arr)
184    loaded_params[name] = True
185    visited_pytorch_params[pytorch_name] = True
186
187assert len(params) == len(loaded_params)
188assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \
189    "Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params))
190
191
192texts = 'Hello world. abc, def and 中文!'
193torch_tokens = roberta.encode(texts)
194
195torch_features = roberta.extract_features(torch_tokens)
196pytorch_out = torch_features.detach().numpy()
197
198mx_tokenizer = nlp.data.GPT2BPETokenizer()
199mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token]
200mx_data = vocab[mx_tokens]
201print(mx_tokens)
202print(vocab[mx_tokens])
203print(torch_tokens)
204assert mx_data == torch_tokens.tolist()
205
206mx_out = bert(mx.nd.array([mx_data]))
207print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out))
208mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3)
209mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6)
210
211bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params'))
212with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f:
213    f.write(vocab.to_json())
214