Commit 17ab6fbf authored by Administrator's avatar Administrator

修正模型加载,增加环境判断

parent b27d52ec
......@@ -12,10 +12,10 @@ from .type import WhisperMode, SPEECH_ARRAY_INDEX
class Transcribe:
def __init__(self, args):
def __init__(self, args, whispermodel):
self.args = args
self.sampling_rate = 16000
self.whisper_model = None
self.whisper_model = whispermodel
self.vad_model = None
self.detect_speech = None
......
import argparse
import logging
import os
from app.video_cut.autocut import utils
from app.video_cut.autocut import whisper_model
from app.video_cut.autocut.type import WhisperMode, WhisperModel
def main_args(logger):
def main_args(logger, debug):
logger.info('load augument')
parser = argparse.ArgumentParser()
parser.add_argument("--inputs", type=str, help="Inputs filenames/folders")
......@@ -134,12 +132,14 @@ def main_args(logger):
args.wmdigit = True
args.force = True
args.vad = "0"
args.whisper_model = "large-v2"
args.device = "cuda"
logger.info(f'load whisper_model: {args.whisper_model} device: {args.device}')
import whisper
whisper_model = whisper.load_model(args.whisper_model, args.device)
if not debug:
args.whisper_model = "large-v2"
args.device = "cuda"
logger.info(f'load whisper_model: {args.whisper_model}, device: {args.device}')
whispermodel = whisper_model.WhisperModel(16000)
whispermodel.load(args.whisper_model, args.device)
logger.info(f'done.')
return args, whisper_model
\ No newline at end of file
return args, whispermodel
\ No newline at end of file
......@@ -31,7 +31,7 @@ def validate_request():
# 主线
def video_cut_pipeline(logger, args, whisper_model):
def video_cut_pipeline(logger, args, whispermodel):
# print(args)
time_record = []
media_file, lang = validate_request()
......@@ -50,7 +50,7 @@ def video_cut_pipeline(logger, args, whisper_model):
args.inputs = [media_file]
args.lang = lang
wmdigit_transcribe.Transcribe(args, whisper_model).run()
wmdigit_transcribe.Transcribe(args, whispermodel).run()
time_record.append(f"视频生成srt和md。耗时: {time.time() - start_time:.4f} 秒")
......
......@@ -26,12 +26,12 @@ input_root = os.path.join(root, 'inputs')
output_root = os.path.join(root, 'outputs')
# 预加载模型
args, whisper_model = main_args(logger)
args, whispermodel = main_args(logger, app.config['DEBUG'])
# 对外接口
@app.route('/wm_video_cut', methods=['POST'])
def wm_video_cut():
final_video_url, srt_url = video_cut_pipeline(logger, args, whisper_model)
final_video_url, srt_url = video_cut_pipeline(logger, args, whispermodel)
return jsonify({"result": {"final_video_url": final_video_url, "srt_url": srt_url}})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment