模型:

facebook/tart-full-flan-t5-xl

英文

带指示的任务感知检索

官方存储库: github.com/facebookresearch/tart

模型描述

facebook/tart-full-flan-t5-xl 是一个通过指示微调训练的多任务交叉编码器模型,训练了约40个检索任务,其初始化模型为 google/flan-t5-xl

TART-full是一个15亿个交叉编码器,可以根据查询和自然语言指示(例如,寻找回答这个问题的维基百科段落。)重新排序顶部文档。广泛使用的 BEIR LOTTE 和我们的新评估 X^2-Retrieval 的实验结果显示,TART-full通过利用自然语言指示优于先前的最先进方法。

关于建模和训练的更多细节可以在我们的论文 Task-aware Retrieval with Instructions 中找到。

安装方式

git clone https://github.com/facebookresearch/tart
pip install -r requirements.txt
cd tart/TART

如何使用?

可以通过我们定制的EncT5模型加载TART-full。

from src.modeling_enc_t5 import EncT5ForSequenceClassification
from src.tokenization_enc_t5 import EncT5Tokenizer
import torch
import torch.nn.functional as F
import numpy as np

# load TART full and tokenizer
model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
tokenizer =  EncT5Tokenizer.from_pretrained("facebook/tart-full-flan-t5-xl")
model.eval()

q = "What is the population of Tokyo?"
in_answer = "retrieve a passage that answers this question from Wikipedia"

p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022, the first decline since 1996, the metropolitan government reported Monday."
p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan."

# 1. TART-full can identify more relevant paragraph. 
features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
    scores = model(**features).logits
    normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]

print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million ... "

# 2. TART-full can identify the document that is more relevant AND follows instructions.
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
q_1 = "How many people live in Tokyo?"
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p_1, q_1], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
    scores = model(**features).logits
    normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]

print([p_1, q_1][np.argmax(normalized_scores)]) #  "How many people live in Tokyo?"