Commit dfc41839 authored by Administrator's avatar Administrator

启动时加载whisper_model

parent 91031c74
...@@ -12,10 +12,10 @@ from .type import WhisperMode, SPEECH_ARRAY_INDEX ...@@ -12,10 +12,10 @@ from .type import WhisperMode, SPEECH_ARRAY_INDEX
class Transcribe: class Transcribe:
def __init__(self, args): def __init__(self, args, whisper_model):
self.args = args self.args = args
self.sampling_rate = 16000 self.sampling_rate = 16000
self.whisper_model = None self.whisper_model = whisper_model
self.vad_model = None self.vad_model = None
self.detect_speech = None self.detect_speech = None
......
...@@ -135,5 +135,10 @@ def main_args(logger): ...@@ -135,5 +135,10 @@ def main_args(logger):
args.force = True args.force = True
args.vad = "0" args.vad = "0"
args.whisper_model = "large-v2" args.whisper_model = "large-v2"
args.device = "cuda"
return args logger.info(f'load whisper_model: {args.whisper_model} device: {args.device}')
\ No newline at end of file import whisper
whisper_model = whisper.load_model(args.whisper_model, args.device)
return args, whisper_model
\ No newline at end of file
...@@ -31,7 +31,7 @@ def validate_request(): ...@@ -31,7 +31,7 @@ def validate_request():
# 主线 # 主线
def video_cut_pipeline(logger, args): def video_cut_pipeline(logger, args, whisper_model):
# print(args) # print(args)
time_record = [] time_record = []
media_file, lang = validate_request() media_file, lang = validate_request()
...@@ -50,7 +50,7 @@ def video_cut_pipeline(logger, args): ...@@ -50,7 +50,7 @@ def video_cut_pipeline(logger, args):
args.inputs = [media_file] args.inputs = [media_file]
args.lang = lang args.lang = lang
wmdigit_transcribe.Transcribe(args).run() wmdigit_transcribe.Transcribe(args, whisper_model).run()
time_record.append(f"视频生成srt和md。耗时: {time.time() - start_time:.4f} 秒") time_record.append(f"视频生成srt和md。耗时: {time.time() - start_time:.4f} 秒")
......
...@@ -7,6 +7,7 @@ flask_sqlalchemy ...@@ -7,6 +7,7 @@ flask_sqlalchemy
redis redis
tqdm tqdm
oss2
moviepy==2.0.0.dev2 moviepy==2.0.0.dev2
edge-tts edge-tts
openai-whisper openai-whisper
......
...@@ -26,12 +26,12 @@ input_root = os.path.join(root, 'inputs') ...@@ -26,12 +26,12 @@ input_root = os.path.join(root, 'inputs')
output_root = os.path.join(root, 'outputs') output_root = os.path.join(root, 'outputs')
# 预加载模型 # 预加载模型
args = main_args(logger) args, whisper_model = main_args(logger)
# 对外接口 # 对外接口
@app.route('/wm_video_cut', methods=['POST']) @app.route('/wm_video_cut', methods=['POST'])
def wm_video_cut(): def wm_video_cut():
final_video_url, srt_url = video_cut_pipeline(logger, args) final_video_url, srt_url = video_cut_pipeline(logger, args, whisper_model)
return jsonify({"result": {"final_video_url": final_video_url, "srt_url": srt_url}}) return jsonify({"result": {"final_video_url": final_video_url, "srt_url": srt_url}})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment