模型:

facebook/rag-token-nq

英文

RAG

这是Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 的RAG-Token模型。

该模型是一个非大小写敏感模型,即将大写字母转换为小写字母。

该模型包括一个question_encoder、一个retriever和一个generator。检索器从链接上方的wiki_dpr训练数据集中提取相关段落。question_encoder和retriever基于facebook/dpr-question_encoder-single-nq-base和facebook/bart-large,它们在wiki_dpr QA数据集上进行了联合微调,以端到端方式工作。

用法:

请注意:在下面的用法示例中,只使用了wiki_dpr的虚拟retriever,因为完整的遗留索引需要超过75GB的RAM。该模型可以回答任何事实型问题,方法如下:

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") 

generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 

# should give michael phelps => sounds reasonable