OpenAI Whisper在CPU上的推理性能比较

2024年11月29日 由 alex 发表 48 0

Whisper 是 OpenAI 推出的一款基于 Transformer 的开源 ASR 模型。在我的案例中,该模型针对患有言语障碍的人的语音记录数据集进行了微调。


我已经尝试了以下选项来进行 CPU 推理:

  1. HuggingFace管道
  2. ONNX运行时
  3. OpenVino运行时
  4. PyTorch推理


总结

以下是最终结果:

  • PyTorch(16 核)≈ 3.5 秒
  • OpenVino int4 ≈ 4.0 秒
  • OpenVino,int8≈4.2秒
  • PyTorch(8 核)≈ 4.2 秒
  • PyTorch(4 核)≈ 8.0 秒
  • HF管道≈18.0秒


使用 HuggingFace 管道进行私下推理

由于我们的模型是使用transformers库进行预先训练并存储在 HuggingFace 中心的,因此第一个也是最直接的选择是使用内置管道。


class WhisperService:
    _initialized = False
    def __init__(self, language='en'):
        if not WhisperService._initialized:
            os.environ["TRANSFORMERS_VERBOSITY"] = "error"
            transformers_log.set_verbosity_error()
            self.model_name = utils.MODEL_NAME
            self.language = language
            self.task = utils.TASK
            try:
                # Initialize model and related components
                log.info("Starting Whisper service...")
                self.peft_config = self.generate_model_config()
                self.model = self.get_whisper_model_from_hf(self.peft_config)
                self.tokenizer = self.create_tokenizer(self.peft_config)
                self.processor = self.create_processor(self.peft_config)
                self.pipeline_asr, self.forced_decoder_ids = self.create_whisper_pipeline(
                    self.model, self.tokenizer, self.processor
                )
                WhisperService._initialized = True
                log.info("Whisper service started with success!")
            except Exception as e:
                log.error(f"Error during Whisper service init: {str(e)}")
                raise
    def generate_model_config(self) -> PeftConfig:
        """
        """
        try:
            login(token=os.environ['API_TOKEN'])
            config = PeftConfig.from_pretrained(self.model_name)
            log.info("Model config generated")
            return config
        except Exception as e:
            log.error(f"Error during model config generation: {str(e)}")
            raise
    def get_whisper_model_from_hf(self, peft_config: PeftConfig) -> PeftModel:
        """
        """
        try:
            model = WhisperForConditionalGeneration.from_pretrained(
                    peft_config.base_model_name_or_path
                )
            # Check if GPU is available
            if torch.cuda.is_available():
                log.info("Model loaded on GPU")
            else:
                log.info("Model loaded on CPU")
            model = PeftModel.from_pretrained(model, self.model_name)
            log.info("Whisper model configured with PeftModel")
            return model
        except Exception as e:
            log.error(f"Error during Whisper model loading: {str(e)}")
            raise
    def create_processor(self, peft_config: PeftConfig) -> WhisperProcessor:
        """
        """
        try:
            processor = WhisperProcessor.from_pretrained(
                peft_config.base_model_name_or_path,
                language=self.language,
                task=self.task
            )
            log.info("WhisperProcessor created")
            return processor
        except Exception as e:
            log.error(f"Error during WhisperProcessor creation: {str(e)}")
            raise
    def create_tokenizer(self, peft_config: PeftConfig) -> WhisperTokenizer:
        """
        """
        try:
            tokenizer = WhisperTokenizer.from_pretrained(
                peft_config.base_model_name_or_path,
                language=self.language,
                task=self.task
            )
            log.info("WhisperTokenizer created")
            return tokenizer
        except Exception as e:
            log.error(f"Error during WhisperTokenizer creation: {str(e)}")
            raise
    def create_whisper_pipeline(self, model: PreTrainedModel, tokenizer: WhisperTokenizer,
                                processor: WhisperProcessor) -> tuple:
        """
        """
        try:
            feature_extractor = processor.feature_extractor
            pipe_lora = AutomaticSpeechRecognitionPipeline(
                model=model,
                tokenizer=tokenizer,
                feature_extractor=feature_extractor
            )
            forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
            log.info("Pipeline created")
            return pipe_lora, forced_decoder_ids
        except Exception as e:
            log.error(f"Error during Pipeline creation: {str(e)}")
            raise
    async def transcribe(self, audio_path: str) -> str:
        """
        """
        try:
            loop = asyncio.get_event_loop()
            log.info(f"Transcribing the following file audio: {audio_path}")
            with torch.cuda.amp.autocast():
                text = await loop.run_in_executor(
                    None,
                    lambda:
                    self.pipeline_asr(audio_path, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids},
                                      max_new_tokens=255)["text"]
                )
            log.info("Transcription completed!")
            return text
        except Exception as e:
            log.error(f"Error during transcription: {str(e)}")
            raise


