如何从 local/colab 目录加载预训练的 BERT 模型?

数据挖掘 nlp 伯特
2021-10-15 02:14:37

嗨,我从这里下载了 BERT 预训练模型 ( https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip ) 并保存到 gogole colab 和本地的目录中。

当我尝试在 colab 中加载模型时,我得到“我们假设 '/content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12/config.json”。试图在本地机器上加载模型并得到同样的错误。

这就是我加载模型的方式: from transformers import BertForMaskedLM BertNSP=BertForMaskedLM.from_pretrained('/content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12/')

当我下载预训练模型时,这是从目录加载模型的正确方法吗?我收到错误“'/content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12/config.json'”下载的模型有这些命名约定,其中文件名以 bert_ 开头,但 BertForMaskedLM 类需要文件名成为 config.json 。

bert_config.json bert_model.ckpt.data-00000-of-00001 bert_model.ckpt.index vocab.txt bert_model.ckpt.meta

完全错误:模型名称“/content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12/”在模型名称列表中找不到(bert-base-uncased、bert-large-uncased、bert-base-cased , bert-large-case, bert-base-multilingual-uncased, bert-base-multilingual-cases, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert -large-case-whole-word-masking,bert-large-uncased-whole-word-masking-finetuned-squad,bert-large-case-whole-word-masking-finetuned-squad,bert-base-case-finetuned -mrpc、bert-base-german-dbmdz-cased、bert-base-german-dbmdz-uncased)。我们假设 '/content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12/config.json' 是名为 config.json 的配置文件或包含此类文件但找不到的目录的路径或 URL此路径或 url 处的任何此类文件。

当我通过从所有 4 个文件名中删除 bert 来重命名上述 4 个文件时,即使存在“model.ckpt.index”文件,我也会收到此错误

错误:“OSError:在目录 /content/drive/My Drive/bert_training/uncased_L-12_H-768_A-12 中找不到名为 ['pytorch_model.bin', 'tf_model.h5', 'model.ckpt.index'] 的文件错误/ 或 from_tf 设置为 False"

2个回答

您可以使用以下代码行导入预训练的 bert 模型:

pip install pytorch_pretrained_bert

from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForNextSentencePrediction

BERT_CLASS = BertForNextSentencePrediction

# Make sure all the files are in same folder, i.e vocab , config and bin file
PRE_TRAINED_MODEL_NAME_OR_PATH = '/path/to/the/files/containing/models/files'

model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)

您正在使用 HuggingFace 的 Transformers 库。

由于这个库最初是用 Pytorch 编写的,因此检查点与官方的 TF 检查点不同。但是您使用的是官方的 TF 检查点。

您需要从那里下载转换后的检查点。


注意:HuggingFace 也发布了 TF 模型。但我不确定它是否可以在没有从官方 TF 检查点转换的情况下工作。如果你想使用 HuggingFace 的 TF API,你需要做:

from transformers import TFBertForMaskedLM