模型:

facebook/rag-sequence-base

英文

RAG

这是一份非微调版本的RAG-Sequence模型,根据Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 创建。

RAG由问题编码器、检索器和生成器组成。检索器应该是一个RagRetriever实例。问题编码器可以是可以使用AutoModel加载的任何模型,生成器可以是可以使用AutoModelForSeq2SeqLM加载的任何模型。

这个模型是一个非微调的RAG-Sequence模型,创建过程如下:

from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer

model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")

请注意,该模型是非区分大小写的,因此所有大写输入字母都会转换为小写。

用法:

注意:模型默认使用虚拟检索器。通过设置config.index_name="legacy"和config.use_dummy_dataset=False可以使用完整的检索器获得更好的结果。可以按以下方式对模型进行微调:

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") 

outputs = model(input_dict["input_ids"], labels=input_dict["labels"])

loss = outputs.loss

# train on loss