1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# 'License'); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint:disable=redefined-outer-name,logging-format-interpolation 18""" Script for converting Fairseq Roberta Model to Gluon. """ 19import argparse 20import logging 21import os 22import sys 23import io 24import numpy as np 25 26import torch 27from fairseq.models.roberta import RobertaModel 28 29import mxnet as mx 30import gluonnlp as nlp 31from gluonnlp.model import BERTEncoder, BERTModel 32from gluonnlp.model.bert import bert_hparams 33from gluonnlp.data.utils import _load_pretrained_vocab 34 35from utils import get_hash, load_text_vocab, tf_vocab_to_gluon_vocab 36 37parser = argparse.ArgumentParser(description='Conversion script for Fairseq RoBERTa model', 38 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 39parser.add_argument('--ckpt_dir', type=str, help='Full path to the roberta folder', 40 default='/home/ubuntu/roberta/roberta.base') 41parser.add_argument('--model', type=str, help='Model type. ', 42 choices=['roberta_12_768_12', 'roberta_24_1024_16'], 43 default='roberta_12_768_12') 44parser.add_argument('--verbose', action='store_true', help='Verbose logging') 45 46args = parser.parse_args() 47 48ckpt_dir = os.path.expanduser(args.ckpt_dir) 49 50ckpt = torch.load(os.path.join(ckpt_dir, 'model.pt')) 51pytorch_params = ckpt['model'] 52 53if args.verbose: 54 print(ckpt['args']) 55 for k, v in pytorch_params.items(): 56 print(k, v.shape) 57 58# Load the model in fairseq 59roberta = RobertaModel.from_pretrained(ckpt_dir) 60roberta.eval() 61 62def fairseq_vocab_to_gluon_vocab(torch_vocab): 63 index_to_words = [None] * len(torch_vocab) 64 65 bos_idx = torch_vocab.bos() 66 pad_idx = torch_vocab.pad() 67 eos_idx = torch_vocab.eos() 68 unk_idx = torch_vocab.unk() 69 70 index_to_words[bos_idx] = torch_vocab.symbols[bos_idx] 71 index_to_words[pad_idx] = torch_vocab.symbols[pad_idx] 72 index_to_words[eos_idx] = torch_vocab.symbols[eos_idx] 73 index_to_words[unk_idx] = torch_vocab.symbols[unk_idx] 74 75 specials = [bos_idx, pad_idx, eos_idx, unk_idx] 76 77 openai_to_roberta = {} 78 openai_vocab = _load_pretrained_vocab('openai_webtext', '.') 79 80 with io.open(os.path.join(ckpt_dir, 'dict.txt'), encoding='utf-8') as f: 81 for i, line in enumerate(f): 82 token, count = line.split(' ') 83 try: 84 fake_token = int(token) 85 openai_to_roberta[token] = i + len(specials) 86 except ValueError: 87 index_to_words[i + len(specials)] = token 88 89 for idx, token in enumerate(openai_vocab.idx_to_token): 90 if str(idx) in openai_to_roberta: 91 index_to_words[openai_to_roberta[str(idx)]] = token 92 else: 93 assert token == u'<mask>', token 94 95 mask_idx = torch_vocab.index(u'<mask>') 96 index_to_words[mask_idx] = torch_vocab.string([mask_idx]) 97 assert None not in index_to_words 98 word2idx = {} 99 for idx, token in enumerate(index_to_words): 100 word2idx[token] = idx 101 102 vocab = nlp.vocab.Vocab(word2idx, token_to_idx=word2idx, 103 unknown_token=index_to_words[unk_idx], 104 padding_token=index_to_words[pad_idx], 105 bos_token=index_to_words[bos_idx], 106 eos_token=index_to_words[eos_idx], 107 mask_token=u'<mask>') 108 return vocab 109 110vocab = fairseq_vocab_to_gluon_vocab(roberta.task.dictionary) 111 112predefined_args = bert_hparams[args.model] 113 114# BERT encoder 115encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], 116 num_layers=predefined_args['num_layers'], units=predefined_args['units'], 117 hidden_size=predefined_args['hidden_size'], 118 max_length=predefined_args['max_length'], 119 num_heads=predefined_args['num_heads'], scaled=predefined_args['scaled'], 120 dropout=predefined_args['dropout'], 121 use_residual=predefined_args['use_residual'], 122 layer_norm_eps=predefined_args['layer_norm_eps']) 123 124# BERT model 125bert = BERTModel(encoder, len(vocab), 126 units=predefined_args['units'], embed_size=predefined_args['embed_size'], 127 word_embed=predefined_args['word_embed'], use_pooler=False, 128 use_token_type_embed=False, use_classifier=False) 129 130bert.initialize(init=mx.init.Normal(0.02)) 131 132ones = mx.nd.ones((2, 8)) 133out = bert(ones, None, mx.nd.array([5, 6]), mx.nd.array([[1], [2]])) 134params = bert._collect_params_with_prefix() 135 136 137 138mapping = { 139 'decoder.2' : 'decoder.lm_head.layer_norm', 140 'decoder.0' : 'decoder.lm_head.dense', 141 'decoder.3' : 'decoder.lm_head', 142 'encoder.layer_norm' : 'decoder.sentence_encoder.emb_layer_norm', 143 'encoder.position_weight' : 'decoder.sentence_encoder.embed_positions.weight', 144 'encoder.transformer_cells': 'decoder.sentence_encoder.layers', 145 'attention_cell.proj_key.' : 'self_attn.in_proj_', 146 'attention_cell.proj_value.' : 'self_attn.in_proj_', 147 'attention_cell.proj_query.' : 'self_attn.in_proj_', 148 'ffn.ffn_1' : 'fc1', 149 'ffn.ffn_2' : 'fc2', 150 'layer_norm.gamma' : 'layer_norm.weight', 151 'layer_norm.beta' : 'layer_norm.bias', 152 'ffn.layer_norm' : 'final_layer_norm', 153 'word_embed.0.weight' : 'decoder.sentence_encoder.embed_tokens.weight', 154} 155 156for i in range(24): 157 mapping['{}.layer_norm'.format(i)] = '{}.self_attn_layer_norm'.format(i) 158 mapping['{}.proj'.format(i)] = '{}.self_attn.out_proj'.format(i) 159 160# set parameter data 161loaded_params = {} 162visited_pytorch_params = {} 163for name in params: 164 pytorch_name = name 165 for source, dest in mapping.items(): 166 pytorch_name = pytorch_name.replace(source, dest) 167 168 assert pytorch_name in pytorch_params.keys(), 'Key ' + pytorch_name + ' for ' + name + ' not found.' 169 torch_arr = pytorch_params[pytorch_name].cpu() 170 # fairseq positional embedding starts with index 2 171 if pytorch_name == 'decoder.sentence_encoder.embed_positions.weight': 172 torch_arr = torch_arr[2:] 173 174 arr = mx.nd.array(torch_arr) 175 if 'attention_cell.proj' in name: 176 unfused = ['query', 'key', 'value'] 177 arrs = arr.split(num_outputs=3, axis=0) 178 for i, p in enumerate(unfused): 179 if p in name: 180 arr = arrs[i] 181 else: 182 assert arr.shape == params[name].shape, (arr.shape, params[name].shape, name, pytorch_name) 183 params[name].set_data(arr) 184 loaded_params[name] = True 185 visited_pytorch_params[pytorch_name] = True 186 187assert len(params) == len(loaded_params) 188assert len(visited_pytorch_params) == len(pytorch_params), "Gluon model does not match PyTorch model. " \ 189 "Please fix the BERTModel hyperparameters\n" + str(len(visited_pytorch_params)) + ' v.s. ' + str(len(pytorch_params)) 190 191 192texts = 'Hello world. abc, def and 中文!' 193torch_tokens = roberta.encode(texts) 194 195torch_features = roberta.extract_features(torch_tokens) 196pytorch_out = torch_features.detach().numpy() 197 198mx_tokenizer = nlp.data.GPT2BPETokenizer() 199mx_tokens = [vocab.bos_token] + mx_tokenizer(texts) + [vocab.eos_token] 200mx_data = vocab[mx_tokens] 201print(mx_tokens) 202print(vocab[mx_tokens]) 203print(torch_tokens) 204assert mx_data == torch_tokens.tolist() 205 206mx_out = bert(mx.nd.array([mx_data])) 207print('stdev = ', np.std(mx_out.asnumpy() - pytorch_out)) 208mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=1e-3, rtol=1e-3) 209mx.test_utils.assert_almost_equal(mx_out.asnumpy(), pytorch_out, atol=5e-6, rtol=5e-6) 210 211bert.save_parameters(os.path.join(ckpt_dir, args.model + '.params')) 212with io.open(os.path.join(ckpt_dir, args.model + '.vocab'), 'w', encoding='utf-8') as f: 213 f.write(vocab.to_json()) 214