模型:
ITG/DialoGPT-medium-spanish-chitchat
这是一个改进的transformer-decoder模型,适用于单轮对话任务,用于处理西班牙语。我们对来自Microsoft的3.45亿参数的模型进行了微调,采用了因果语言建模(CLM)的目标函数。
我们使用了 Bot Framework Tools repository 中可用的数据集之一。我们处理了 the professional-styled personality chat dataset in Spanish ,该文件可供使用 to download here 。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
CHAT_TURNS = 5
MAX_LENGTH = 1000
model = AutoModelForCausalLM.from_pretrained('ITG/DialoGPT-medium-spanish-chitchat')
tokenizer = AutoTokenizer.from_pretrained('ITG/DialoGPT-medium-spanish-chitchat')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for i in range(CHAT_TURNS):
user_input = input(f"Step - {i} >> user prompt ->")
with torch.no_grad():
# User turn, where "user_input" is the question (single-turn dialogue task)
user_inputs_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
user_inputs_ids = user_inputs_ids.to(device)
# The chat history adds the generated tokens for the answer
chat_history = model.generate(user_inputs_ids, max_length=MAX_LENGTH, pad_token_id=tokenizer.eos_token_id)
# decode just the last generated output tokens from the model (do not include the user prompt again)
step_model_answer = tokenizer.decode(chat_history[:, user_inputs_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"Step - {i} >> DialoGPT-spanish model answer -> {step_model_answer}")
| Role | Response |
|---|---|
| User | Hola, buenos dias |
| Bot | Buenos días |
| User | ¿Qué tal estás? |
| Bot | Fenomenal, gracias. |
| User | ¡Tengo hambre! |
| Bot | Quizá picar algo te ayude. |
| User | Estoy triste |
| Bot | Siento escuchar eso. |
| User | ¿Tienes personalidad? |
| Bot | Soy un bot. |
| User | ¡Cuenta un chiste! |
| Bot | En realidad no tengo tanta gracia. |
| User | Jajajajajajajaja |
| Bot | ¡Me alegro de que te rías! |
| Hyper-parameter | Value |
|---|---|
| Validation partition (%) | 20% |
| Training batch size | 8 |
| Learning rate | 5e-4 |
| Max training epochs | 20 |
| Warmup training steps (%) | 6% |
| Weight decay | 0.01 |
| Optimiser (beta1, beta2, epsilon) | AdamW (0.9, 0.999, 1e-08) |
| Monitoring metric (delta, patience) | Validation loss (0.1, 3) |
如果您想微调自己的对话模型,我们建议您从 DialoGPT model 开始。您可以查看 original GitHub repository 。
由于我们的方法可以为任何Unicode字符串分配概率,这使得我们可以在任何预处理、分词或词汇大小的数据集上对我们的语言模型进行评估。