英文

bart-large-mnli

这是在 MultiNLI (MNLI) 数据集上训练后的检查点 bart-large

关于这个模型的额外信息:

NLI基于零样本文本分类

Yin et al. 提出了一种使用预训练的NLI模型作为现成的零样本序列分类器的方法。该方法通过将待分类的序列作为NLI前提,并为每个候选标签构建一个假设。例如,如果我们想评估一个序列是否属于"政治"类,我们可以构建一个假设:这段文本是关于政治的。然后将蕴涵和矛盾的概率转换为标签概率。

这种方法在许多情况下效果出人意料地好,尤其是在使用像BART和Roberta这样的较大预训练模型时。有关此方法和其他零样本方法的更详尽介绍,请参阅 this blog post ;有关使用此模型进行零样本分类的示例的代码片段,请参阅下面使用Hugging Face内置流水线和本机Transformers/PyTorch代码的示例。

With the zero-shot classification pipeline

可以使用zero-shot-classification流水线加载模型,如下所示:

from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

然后,您可以使用此流水线将序列分类为您指定的任何类名。

sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
#{'labels': ['travel', 'dancing', 'cooking'],
# 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
# 'sequence': 'one day I will see the world'}

如果有多个候选标签可能是正确的,请传递multi_class=True以独立计算每个类别:

candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
#{'labels': ['travel', 'exploration', 'dancing', 'cooking'],
# 'scores': [0.9945111274719238,
#  0.9383890628814697,
#  0.0057061901316046715,
#  0.0018193122232332826],
# 'sequence': 'one day I will see the world'}
With manual PyTorch
# pose sequence as a NLI premise and label as a hypothesis
from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

premise = sequence
hypothesis = f'This example is {label}.'

# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
                     truncation_strategy='only_first')
logits = nli_model(x.to(device))[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]