1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
import os
import time
from typing import List, Any
import numpy as np
import srt
import torch
from . import utils, whisper_model
from .type import WhisperMode, SPEECH_ARRAY_INDEX
class Transcribe:
def __init__(self, args, whispermodel):
self.args = args
self.sampling_rate = 16000
self.whisper_model = whispermodel
self.vad_model = None
self.detect_speech = None
tic = time.time()
if self.whisper_model is None:
if self.args.whisper_mode == WhisperMode.WHISPER.value:
self.whisper_model = whisper_model.WhisperModel(self.sampling_rate)
self.whisper_model.load(self.args.whisper_model, self.args.device)
elif self.args.whisper_mode == WhisperMode.OPENAI.value:
self.whisper_model = whisper_model.OpenAIModel(
self.args.openai_rpm, self.sampling_rate
)
self.whisper_model.load()
elif self.args.whisper_mode == WhisperMode.FASTER.value:
self.whisper_model = whisper_model.FasterWhisperModel(
self.sampling_rate
)
self.whisper_model.load(self.args.whisper_model, self.args.device)
logging.info(f"Done Init model in {time.time() - tic:.1f} sec")
def run(self, retry=1):
for input in self.args.inputs:
logging.info(f"Transcribing {input}")
name, _ = os.path.splitext(input)
if utils.check_exists(name + ".md", self.args.force):
continue
try:
audio = utils.load_audio(input, sr=self.sampling_rate)
speech_array_indices = self._detect_voice_activity(audio)
transcribe_results = self._transcribe(input, audio, speech_array_indices)
srt_fn = name + ".srt"
md_fn = name + ".md"
# print(transcribe_results)
srt_json = self._save_srt(srt_fn, transcribe_results)
logging.info(f"Transcribed {input} to {srt_fn}")
self._save_md(md_fn, srt_fn, input, bool(self.args.wmdigit))
logging.info(f'Saved texts to {md_fn} to mark sentences')
return md_fn, srt_fn, srt_json
except Exception as e:
if retry == 1:
raise RuntimeError(f"Failed to Transcribing {e}")
else:
time.sleep(1)
logging.info(f"Retry {retry} to Transcribing {input}")
retry += 1
self.run(retry)
def _detect_voice_activity(self, audio) -> List[SPEECH_ARRAY_INDEX]:
"""Detect segments that have voice activities"""
if self.args.vad == "0":
return [{"start": 0, "end": len(audio)}]
tic = time.time()
if self.vad_model is None or self.detect_speech is None:
# torch load limit https://github.com/pytorch/vision/issues/4156
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.vad_model, funcs = torch.hub.load(
repo_or_dir="/home/ubuntu/.cache/torch/hub/snakers4_silero-vad_master", model="silero_vad", source='local'
)
self.detect_speech = funcs[0]
speeches = self.detect_speech(
audio, self.vad_model, sampling_rate=self.sampling_rate
)
# Remove too short segments
speeches = utils.remove_short_segments(speeches, 1.0 * self.sampling_rate)
# Expand to avoid to tight cut. You can tune the pad length
speeches = utils.expand_segments(
speeches, 0.2 * self.sampling_rate, 0.0 * self.sampling_rate, audio.shape[0]
)
# Merge very closed segments
speeches = utils.merge_adjacent_segments(speeches, 0.5 * self.sampling_rate)
logging.info(f"Done voice activity detection in {time.time() - tic:.1f} sec")
return speeches if len(speeches) > 1 else [{"start": 0, "end": len(audio)}]
def _transcribe(
self,
input: str,
audio: np.ndarray,
speech_array_indices: List[SPEECH_ARRAY_INDEX],
) -> List[Any]:
tic = time.time()
print(speech_array_indices)
res = (
self.whisper_model.transcribe(
audio, speech_array_indices, self.args.lang, self.args.prompt
)
if self.args.whisper_mode == WhisperMode.WHISPER.value
or self.args.whisper_mode == WhisperMode.FASTER.value
else self.whisper_model.transcribe(
input, audio, speech_array_indices, self.args.lang, self.args.prompt
)
)
logging.info(f"Done transcription in {time.time() - tic:.1f} sec")
return res
def _save_srt(self, output, transcribe_results):
subs = self.whisper_model.gen_srt(transcribe_results)
# print(subs)
# 把翻译后的字幕中的中文去掉,有的翻译的不好
if self.args.lang not in ("zh","Japanese"):
for s in subs:
s.content = utils.remove_chinese(s.content)
# 生成字幕文件
with open(output, "wb") as f:
f.write(srt.compose(subs).encode(self.args.encoding, "replace"))
# 生成字幕json
sub_json = utils.gen_subjson_from_subs(subs)
print(sub_json)
return sub_json
def _save_md(self, md_fn, srt_fn, video_fn, is_auto_edit=False):
with open(srt_fn, encoding=self.args.encoding) as f:
subs = srt.parse(f.read())
md = utils.MD(md_fn, self.args.encoding)
md.clear()
md.add_done_editing(is_auto_edit)
md.add_video(os.path.basename(video_fn))
md.add(
f"\nTexts generated from [{os.path.basename(srt_fn)}]({os.path.basename(srt_fn)})."
"Mark the sentences to keep for autocut.\n"
"The format is [subtitle_index,duration_in_second] subtitle context.\n\n"
)
for s in subs:
sec = s.start.seconds
pre = f"[{s.index},{sec // 60:02d}:{sec % 60:02d}]"
md.add_task(is_auto_edit, f"{pre:11} {s.content.strip()}")
md.write()