模型:
airesearch/wav2vec2-large-xlsr-53-th
在泰语上对wav2vec2-large-xlsr-53进行微调
我们使用 Common Voice Corpus 7.0 的泰语示例对 wav2vec2-large-xlsr-53 进行微调,基于 Fine-tuning Wav2Vec2 for English ASR 。笔记本和脚本可以在 vistec-ai/wav2vec2-large-xlsr-53-th 中找到。预训练模型和处理器可以在 airesearch/wav2vec2-large-xlsr-53-th 中找到。
在 robust-speech-event 的eval.py中添加音节划分器syllable_tokenize,词语划分器word_tokenize ( PyThaiNLP )和 deepcut 划分器
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
| WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER | |
|---|---|---|---|---|
| Only Tokenization | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
| Cleaning rules and Tokenization | TBD | TBD | TBD | TBD |
#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
#function to resample to 16_000
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
#infer
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
Common Voice Corpus 7.0]( https://commonvoice.mozilla.org/en/datasets )包含了5GB的泰语验证数据,总时长为255小时。我们使用pythainlp.tokenize.word_tokenize进行预分词。我们使用notebooks/cv-preprocess.ipynb中的清洗规则进行数据预处理,通过 ekapolc/Thai_commonvoice_split 进行去重和拆分,以避免在 Common Voice Corpus 7.0 中清洗后进行随机拆分的数据泄漏,同时保留大部分数据作为训练集。数据集加载脚本是scripts/th_common_voice_70.py。您可以使用train_cleand.tsv,validation_cleaned.tsv和test_cleaned.tsv与该脚本一起使用,以获得与我们相同的拆分。生成的数据集如下:
DatasetDict({
train: Dataset({
features: ['path', 'sentence'],
num_rows: 86586
})
test: Dataset({
features: ['path', 'sentence'],
num_rows: 2502
})
validation: Dataset({
features: ['path', 'sentence'],
num_rows: 3027
})
})
我们在单个V100 GPU上使用以下配置进行微调,并选择验证损失最低的检查点。微调脚本是scripts/wav2vec2_finetune.py
# create model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
我们使用 PyThaiNLP 2.3.1和 deepcut 对分词为单词的测试集进行WER和CER评估。我们还测量了使用 TNC n-gram进行拼写修正时的性能。评估代码可以在notebooks/wav2vec2_finetuning_tutorial.ipynb中找到。基准测试在test-unique拆分上执行。
| WER PyThaiNLP 2.3.1 | WER deepcut | CER | |
|---|---|---|---|
| 12321321 | 23.04 | 7.57 | |
| Ours without spell correction | 13.634024 | 8.152052 | 2.813019 |
| Ours with spell correction | 17.996397 | 14.167975 | 5.225761 |
| Google Web Speech API※ | 13.711234 | 10.860058 | 7.357340 |
| Microsoft Bing Speech API※ | 12.578819 | 9.620991 | 5.016620 |
| Amazon Transcribe※ | 21.86334 | 14.487553 | 7.077562 |
| NECTEC AI for Thai Partii API※ | 20.105887 | 15.515631 | 9.551027 |
※不使用Common Voice 7.0数据进行API微调