该模型使用了来自8个自然语言推理(NLI)数据集的1,279,665个假设-前提对进行训练,其中包括 MultiNLI 、 Fever-NLI 、 LingNLI 和 DocNLI (包括 ANLI 、QNLI、DUC、CNN/DailyMail和Curation)。
这是模型库中唯一经过8个NLI数据集训练的模型,包括使用超长文本进行长距离推理的DocNLI。请注意,该模型通过将“neural”和“contradiction”的类别合并为“not-entailment”来创建更多的训练数据。
基本模型是 DeBERTa-v3-small from Microsoft 。DeBERTa的v3变体通过使用不同的预训练目标显著优于之前的模型版本,请参阅原始 DeBERTa paper 的附录11以及 DeBERTa-V3 paper 。
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "MoritzLaurer/DeBERTa-v3-small-mnli-fever-docnli-ling-2c"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
该模型使用了来自8个NLI数据集的1,279,665个假设-前提对进行训练,其中包括 MultiNLI 、 Fever-NLI 、 LingNLI 和 DocNLI (包括 ANLI 、QNLI、DUC、CNN/DailyMail和Curation)。
DeBERTa-v3-small-mnli-fever-docnli-ling-2c使用Hugging Face训练器进行训练,使用了以下超参数。
training_args = TrainingArguments(
num_train_epochs=3, # total number of training epochs
learning_rate=2e-05,
per_device_train_batch_size=32, # batch size per device during training
per_device_eval_batch_size=32, # batch size for evaluation
warmup_ratio=0.1, # number of warmup steps for learning rate scheduler
weight_decay=0.06, # strength of weight decay
fp16=True # mixed precision training
)
该模型使用MultiNLI和ANLI的二进制测试集以及Fever-NLI的二进制开发集进行评估(两个类别而不是三个类别)。所使用的度量标准是准确度。
| mnli-m-2c | mnli-mm-2c | fever-nli-2c | anli-all-2c | anli-r3-2c |
|---|---|---|---|---|
| 0.927 | 0.921 | 0.892 | 0.684 | 0.673 |
关于潜在偏见,请参考原始DeBERTa论文和不同NLI数据集的文献。
如果您要引用此模型,请引用原始的DeBERTa论文、相关的NLI数据集,并包含Hugging Face模型库上该模型的链接。
如果您有任何问题或合作想法,请通过m{dot}laurer{at}vu{dot}nl或 LinkedIn 与我联系。
请注意,DeBERTa-v3是最近发布的,较旧版本的HF Transformers似乎存在运行该模型的问题(例如与标记器相关的问题)。使用Transformers==4.13可能可以解决部分问题。