我们从HuggingFace平台上获取模型(utils.MODEL_NAME是HuggingFace模型的标识符,例如“miosipof/asr_EN_medium_v1”)。


通过以下代码建立处理流程:


pipe_lora = AutomaticSpeechRecognitionPipeline(
                model=model,
                tokenizer=tokenizer,
                feature_extractor=feature_extractor
            )


ONNX运行时


模型转换为ONNX格式

让我们导入一些库:


from onnxruntime.quantization import quantize_dynamic, QuantType
import onnx
import numpy as np
import onnxruntime as ort
import torchaudio


接下来,我们将使用transformers optimum库和CLI将模型从HuggingFace转换为ONNX格式:


pip install optimum[exporters]


optimum-cli export onnx --model local_path --task trascribe local_model_folder/


这将在local_path路径下的原始模型基础上,在local_model_folder中创建一系列文件。


让我们来设置一个ONNX会话:


session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
session_options.intra_op_num_threads = 44
session_options.inter_op_num_threads = 16


我们将分别处理编码器和解码器:


sess_encoder = ort.InferenceSession("./path_to/encoder_q.onnx")"./path_to/encoder_q.onnx")
sess_decoder = ort.InferenceSession("./path_to/decoder_q.onnx")


为了提高性能,我们定义了一个模型量化函数,然后将其应用于编码器和解码器:


def quantize_onnx_model(onnx_model_path, quantized_model_path):
    onnx_opt_model = onnx.load(onnx_model_path)
    quantize_dynamic(onnx_model_path,
    quantized_model_path,
    weight_type=QuantType.QUInt8) #chnage QInt8 to QUInt8
quantize_onnx_model("./path_to/encoder.onnx","./path_to/encoder_q.onnx")
quantize_onnx_model("./path_to/decoder.onnx","./path_to/decoder_q.onnx")


使用ONNX模型进行推理

让我们初始化处理器和分词器:


processor = WhisperProcessor.from_pretrained("./path_to/q_whisper_onnx")"./path_to/q_whisper_onnx")
# tokenizer = processor.tokenizer
tokenizer = whisper.decoding.get_tokenizer(
    model.is_multilingual, 
    task="transcribe", 
    language="en",
)


音频预处理脚本(类似于Whisper的log_mel_spectrogram()函数),用于将.wav文件转换为log_mel频谱图数组:


def preprocessing_torchaudio(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
    mel = processor.feature_extractor(waveform[0], sampling_rate=16000).input_features
    return torch.tensor(mel, dtype=torch.float32)


对于一个样本.wav文件,其音频数组x_mel将是:


x_mel = preprocessing_librosa("./path_to/audio.wav")"./path_to/audio.wav")


最后,使用我们量化后的ONNX模型进行序列编码和解码的自定义循环:


max_tokens = 448448
out_encoder, = sess_encoder.run(["last_hidden_state"], {"input_features": x_mel.numpy()})
next_token = tokenizer.sot
# next_token = "<|startoftranscript|>"
while x_tokens.shape[1] <= max_tokens and next_token != tokenizer.eot:
    out_decoder, = sess_decoder.run(
        ["logits"], 
        {
            "input_ids": x_tokens.numpy(), 
            "encoder_hidden_states": out_encoder,
        },
    )
    next_token = out_decoder[0, -1].argmax()
    next_token = torch.tensor(next_token)
    
    print(next_token,next_token.shape,x_tokens.shape)
    
    x_tokens = torch.concat(
        [x_tokens, next_token.reshape(1, 1)], 
        axis=1,
    )

print(tokenizer.decode(x_tokens[0]))


我把代码留在了这种不太理想的格式,因为ONNX的推理性能总是比通过OpenVino或PyTorch进行推理要差很多,这可能是因为ONNX格式最初是为卷积神经网络开发的,可能不是优化transformer的最佳选择。


OpenVino运行时

使用OpenVino进行推理的实现要简单得多。


