这个模型是用于预测两个短文本之间的逻辑关系(包括蕴含、矛盾或中立)的Fine-tuned RuBERT。
如何运行NLI模型:
# !pip install transformers sentencepiece --quiet
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
    model.cuda()
text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
    out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
    proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764} 
 您还可以将此模型用于零样本短文本分类(仅通过标签),例如情感分析:
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
    label_texts
    tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
    with torch.inference_mode():
        result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
    proba = result[:, model.config.label2id[label]].cpu().numpy()
    if normalize:
        proba /= sum(proba)
    return proba
classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.05609814, 0.9439019 ], dtype=float32)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
# array([0.9059292 , 0.09407079], dtype=float32)
 或者,您可以使用 Huggingface pipelines 进行推理。
该模型是在一系列已从英文自动翻译成俄文的NLI数据集上训练的。
大多数数据集来自于以下来源: from the repo of Felipe Salvatore , JOCI , MNLI , MPE , SICK , SNLI 。
一些数据集来自原始来源: ANLI , NLI-style FEVER , IMPPRES 。
下表显示了五个模型在相应开发集上的ROC AUC(对一类与其他类进行比较):
| model | add_one_rte | anli_r1 | anli_r2 | anli_r3 | copa | fever | help | iie | imppres | joci | mnli | monli | mpe | scitail | sick | snli | terra | total | 
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| n_observations | 387 | 1000 | 1000 | 1200 | 200 | 20474 | 3355 | 31232 | 7661 | 939 | 19647 | 269 | 1000 | 2126 | 500 | 9831 | 307 | 101128 | 
| tiny/entailment | 0.77 | 0.59 | 0.52 | 0.53 | 0.53 | 0.90 | 0.81 | 0.78 | 0.93 | 0.81 | 0.82 | 0.91 | 0.81 | 0.78 | 0.93 | 0.95 | 0.67 | 0.77 | 
| twoway/entailment | 0.89 | 0.73 | 0.61 | 0.62 | 0.58 | 0.96 | 0.92 | 0.87 | 0.99 | 0.90 | 0.90 | 0.99 | 0.91 | 0.96 | 0.97 | 0.97 | 0.87 | 0.86 | 
| threeway/entailment | 0.91 | 0.75 | 0.61 | 0.61 | 0.57 | 0.96 | 0.56 | 0.61 | 0.99 | 0.90 | 0.91 | 0.67 | 0.92 | 0.84 | 0.98 | 0.98 | 0.90 | 0.80 | 
| vicgalle-xlm/entailment | 0.88 | 0.79 | 0.63 | 0.66 | 0.57 | 0.93 | 0.56 | 0.62 | 0.77 | 0.80 | 0.90 | 0.70 | 0.83 | 0.84 | 0.91 | 0.93 | 0.93 | 0.78 | 
| facebook-bart/entailment | 0.51 | 0.41 | 0.43 | 0.47 | 0.50 | 0.74 | 0.55 | 0.57 | 0.60 | 0.63 | 0.70 | 0.52 | 0.56 | 0.68 | 0.67 | 0.72 | 0.64 | 0.58 | 
| threeway/contradiction | 0.71 | 0.64 | 0.61 | 0.97 | 1.00 | 0.77 | 0.92 | 0.89 | 0.99 | 0.98 | 0.85 | |||||||
| threeway/neutral | 0.79 | 0.70 | 0.62 | 0.91 | 0.99 | 0.68 | 0.86 | 0.79 | 0.96 | 0.96 | 0.83 | 
评估(以及用于训练 tiny 和 twoway 模型)还使用了一些额外的数据集:从 the repo of Felipe Salvatore 中提取和翻译的 Add-one RTE 、 CoPA 、 IIE 和 SCITAIL ,以及从原始来源翻译的 HELP 和 MoNLI ,以及俄文 TERRa 。