From dfc41839324f6bd5b1267cff276037b8ee92fa58 Mon Sep 17 00:00:00 2001
From: zhouchengbo <zhouchengbo@wmdigit.com>
Date: Thu, 9 Nov 2023 16:36:55 +0800
Subject: [PATCH] =?UTF-8?q?=E5=90=AF=E5=8A=A8=E6=97=B6=E5=8A=A0=E8=BD=BDwh?=
 =?UTF-8?q?isper=5Fmodel?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 app/video_cut/autocut/wmdigit_transcribe.py | 4 ++--
 app/video_cut/load_args.py                  | 7 ++++++-
 app/video_cut/main.py                       | 4 ++--
 requirements.txt                            | 1 +
 start.py                                    | 4 ++--
 5 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/app/video_cut/autocut/wmdigit_transcribe.py b/app/video_cut/autocut/wmdigit_transcribe.py
index 15f7115..d73cc12 100644
--- a/app/video_cut/autocut/wmdigit_transcribe.py
+++ b/app/video_cut/autocut/wmdigit_transcribe.py
@@ -12,10 +12,10 @@ from .type import WhisperMode, SPEECH_ARRAY_INDEX
 
 
 class Transcribe:
-    def __init__(self, args):
+    def __init__(self, args, whisper_model):
         self.args = args
         self.sampling_rate = 16000
-        self.whisper_model = None
+        self.whisper_model = whisper_model
         self.vad_model = None
         self.detect_speech = None
 
diff --git a/app/video_cut/load_args.py b/app/video_cut/load_args.py
index 535c1f2..7197860 100644
--- a/app/video_cut/load_args.py
+++ b/app/video_cut/load_args.py
@@ -135,5 +135,10 @@ def main_args(logger):
     args.force = True
     args.vad = "0"
     args.whisper_model = "large-v2"
+    args.device = "cuda"
 
-    return args
\ No newline at end of file
+    logger.info(f'load whisper_model: {args.whisper_model} device: {args.device}')
+    import whisper
+    whisper_model = whisper.load_model(args.whisper_model, args.device)
+
+    return args, whisper_model
\ No newline at end of file
diff --git a/app/video_cut/main.py b/app/video_cut/main.py
index 0ba4030..0224329 100644
--- a/app/video_cut/main.py
+++ b/app/video_cut/main.py
@@ -31,7 +31,7 @@ def validate_request():
 
 
 # 主线
-def video_cut_pipeline(logger, args):
+def video_cut_pipeline(logger, args, whisper_model):
     # print(args)
     time_record = []
     media_file, lang = validate_request()
@@ -50,7 +50,7 @@ def video_cut_pipeline(logger, args):
 
     args.inputs = [media_file]
     args.lang = lang
-    wmdigit_transcribe.Transcribe(args).run()
+    wmdigit_transcribe.Transcribe(args, whisper_model).run()
     time_record.append(f"视频生成srt和md。耗时: {time.time() - start_time:.4f} 秒")
 
 
diff --git a/requirements.txt b/requirements.txt
index a0c9c5b..0f8d4d4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,7 @@ flask_sqlalchemy
 redis
 tqdm
 
+oss2
 moviepy==2.0.0.dev2
 edge-tts
 openai-whisper
diff --git a/start.py b/start.py
index 7f26c22..74945e6 100644
--- a/start.py
+++ b/start.py
@@ -26,12 +26,12 @@ input_root = os.path.join(root, 'inputs')
 output_root = os.path.join(root, 'outputs')
 
 # 预加载模型
-args = main_args(logger)
+args, whisper_model = main_args(logger)
 
 # 对外接口
 @app.route('/wm_video_cut', methods=['POST'])
 def wm_video_cut():
-    final_video_url, srt_url = video_cut_pipeline(logger, args)
+    final_video_url, srt_url = video_cut_pipeline(logger, args, whisper_model)
     return jsonify({"result": {"final_video_url": final_video_url, "srt_url": srt_url}})
 
 
-- 
2.18.1