首先,导入一些必要的库:


import os
from transformers import WhisperProcessor, logging as transformers_log
from optimum.intel.openvino import OVModelForSpeechSeq2Seq
import torchaudio
import torch
import numpy as np
import time
from src import log
from src.utils import utils
import asyncio


模型转换为OpenVino格式

我们将使用transformers optimum库将我们的HuggingFace模型导出为OpenVino格式(你可以将openai/whisper-medium替换为你自己的模型或HuggingFace平台上托管的其他任何Whisper模型):


[openvino,nncf]optimum-cli export openvino --model openai/whisper-medium --weight-format int8 asr_openvino_int8


注意,在导出时我们使用了int8量化。我也尝试过int4量化,但在我的情况下,它对转录质量影响很大。


以下是我们将用于获取OpenVino模型的方法:


   def get_openvino_model(self):def get_openvino_model(self):
        ov_config = {"CACHE_DIR": ""}
        self.model = OVModelForSpeechSeq2Seq.from_pretrained(self.ov_model_name, ov_config=ov_config, compile=False)
        log.info("OpenVino model loaded from " + str(self.ov_model_name))
        try
            ov_model_path = Path("src/model/" + self.model_name.replace("/", "_"))
            ov_config = {"CACHE_DIR": ""}
        
            if not ov_model_path.exists():
                self.model = OVModelForSpeechSeq2Seq.from_pretrained(
                    self.model_name,
                    ov_config=ov_config,
                    export=True,
                    compile=False,
                    load_in_8bit=False,
                )
                self.model.half()
                self.model.save_pretrained(ov_model_path)
                log.info("HF model converted to OpenVino and saved in " + str(ov_model_path))
            else:
                self.model = OVModelForSpeechSeq2Seq.from_pretrained(ov_model_path, ov_config=ov_config, compile=False)
                log.info("OpenVino model loaded from " + str(ov_model_path))
        
        except Exception as e:
            log.error(f"Error during OpenVino model loading: {str(e)}")
            raise
        return self.model


在这里,self.ov_model_name 将是我们之前用于optimum CLI命令的 asr_openvino_int8(加上其路径)。我使用了一个不太优雅的 self.model_name.replace("/", "_") 函数来将HuggingFace上的URL转换为模型名称。


接下来,需要编译OpenVino模型,因为它将直接通过OpenVino运行时加载:


    def compile_openvino_model(self):def compile_openvino_model(self):
        """
        """
        try:
            if torch.cuda.is_available():
                log.info("Model loaded on GPU")
                self.device = "GPU"
            else:
                log.info("Model loaded on CPU")
                self.device = "CPU"
            self.model.to(self.device)
            self.model.compile()
            log.info("OpenVino model compiled successfully")
        except Exception as e:
            log.error(f"Error during OpenVino model compilation: {str(e)}")
            raise
        return self.model


使用OpenVino模型进行推理

现在,我们定义两个辅助函数来创建用于编码的Whisper处理器(与前向传播相比,这所花费的时间微不足道)以及音频预处理:


    def create_processor(self):def create_processor(self):
        """
        """
        try:
            processor = WhisperProcessor.from_pretrained(
                self.model_name,
                language=self.language,
                task=self.task
            )
            log.info("WhisperProcessor created")
            return processor
        except Exception as e:
            log.error(f"Error during WhisperProcessor creation: {str(e)}")
            raise

    def preprocess_audio(self, waveform):
        """
        """
        # compute log-Mel input features from input audio array
        audio_features = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features[0]
        audio_features = torch.tensor(np.array([audio_features]))
        return audio_features


最后,定义管道,即一个用于转录的异步函数——类似于HuggingFace管道的实现:


    def openvino_pipeline(self,audio_path):def openvino_pipeline(self,audio_path):
        print("1 - starting audio load:", time.time())
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
        print("2 - starting preprocessing:", time.time())
        audio_features = self.preprocess_audio(waveform)
        print("3 - starting forward pass:", time.time())
        predicted_ids = self.model.generate(audio_features, max_new_tokens=224)
        print("4 - starting decoding:", time.time())
        transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
        return transcription[0]

    async def transcribe(self, audio_path: str) -> str:
        """
        """
        try:
            loop = asyncio.get_event_loop()
            log.info(f"Transcribing the following file audio: {audio_path}")
            print("0 - starting the loop:",time.time())
            text = await loop.run_in_executor(
                None,
                lambda: self.openvino_pipeline(audio_path)
                )
            print("5 - all done:", time.time())
            log.info("Transcription completed!")
            return text
        except Exception as e:
            log.error(f"Error during transcription: {str(e)}")
            raise


以下是OpenVino推理类的完整代码,


class OpenVinoService:
    _initialized = False
    def __init__(self, language='en'):
        if not OpenVinoService._initialized:
            os.environ["TRANSFORMERS_VERBOSITY"] = "error"
            transformers_log.set_verbosity_error()
            self.model_name = utils.MERGED_MODEL_NAME
            self.ov_model_name = utils.OV_MODEL
            self.language = language
            self.task = utils.TASK
            self.device = "CPU"
            self.sr = utils.SAMPLING_RATE
            try:
                # Initialize model and related components
                log.info("Starting OpenVino service...")
                self.model = self.get_openvino_model()
                self.compile_openvino_model()
                self.processor = self.create_processor()
                OpenVinoService._initialized = True
                log.info("OpenVino service started with success!")
            except Exception as e:
                log.error(f"Error during OpenVino service init: {str(e)}")
                raise
    def get_openvino_model(self):
        """
        """
        ov_config = {"CACHE_DIR": ""}
        self.model = OVModelForSpeechSeq2Seq.from_pretrained(self.ov_model_name, ov_config=ov_config, compile=False)
        log.info("OpenVino model loaded from " + str(self.ov_model_name))
        try:
            ov_model_path = Path("src/model/" + self.model_name.replace("/", "_"))
            ov_config = {"CACHE_DIR": ""}
        
            if not ov_model_path.exists():
                self.model = OVModelForSpeechSeq2Seq.from_pretrained(
                    self.model_name,
                    ov_config=ov_config,
                    export=True,
                    compile=False,
                    load_in_8bit=False,
                )
                self.model.half()
                self.model.save_pretrained(ov_model_path)
                log.info("HF model converted to OpenVino and saved in " + str(ov_model_path))
            else:
                self.model = OVModelForSpeechSeq2Seq.from_pretrained(ov_model_path, ov_config=ov_config, compile=False)
                log.info("OpenVino model loaded from " + str(ov_model_path))
        
        except Exception as e:
            log.error(f"Error during OpenVino model loading: {str(e)}")
            raise
        return self.model

    def compile_openvino_model(self):
        """
        """
        try:
            if torch.cuda.is_available():
                log.info("Model loaded on GPU")
                self.device = "GPU"
            else:
                log.info("Model loaded on CPU")
                self.device = "CPU"
            self.model.to(self.device)
            self.model.compile()
            log.info("OpenVino model compiled successfully")
        except Exception as e:
            log.error(f"Error during OpenVino model compilation: {str(e)}")
            raise
        return self.model

    def create_processor(self):
        """
        """
        try:
            processor = WhisperProcessor.from_pretrained(
                self.model_name,
                language=self.language,
                task=self.task
            )
            log.info("WhisperProcessor created")
            return processor
        except Exception as e:
            log.error(f"Error during WhisperProcessor creation: {str(e)}")
            raise

    def preprocess_audio(self, waveform):
        """
        """
        # compute log-Mel input features from input audio array
        audio_features = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features[0]
        audio_features = torch.tensor(np.array([audio_features]))
        return audio_features

    def openvino_pipeline(self,audio_path):
        print("1 - starting audio load:", time.time())
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
        print("2 - starting preprocessing:", time.time())
        audio_features = self.preprocess_audio(waveform)
        print("3 - starting forward pass:", time.time())
        predicted_ids = self.model.generate(audio_features, max_new_tokens=224)
        print("4 - starting decoding:", time.time())
        transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
        return transcription[0]

    async def transcribe(self, audio_path: str) -> str:
        """
        """
        try:
            loop = asyncio.get_event_loop()
            log.info(f"Transcribing the following file audio: {audio_path}")
            print("0 - starting the loop:",time.time())
            text = await loop.run_in_executor(
                None,
                lambda: self.openvino_pipeline(audio_path)
                )
            print("5 - all done:", time.time())
            log.info("Transcription completed!")
            return text
        except Exception as e:
            log.error(f"Error during transcription: {str(e)}")
            raise


PyTorch推理

