模型:
facebook/data2vec-audio-base-960h
基于Librispeech的960小时16kHz采样语音音频进行预训练和微调的基础模型。使用该模型时,请确保语音输入也以16Khz进行采样。
作者:Alexei Baevski,Wei-Ning Hsu,Qiantong Xu,Arun Babu,Jiatao Gu,Michael Auli
摘要
尽管跨模态的自监督学习的总体思想是相同的,但实际的算法和目标因不同模态而异。为了让我们更接近通用的自监督学习,我们提出了data2vec,这是一个框架,可以在语音、自然语言处理或计算机视觉中使用相同的学习方法。其核心思想是在自蒸馏设置中使用标准Transformer架构,基于输入的屏蔽视图来预测完整输入数据的潜在表示。data2vec不是预测特定于模态的目标,例如单词、视觉标记或人声单位,而是预测包含来自整个输入的信息的上下文化潜在表示。在语音识别、图像分类和自然语言理解的主要基准测试上进行的实验表明,data2vec取得了新的最先进或与主导方法相比具有竞争力的性能。
原始模型可在 https://github.com/pytorch/fairseq/tree/main/examples/data2vec 下找到。
 
  更多信息,请参阅 official paper 。
要转录音频文件,可以将模型用作独立的声学模型,如下所示:
 from transformers import Wav2Vec2Processor, Data2VecForCTC
 from datasets import load_dataset
 import torch
 
 # load model and processor
 processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
 model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
     
 # load dummy dataset and read soundfiles
 ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
 
 # tokenize
 input_values = processor(ds[0]["audio"]["array"],, return_tensors="pt", padding="longest").input_values  # Batch size 1
 
 # retrieve logits
 logits = model(input_values).logits
 
 # take argmax and decode
 predicted_ids = torch.argmax(logits, dim=-1)
 transcription = processor.batch_decode(predicted_ids)
 此代码片段显示如何在LibriSpeech的“clean”和“other”测试数据上评估facebook/data2vec-audio-base-960h。
 from transformers import Wav2Vec2Processor, Data2VecForCTC
 from datasets import load_dataset
 import torch
 from jiwer import wer
 
 # load model and processor
 processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h").to("cuda")
 model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
 
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
 结果(WER):
| "clean" | "other" | 
|---|---|
| 2.77 | 7.08 |