From dfc41839324f6bd5b1267cff276037b8ee92fa58 Mon Sep 17 00:00:00 2001 From: zhouchengbo <zhouchengbo@wmdigit.com> Date: Thu, 9 Nov 2023 16:36:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=AF=E5=8A=A8=E6=97=B6=E5=8A=A0=E8=BD=BDwh?= =?UTF-8?q?isper=5Fmodel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/video_cut/autocut/wmdigit_transcribe.py | 4 ++-- app/video_cut/load_args.py | 7 ++++++- app/video_cut/main.py | 4 ++-- requirements.txt | 1 + start.py | 4 ++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/app/video_cut/autocut/wmdigit_transcribe.py b/app/video_cut/autocut/wmdigit_transcribe.py index 15f7115..d73cc12 100644 --- a/app/video_cut/autocut/wmdigit_transcribe.py +++ b/app/video_cut/autocut/wmdigit_transcribe.py @@ -12,10 +12,10 @@ from .type import WhisperMode, SPEECH_ARRAY_INDEX class Transcribe: - def __init__(self, args): + def __init__(self, args, whisper_model): self.args = args self.sampling_rate = 16000 - self.whisper_model = None + self.whisper_model = whisper_model self.vad_model = None self.detect_speech = None diff --git a/app/video_cut/load_args.py b/app/video_cut/load_args.py index 535c1f2..7197860 100644 --- a/app/video_cut/load_args.py +++ b/app/video_cut/load_args.py @@ -135,5 +135,10 @@ def main_args(logger): args.force = True args.vad = "0" args.whisper_model = "large-v2" + args.device = "cuda" - return args \ No newline at end of file + logger.info(f'load whisper_model: {args.whisper_model} device: {args.device}') + import whisper + whisper_model = whisper.load_model(args.whisper_model, args.device) + + return args, whisper_model \ No newline at end of file diff --git a/app/video_cut/main.py b/app/video_cut/main.py index 0ba4030..0224329 100644 --- a/app/video_cut/main.py +++ b/app/video_cut/main.py @@ -31,7 +31,7 @@ def validate_request(): # 主线 -def video_cut_pipeline(logger, args): +def video_cut_pipeline(logger, args, whisper_model): # print(args) time_record = [] media_file, lang = validate_request() @@ -50,7 +50,7 @@ def video_cut_pipeline(logger, args): args.inputs = [media_file] args.lang = lang - wmdigit_transcribe.Transcribe(args).run() + wmdigit_transcribe.Transcribe(args, whisper_model).run() time_record.append(f"视频生æˆsrtå’Œmd。耗时: {time.time() - start_time:.4f} ç§’") diff --git a/requirements.txt b/requirements.txt index a0c9c5b..0f8d4d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ flask_sqlalchemy redis tqdm +oss2 moviepy==2.0.0.dev2 edge-tts openai-whisper diff --git a/start.py b/start.py index 7f26c22..74945e6 100644 --- a/start.py +++ b/start.py @@ -26,12 +26,12 @@ input_root = os.path.join(root, 'inputs') output_root = os.path.join(root, 'outputs') # é¢„åŠ è½½æ¨¡åž‹ -args = main_args(logger) +args, whisper_model = main_args(logger) # å¯¹å¤–æŽ¥å£ @app.route('/wm_video_cut', methods=['POST']) def wm_video_cut(): - final_video_url, srt_url = video_cut_pipeline(logger, args) + final_video_url, srt_url = video_cut_pipeline(logger, args, whisper_model) return jsonify({"result": {"final_video_url": final_video_url, "srt_url": srt_url}}) -- 2.18.1