通过直接实现Whisper的PyTorch代码来进行推理包括几个步骤:

  1. 在我的情况下,用于推理的微调模型位于HuggingFace平台上,所以我首先需要从那里获取它;
  2. 我们还需要从OpenAI的GitHub上获取原始的Whisper基础模型(其大小应与我们的微调模型相对应——在我的情况下是Whisper-Medium);
  3. HuggingFace上的微调模型需要映射到OpenAI的格式;
  4. 我们的预训练权重将应用于基础模型;
  5. 然后,我们可以简单地将模型设置为评估模式并运行推理。


让我们从HuggingFace平台上获取模型开始:


  def get_hf_model(self):def get_hf_model(self):
        """
        """
        try:
            merged_model = WhisperForConditionalGeneration.from_pretrained(self.model_name)
            pt_model_name = os.path.basename(self.model_name) + ".pth"
            pt_dir_name = os.path.join("assets","pt_models")
            self.pretrained_model_path = os.path.join(pt_dir_name, pt_model_name)
            if not os.path.exists(pt_dir_name):
                os.makedirs(pt_dir_name)
                log.info(f"Directory {pt_dir_name} created and will be used to store PyTorch models")
            else:
                log.info(f"Directory {pt_dir_name} exists, using it to save PyTorch model")
            torch.save(merged_model.state_dict(), self.pretrained_model_path)
            log.info(f"HF model saved to {self.pretrained_model_path} in PyTorch format for conversion")
        except Exception as e:
            log.error(f"Error during HuggingFace model loading: {str(e)}")
            raise


在这里,self.model_name代表我在HuggingFace上的模型ID(请注意,它应该是完整的合并模型,而不是适配器)。


从HuggingFace到PyTorch的模型转换

在transformers库实现的Whisper中使用的层名称与OpenAI原始仓库中使用的层名称不同。


从HuggingFace到OpenAI的映射函数是这样的:


   def map_hf_to_pt(self,pretrained_weights):def map_hf_to_pt(self,pretrained_weights):
        def rename_key(key):
            new_key = key
            for k, v in self.mapping:
                new_key = new_key.replace(k, v)
            return new_key
        # Rename the keys in the state_dict
        updated_weights = {rename_key(k): v for k, v in pretrained_weights.items()}
        updated_weights.pop('proj_out.weight', None)
        return updated_weights


在这里,self.mapping 是一个映射字典:


self.mapping = [ ('model.', ''),'model.', ''),
           ('decoder.layers', 'decoder.blocks'), 
           ('encoder.layers', 'encoder.blocks'), 
           
           ('encoder.embed_positions.weight', 'encoder.positional_embedding'), 
           
           ('self_attn.k_proj', 'attn.key'),
           ('self_attn.q_proj', 'attn.query'),
           ('self_attn.v_proj', 'attn.value'),
           ('self_attn.out_proj', 'attn.out'),
           ('self_attn_layer_norm', 'attn_ln'),
           ('final_layer_norm', 'mlp_ln'),
           ('fc1', 'mlp.0'),
           ('fc2', 'mlp.2'),
           ('encoder_attn.k_proj','cross_attn.key'),
           ('encoder_attn.v_proj','cross_attn.value'),
           ('encoder_attn.q_proj','cross_attn.query'),
           ('encoder_attn.out_proj','cross_attn.out'),
           ('encoder_attn_layer_norm','cross_attn_ln'),
           ('decoder.embed_positions.weight','decoder.positional_embedding'),
           ('decoder.embed_tokens','decoder.token_embedding'),
           ('encoder.layer_norm','encoder.ln_post'),
           ('decoder.layer_norm','decoder.ln'),
          ]


现在,将这个映射应用到Whisper基础模型上,并使用我们从HuggingFace平台上下载的模型的预训练权重:


   def set_pt_model(self):def set_pt_model(self):
        model = whisper.load_model("medium")
        log.info("Whisper base model loaded")
        pretrained_model = torch.load(self.pretrained_model_path)
        log.info(f"Whisper pretrained model loaded from {self.pretrained_model_path}")
        # Extract state_dict if the loaded model is not already a state_dict
        if hasattr(pretrained_model, "state_dict"):
            pretrained_weights = pretrained_model.state_dict()  # extract the state dict
        else:
            pretrained_weights = pretrained_model  # it's already a state_dict
        #######################################################################
        updated_weights = self.map_hf_to_pt(pretrained_weights)
        model.load_state_dict(updated_weights, strict=True)
        log.info(f"Model weights mapped from HuggingFace model to PyTorch")
        ######################################################################
        model.to(self.device)
        model.requires_grad_(False)
        model.eval()
        log.info("Whisper PyTorch model loaded on " + str(self.device))
        return model


