模型:
setu4993/LEALLA-small
LEALLA是从 LaBSE 个训练模型中精简而来的一套支持109种语言的轻量级语言无关句子嵌入模型。该模型可用于获取多语言句子嵌入以及双语文本检索。
此模型从TF Hub的v1模型迁移而来。两个版本的模型产生的嵌入结果均为 equivalent 。不过,对于某些语言(如日语),LEALLA模型在比较嵌入和相似度时似乎需要更高的容差。
使用该模型:
import torch
from transformers import BertModel, BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("setu4993/LEALLA-small")
model = BertModel.from_pretrained("setu4993/LEALLA-small")
model = model.eval()
english_sentences = [
"dog",
"Puppies are nice.",
"I enjoy taking long walks along the beach with my dog.",
]
english_inputs = tokenizer(english_sentences, return_tensors="pt", padding=True)
with torch.no_grad():
english_outputs = model(**english_inputs)
要获得句子嵌入,请使用pooler输出:
english_embeddings = english_outputs.pooler_output
其他语言的输出:
italian_sentences = [
"cane",
"I cuccioli sono carini.",
"Mi piace fare lunghe passeggiate lungo la spiaggia con il mio cane.",
]
japanese_sentences = ["犬", "子犬はいいです", "私は犬と一緒にビーチを散歩するのが好きです"]
italian_inputs = tokenizer(italian_sentences, return_tensors="pt", padding=True)
japanese_inputs = tokenizer(japanese_sentences, return_tensors="pt", padding=True)
with torch.no_grad():
italian_outputs = model(**italian_inputs)
japanese_outputs = model(**japanese_inputs)
italian_embeddings = italian_outputs.pooler_output
japanese_embeddings = japanese_outputs.pooler_output
对于句子之间的相似度计算,建议在计算相似度之前进行L2范数归一化:
import torch.nn.functional as F
def similarity(embeddings_1, embeddings_2):
normalized_embeddings_1 = F.normalize(embeddings_1, p=2)
normalized_embeddings_2 = F.normalize(embeddings_2, p=2)
return torch.matmul(
normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1)
)
print(similarity(english_embeddings, italian_embeddings))
print(similarity(english_embeddings, japanese_embeddings))
print(similarity(italian_embeddings, japanese_embeddings))
关于数据、训练、评估和性能指标的详细信息,请参阅 original paper 。
@misc{mao2023lealla,
title={LEALLA: Learning Lightweight Language-agnostic Sentence Embeddings with Knowledge Distillation},
author={Zhuoyuan Mao and Tetsuji Nakagawa},
year={2023},
eprint={2302.08387},
archivePrefix={arXiv},
primaryClass={cs.CL}
}