wmdigit_transcribe.py 6.1 KB
import logging
import os
import time
from typing import List, Any

import numpy as np
import srt
import torch

from . import utils, whisper_model
from .type import WhisperMode, SPEECH_ARRAY_INDEX


class Transcribe:
    def __init__(self, args, whispermodel):
        self.args = args
        self.sampling_rate = 16000
        self.whisper_model = whispermodel
        self.vad_model = None
        self.detect_speech = None

        tic = time.time()
        if self.whisper_model is None:
            if self.args.whisper_mode == WhisperMode.WHISPER.value:
                self.whisper_model = whisper_model.WhisperModel(self.sampling_rate)
                self.whisper_model.load(self.args.whisper_model, self.args.device)
            elif self.args.whisper_mode == WhisperMode.OPENAI.value:
                self.whisper_model = whisper_model.OpenAIModel(
                    self.args.openai_rpm, self.sampling_rate
                )
                self.whisper_model.load()
            elif self.args.whisper_mode == WhisperMode.FASTER.value:
                self.whisper_model = whisper_model.FasterWhisperModel(
                    self.sampling_rate
                )
                self.whisper_model.load(self.args.whisper_model, self.args.device)
        logging.info(f"Done Init model in {time.time() - tic:.1f} sec")


    def run(self, retry=1):
        for input in self.args.inputs:
            logging.info(f"Transcribing {input}")
            name, _ = os.path.splitext(input)
            if utils.check_exists(name + ".md", self.args.force):
                continue

            try:
                audio = utils.load_audio(input, sr=self.sampling_rate)
                speech_array_indices = self._detect_voice_activity(audio)
                transcribe_results = self._transcribe(input, audio, speech_array_indices)

                srt_fn = name + ".srt"
                md_fn = name + ".md"
                # print(transcribe_results)
                srt_json = self._save_srt(srt_fn, transcribe_results)
                logging.info(f"Transcribed {input} to {srt_fn}")
                self._save_md(md_fn, srt_fn, input, bool(self.args.wmdigit))
                logging.info(f'Saved texts to {md_fn} to mark sentences')
                return md_fn, srt_fn, srt_json
            except Exception as e:
                if retry == 1:
                    raise RuntimeError(f"Failed to Transcribing {e}")
                else:
                    time.sleep(1)
                    logging.info(f"Retry {retry} to Transcribing {input}")
                    retry += 1
                    self.run(retry)

    def _detect_voice_activity(self, audio) -> List[SPEECH_ARRAY_INDEX]:
        """Detect segments that have voice activities"""
        if self.args.vad == "0":
            return [{"start": 0, "end": len(audio)}]

        tic = time.time()
        if self.vad_model is None or self.detect_speech is None:
            # torch load limit https://github.com/pytorch/vision/issues/4156
            torch.hub._validate_not_a_forked_repo = lambda a, b, c: True

            self.vad_model, funcs = torch.hub.load(
                repo_or_dir="/home/ubuntu/.cache/torch/hub/snakers4_silero-vad_master", model="silero_vad", source='local'
            )

            self.detect_speech = funcs[0]

        speeches = self.detect_speech(
            audio, self.vad_model, sampling_rate=self.sampling_rate
        )

        # Remove too short segments
        speeches = utils.remove_short_segments(speeches, 1.0 * self.sampling_rate)

        # Expand to avoid to tight cut. You can tune the pad length
        speeches = utils.expand_segments(
            speeches, 0.2 * self.sampling_rate, 0.0 * self.sampling_rate, audio.shape[0]
        )

        # Merge very closed segments
        speeches = utils.merge_adjacent_segments(speeches, 0.5 * self.sampling_rate)

        logging.info(f"Done voice activity detection in {time.time() - tic:.1f} sec")
        return speeches if len(speeches) > 1 else [{"start": 0, "end": len(audio)}]

    def _transcribe(
        self,
        input: str,
        audio: np.ndarray,
        speech_array_indices: List[SPEECH_ARRAY_INDEX],
    ) -> List[Any]:
        tic = time.time()
        print(speech_array_indices)
        res = (
            self.whisper_model.transcribe(
                audio, speech_array_indices, self.args.lang, self.args.prompt
            )
            if self.args.whisper_mode == WhisperMode.WHISPER.value
            or self.args.whisper_mode == WhisperMode.FASTER.value
            else self.whisper_model.transcribe(
                input, audio, speech_array_indices, self.args.lang, self.args.prompt
            )
        )
        logging.info(f"Done transcription in {time.time() - tic:.1f} sec")
        return res

    def _save_srt(self, output, transcribe_results):
        subs = self.whisper_model.gen_srt(transcribe_results)
        # print(subs)
        # 把翻译后的字幕中的中文去掉,有的翻译的不好
        if self.args.lang not in ("zh","Japanese"):
            for s in subs:
                s.content = utils.remove_chinese(s.content)
        # 生成字幕文件
        with open(output, "wb") as f:
            f.write(srt.compose(subs).encode(self.args.encoding, "replace"))
        # 生成字幕json
        sub_json = utils.gen_subjson_from_subs(subs)
        print(sub_json)
        return sub_json

    def _save_md(self, md_fn, srt_fn, video_fn, is_auto_edit=False):
        with open(srt_fn, encoding=self.args.encoding) as f:
            subs = srt.parse(f.read())

        md = utils.MD(md_fn, self.args.encoding)
        md.clear()
        md.add_done_editing(is_auto_edit)
        md.add_video(os.path.basename(video_fn))
        md.add(
            f"\nTexts generated from [{os.path.basename(srt_fn)}]({os.path.basename(srt_fn)})."
            "Mark the sentences to keep for autocut.\n"
            "The format is [subtitle_index,duration_in_second] subtitle context.\n\n"
        )

        for s in subs:
            sec = s.start.seconds
            pre = f"[{s.index},{sec // 60:02d}:{sec % 60:02d}]"
            md.add_task(is_auto_edit, f"{pre:11} {s.content.strip()}")
        md.write()