使用PyTorch进行推理

我们几乎准备就绪了。接下来定义Whisper处理器和编码函数:


 def create_processor(self):def create_processor(self):
        """
        """
        try:
            processor = WhisperProcessor.from_pretrained(
                self.model_name,
                language=self.language,
                task=self.task
            )
            log.info("WhisperProcessor created")
            return processor
        except Exception as e:
            log.error(f"Error during WhisperProcessor creation: {str(e)}")
            raise

    def preprocess_audio(self, waveform):
        """
        """
        # compute log-Mel input features from input audio array
        mel = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features
        return torch.tensor(mel, dtype=torch.float32)


最后,定义管道和转录函数:


 def inference_pipeline(self,audio_path):def inference_pipeline(self,audio_path):
        log.info("1 - Starting audio load:")
        # waveform, sample_rate = librosa.load(audio_path, sr=self.sr)
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
        log.info("2 - starting preprocessing:")
        audio_features = self.preprocess_audio(waveform)
        log.info("3 - Starting forward pass:")
        with torch.no_grad():
            result = whisper.decode(
                self.model,
                audio_features,
                options=whisper.DecodingOptions(
                    fp16=False,
                    language="it",
                    without_timestamps=True,
                    suppress_blank=False,
                    suppress_tokens=[],
                ),
            )
        return result[0].text

    async def transcribe(self, audio_path: str) -> DecodingResult | list[DecodingResult]:
        """
        """
        try:
            loop = asyncio.get_event_loop()
            log.info(f"Transcribing the following file audio: {audio_path}")
            log.info("Transcription started...")
            text = await loop.run_in_executor(
                None,
                lambda: self.inference_pipeline(audio_path)
                )
            log.info("Transcription completed!")
            return text
        except Exception as e:
            log.error(f"Error during transcription: {str(e)}")
            raise


以下是PyTorch推理类实现的完整代码。请注意在初始化期间使用的torch.set_num_threads(num_threads)——在这行代码中,我们设置了将用于推理的CPU核心数量,这对性能有很大影响:


