英文

wav2vec2-large-xlsr-53-th

在泰语上对wav2vec2-large-xlsr-53进行微调

Common Voice 7.0

Read more on our blog

我们使用 Common Voice Corpus 7.0 的泰语示例对 wav2vec2-large-xlsr-53 进行微调,基于 Fine-tuning Wav2Vec2 for English ASR 。笔记本和脚本可以在 vistec-ai/wav2vec2-large-xlsr-53-th 中找到。预训练模型和处理器可以在 airesearch/wav2vec2-large-xlsr-53-th 中找到。

鲁棒语音事件

robust-speech-event 的eval.py中添加音节划分器syllable_tokenize,词语划分器word_tokenize ( PyThaiNLP )和 deepcut 划分器

> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer

在Common Voice 7上评估结果“测试”:

WER PyThaiNLP 2.3.1 WER deepcut SER CER
Only Tokenization 0.9524% 2.5316% 1.2346% 0.1623%
Cleaning rules and Tokenization TBD TBD TBD TBD

用法

#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")

#function to resample to 16_000
def speech_file_to_array_fn(batch, 
                            text_col="sentence", 
                            fname_col="path",
                            resampling_to=16000):
    speech_array, sampling_rate = torchaudio.load(batch[fname_col])
    resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
    batch["speech"] = resampler(speech_array)[0].numpy()
    batch["sampling_rate"] = resampling_to
    batch["target_text"] = batch[text_col]
    return batch

#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)

#infer
with torch.no_grad():
    logits = model(inputs.input_values,).logits

predicted_ids = torch.argmax(logits, dim=-1)

print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])

>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']

数据集

Common Voice Corpus 7.0]( https://commonvoice.mozilla.org/en/datasets )包含了5GB的泰语验证数据,总时长为255小时。我们使用pythainlp.tokenize.word_tokenize进行预分词。我们使用notebooks/cv-preprocess.ipynb中的清洗规则进行数据预处理,通过 ekapolc/Thai_commonvoice_split 进行去重和拆分,以避免在 Common Voice Corpus 7.0 中清洗后进行随机拆分的数据泄漏,同时保留大部分数据作为训练集。数据集加载脚本是scripts/th_common_voice_70.py。您可以使用train_cleand.tsv,validation_cleaned.tsv和test_cleaned.tsv与该脚本一起使用,以获得与我们相同的拆分。生成的数据集如下:

DatasetDict({
    train: Dataset({
        features: ['path', 'sentence'],
        num_rows: 86586
    })
    test: Dataset({
        features: ['path', 'sentence'],
        num_rows: 2502
    })
    validation: Dataset({
        features: ['path', 'sentence'],
        num_rows: 3027
    })
})

训练

我们在单个V100 GPU上使用以下配置进行微调,并选择验证损失最低的检查点。微调脚本是scripts/wav2vec2_finetune.py

# create model
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
    output_dir="../data/wav2vec2-large-xlsr-53-thai",
    group_by_length=True,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=16,
    metric_for_best_model='wer',
    evaluation_strategy="steps",
    eval_steps=1000,
    logging_strategy="steps",
    logging_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    num_train_epochs=100,
    fp16=True,
    learning_rate=1e-4,
    warmup_steps=1000,
    save_total_limit=3,
    report_to="tensorboard"
)

评估

我们使用 PyThaiNLP 2.3.1和 deepcut 对分词为单词的测试集进行WER和CER评估。我们还测量了使用 TNC n-gram进行拼写修正时的性能。评估代码可以在notebooks/wav2vec2_finetuning_tutorial.ipynb中找到。基准测试在test-unique拆分上执行。

WER PyThaiNLP 2.3.1 WER deepcut CER
12321321 23.04 7.57
Ours without spell correction 13.634024 8.152052 2.813019
Ours with spell correction 17.996397 14.167975 5.225761
Google Web Speech API※ 13.711234 10.860058 7.357340
Microsoft Bing Speech API※ 12.578819 9.620991 5.016620
Amazon Transcribe※ 21.86334 14.487553 7.077562
NECTEC AI for Thai Partii API※ 20.105887 15.515631 9.551027

※不使用Common Voice 7.0数据进行API微调

许可证

cc-by-sa 4.0

鸣谢