如何使用 BERT 获得句子嵌入?

数据挖掘 张量流 nlp 火炬 伯特
2021-09-21 22:09:15

如何使用 BERT 获得句子嵌入?

from transformers import BertTokenizer
tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
sentence='I really enjoyed this movie a lot.'
#1.Tokenize the sequence:
tokens=tokenizer.tokenize(sentence)
print(tokens)
print(type(tokens))

2.添加[CLS]和[SEP]令牌:

tokens = ['[CLS]'] + tokens + ['[SEP]']
print(" Tokens are \n {} ".format(tokens))

3.填充输入:

T=15
padded_tokens=tokens +['[PAD]' for _ in range(T-len(tokens))]
print("Padded tokens are \n {} ".format(padded_tokens))
attn_mask=[ 1 if token != '[PAD]' else 0 for token in padded_tokens  ]
print("Attention Mask are \n {} ".format(attn_mask))

4. 维护一个段令牌列表:

seg_ids=[0 for _ in range(len(padded_tokens))]
print("Segment Tokens are \n {}".format(seg_ids))

5. 获取 BERT 词汇表中标记的索引:

sent_ids=tokenizer.convert_tokens_to_ids(padded_tokens)
print("senetence idexes \n {} ".format(sent_ids))
token_ids = torch.tensor(sent_ids).unsqueeze(0) 
attn_mask = torch.tensor(attn_mask).unsqueeze(0) 
seg_ids   = torch.tensor(seg_ids).unsqueeze(0)

将它们喂给 BERT

hidden_reps, cls_head = bert_model(token_ids, attention_mask = attn_mask,token_type_ids = seg_ids)
print(type(hidden_reps))
print(hidden_reps.shape ) #hidden states of each token in inout sequence 
print(cls_head.shape ) #hidden states of each [cls]

output:
hidden_reps size 
torch.Size([1, 15, 768])

cls_head size
torch.Size([1, 768])

哪个向量代表这里的句子嵌入?hidden_reps 还是cls_head

有没有其他方法可以从 BERT 获取句子嵌入,以便与其他句子进行相似性检查?

4个回答

实际上有一篇学术论文可以这样做。它被称为S-BERT 或 Sentence-BERT
他们还有一个易于使用的github 存储库。

哪个向量代表这里的句子嵌入?hidden_reps还是cls_head

如果我们查看 BERT 模型的forward()方法,我们会看到以下几行解释返回类型:

outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)

所以元组的第一个元素是“句子输出”——输入中的每个标记都嵌入到这个张量中。在您的示例中,您有 1 个输入序列,长度为 15 个标记,每个标记都嵌入到 768 维空间中。

元组的第二个元素是“池化输出”。您会注意到“序列”维度已被压缩,因此这表示输入序列的池化嵌入。

所以它们都代表句子嵌入。您可以将其hidden_reps视为“详细”表示,其中嵌入了每个标记。您可以将其cls_head视为一种浓缩表示,其中整个序列已被汇集。

有没有其他方法可以从 BERT 获取句子嵌入,以便与其他句子进行相似性检查?

使用该transformers库是我所知道的从 BERT 获取句子嵌入的最简单方法。

然而,有很多方法可以测量嵌入句子之间的相似性。最简单的方法是测量cls_head每个句子的池化嵌入 ( ) 之间的欧几里得距离。

有一个非常酷的工具叫做bert-as-service可以为你完成这项工作。它根据您使用的预训练模型将句子映射到固定长度的词嵌入。它还允许进行大量参数调整,这在文档中进行了广泛介绍。

在您的示例中,对应于第一个标记 ( [CLS])的隐藏状态hidden_reps可以用作句子嵌入。

相比之下,hidden states of each [cls]在我的实验中,池化输出(在您的代码中被错误地称为)被证明是句子嵌入的不良代理。