import os
from src import log
from src.utils import utils
import asyncio
import whisper
from whisper import DecodingResult
from transformers import WhisperForConditionalGeneration, WhisperProcessor, logging as transformers_log
from huggingface_hub import hf_hub_download, login
import torch
import torchaudio
import torch.quantization
class InferenceService:
    _initialized = False
    def __init__(self, language='it', num_threads=1, quantization=True, device = "cpu"):
        try:
            login(token=os.environ['API_TOKEN'])
            log.info("HuggingFace login successful")
        except Exception as e:
            log.error(f"Error during HuggingFace login: {str(e)}")
            raise
        if not InferenceService._initialized:
            os.environ["TRANSFORMERS_VERBOSITY"] = "error"
            transformers_log.set_verbosity_error()
            self.model_name = utils.MERGED_MODEL_NAME
            self.language = language
            self.pytorch_converted_model_source = utils.PRETRAINED_MODEL_PTH
            self.pytorch_converted_model_filename = utils.PRETRAINED_MODEL_FILENAME
            self.task = utils.TASK
            self.device = device
            self.sr = utils.SAMPLING_RATE
            self.mapping = utils.HF_PT_MAPPING
            try:
                # Initialize model and related components
                log.info("Starting PyTorch Inference service...")
                try:
                    self.pretrained_model_path = hf_hub_download(repo_id=self.pytorch_converted_model_source,
                                                                 filename=self.pytorch_converted_model_filename)
                    log.info(f"Whisper pretrained model downloaded to {self.pretrained_model_path}")
                except Exception as e:
                    log.info(f"Unable to download the PyTorch model: {str(e)} - switching to model from HF for conversion")
                    self.get_hf_model()
                self.model = self.set_pt_model()
                if quantization:
                    self.model = torch.quantization.quantize_dynamic(self.model,
                                                                {torch.nn.Linear},
                                                                dtype=torch.qint8)
                self.model = self.model.cpu()
                self.processor = self.create_processor()
                InferenceService._initialized = True
                log.info("PyTorch Inference service started with success!")
            except Exception as e:
                log.error(f"Error during PyTorch Inference service init: {str(e)}")
                raise
        torch.set_num_threads(num_threads)
        log.info(f"Number of threads set to {num_threads} for PyTorch calculations")
    def get_hf_model(self):
        """
        """
        try:
            merged_model = WhisperForConditionalGeneration.from_pretrained(self.model_name)
            pt_model_name = os.path.basename(self.model_name) + ".pth"
            pt_dir_name = os.path.join("assets","pt_models")
            self.pretrained_model_path = os.path.join(pt_dir_name, pt_model_name)
            if not os.path.exists(pt_dir_name):
                os.makedirs(pt_dir_name)
                log.info(f"Directory {pt_dir_name} created and will be used to store PyTorch models")
            else:
                log.info(f"Directory {pt_dir_name} exists, using it to save PyTorch model")
            torch.save(merged_model.state_dict(), self.pretrained_model_path)
            log.info(f"HF model saved to {self.pretrained_model_path} in PyTorch format for conversion")
        except Exception as e:
            log.error(f"Error during HuggingFace model loading: {str(e)}")
            raise
        return 1
    def map_hf_to_pt(self,pretrained_weights):
        def rename_key(key):
            new_key = key
            for k, v in self.mapping:
                new_key = new_key.replace(k, v)
            return new_key
        # Rename the keys in the state_dict
        updated_weights = {rename_key(k): v for k, v in pretrained_weights.items()}
        updated_weights.pop('proj_out.weight', None)
        return updated_weights
    def set_pt_model(self):
        model = whisper.load_model("medium")
        log.info("Whisper base model loaded")
        pretrained_model = torch.load(self.pretrained_model_path)
        log.info(f"Whisper pretrained model loaded from {self.pretrained_model_path}")
        # Extract state_dict if the loaded model is not already a state_dict
        if hasattr(pretrained_model, "state_dict"):
            pretrained_weights = pretrained_model.state_dict()  # extract the state dict
        else:
            pretrained_weights = pretrained_model  # it's already a state_dict
        #######################################################################
        updated_weights = self.map_hf_to_pt(pretrained_weights)
        model.load_state_dict(updated_weights, strict=True)
        log.info(f"Model weights mapped from HuggingFace model to PyTorch")
        # Activate to save converted model and/or its weights
        # torch.save(model, 'src/model/whisper_pretrained_converted.pth')
        # torch.save(updated_weights, 'src/model/whisper_pretrained_converted_weights.pth')
        ######################################################################
        model.to(self.device)
        model.requires_grad_(False)
        model.eval()
        log.info("Whisper PyTorch model loaded on " + str(self.device))
        return model

    def create_processor(self):
        """
        """
        try:
            processor = WhisperProcessor.from_pretrained(
                self.model_name,
                language=self.language,
                task=self.task
            )
            log.info("WhisperProcessor created")
            return processor
        except Exception as e:
            log.error(f"Error during WhisperProcessor creation: {str(e)}")
            raise

    def preprocess_audio(self, waveform):
        """
        """
        # compute log-Mel input features from input audio array
        mel = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features
        return torch.tensor(mel, dtype=torch.float32)

    def inference_pipeline(self,audio_path):
        log.info("1 - Starting audio load:")
        # waveform, sample_rate = librosa.load(audio_path, sr=self.sr)
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
        log.info("2 - starting preprocessing:")
        audio_features = self.preprocess_audio(waveform)
        log.info("3 - Starting forward pass:")
        with torch.no_grad():
            result = whisper.decode(
                self.model,
                audio_features,
                options=whisper.DecodingOptions(
                    fp16=False,
                    language="it",
                    without_timestamps=True,
                    suppress_blank=False,
                    suppress_tokens=[],
                ),
            )
        return result[0].text

    async def transcribe(self, audio_path: str) -> DecodingResult | list[DecodingResult]:
        """
        """
        try:
            loop = asyncio.get_event_loop()
            log.info(f"Transcribing the following file audio: {audio_path}")
            log.info("Transcription started...")
            text = await loop.run_in_executor(
                None,
                lambda: self.inference_pipeline(audio_path)
                )
            log.info("Transcription completed!")
            return text
        except Exception as e:
            log.error(f"Error during transcription: {str(e)}")
            raise
文章来源:https://medium.com/@miosipof/openai-whisper-inference-on-cpu-comparison-e851d8609048
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消