涉及到的模型为icc发布的几款模型,记得将cuda改成npu

# coding=utf-8
import asyncio
import io
import json
import numpy as np
import os
import re
import requests
import sys
import configparser
import time
import torchaudio
import traceback
import warnings
# import zhconv  # 添加简体中文转换库
from concurrent.futures import ThreadPoolExecutor
from funasr import AutoModel
from io import BytesIO
from urllib.parse import urlparse
import uuid
import tempfile  # 导入tempfile模块
import wave
import tornado.ioloop
import tornado.web
from tornado.concurrent import run_on_executor
import torch
from transformers import pipeline
import base64
import struct
import logging

warnings.filterwarnings("ignore")
print(f"cuda is available: {torch.cuda.is_available()}")

# 限制 OpenMP 线程数,避免 libgomp 线程创建失败
# 根据系统资源设置,一般设置为 CPU 核心数或更小
os.environ["OMP_NUM_THREADS"] = os.environ.get("OMP_NUM_THREADS", "4")
os.environ["MKL_NUM_THREADS"] = os.environ.get("MKL_NUM_THREADS", "4")
os.environ["NUMEXPR_NUM_THREADS"] = os.environ.get("NUMEXPR_NUM_THREADS", "4")
os.environ["FORCE_USE_TORCHAUDIO"] = "1"
os.environ["DISABLE_FFMPEG"] = "1"

# device = "npu:0" if torch.npu.is_available() else "cpu"
# device可选cpu和gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device为cpu
# device = "cpu"
warnings.filterwarnings("ignore")
print("--------")
print(device)
print("--------")
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["MODELSCOPE_OFFLINE"] = "1"

# 创建简体中文优化的pipeline
def create_simplified_asr_pipe(model_path):
    """
    创建支持简体中文输出的语音识别pipeline
    """
    # 初始化pipeline
    asr_pipe = pipeline(
        "automatic-speech-recognition",
        model=model_path,
        chunk_length_s=30,
        device=device,
        torch_dtype=torch.float16 if "npu" in device else torch.float32,
    )
    # 简体中文优化配置
    asr_pipe.model.config.forced_decoder_ids = (
        asr_pipe.tokenizer.get_decoder_prompt_ids(
            language="zh", task="transcribe", no_timestamps=False
        )
    )
    # 设置简体中文提示
    if hasattr(asr_pipe.model, "set_prefix"):
        asr_pipe.model.set_prefix("简体中文")
    return asr_pipe

# 创建简体中文语音识别pipeline
# asr_pipe = create_simplified_asr_pipe("/mnt/raid1/whisper-large-v3-turbo")
if sys.platform == "win32":
    asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

class ModelService(object):
    def __init__(self):
        print("开始加载模型")
        # 不带说话人识别模型的实例(offLinePrmodel 不使用 spk_model)
        self.offLinePrmodel = AutoModel(
            model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
            vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
            punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
            device=str(device),
        )
        # 不带说话人识别模型的实例(用于大文件,避免线程创建失败)
        self.offLinePrmodel_no_spk = AutoModel(
            model="F:/model/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
            vad_model="F:/model/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
            punc_model="F:/model/iic/punc_ct-transformer_cn-en-common-vocab471067-large",
            device=str(device),
        )
        self.asr_model = AutoModel(
            model="F:/model/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
            trust_remote_code=True,
            # remote_code="./model.py",
            device=str(device),
        )
        # 带说话人识别模型的实例(用于小文件)
        self.onlinePrmodel = AutoModel(
            model="F:/model/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
            vad_model="F:/model/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
            punc_model="F:/model/iic/punc_ct-transformer_cn-en-common-vocab471067-large",
            spk_model="F:/model/iic/speech_campplus_sv_zh-cn_16k-common",
            device=str(device),
        )
        # 不带说话人识别模型的实例(用于大文件)
        self.onlinePrmodel_no_spk = AutoModel(
            model="F:/model/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
            vad_model="F:/model/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
            punc_model="F:/model/iic/punc_ct-transformer_cn-en-common-vocab471067-large",
            device=str(device),
        )

    def model_process(self, data):
        return data

class OfflineServiceHandler(tornado.web.RequestHandler):
    def initialize(self, modelService):
        self.modelService = modelService
        self.set_default_header()

    executor = ThreadPoolExecutor(5)

    @tornado.gen.coroutine
    def post(self):
        res = yield self.data_process_and_model()
        print("&&&&&&&&&&&&&&&&&&&&&&&&")
        print((res))
        self.write(json.dumps(res, ensure_ascii=False))
        self.finish()

    def set_default_header(self):
        print("setting headers!!!")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "*")
        self.set_header(
            "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"
        )
        self.set_header("Content-Type", "application/json; charset=UTF-8")
        self.set_header("Access-Control-Allow-Headers", "Content-Type")

    def options(self):
        pass

    @run_on_executor
    def data_process_and_model(self):
        try:
            json_content = self.request.body.decode("utf8")
            print(json_content)
            json_data = json.loads(json_content)
            speech_path = json_data["speech_path"]
            # 使用简体中文转录函数
            summary = self.audio_to_text(speech_path)
            print(summary)
            resultInResult = {}
            resultInResult["timestamp"] = time.time()
            resultInResult["message"] = "ok"
            resultInResult["status"] = 200
            resultInResult["code"] = 10170000
            resultInResult["data"] = summary
            return resultInResult
        except Exception as e:
            resultInResult = {}
            resultInResult["timestamp"] = time.time()
            resultInResult["message"] = "error"
            resultInResult["status"] = 500
            resultInResult["code"] = 10170001
            resultInResult["data"] = str(e)
            return resultInResult

    def audio_to_text(self, audio_url, hotword="魔搭"):
        """从URL下载音频文件并进行语音识别
        参数:
            audio_url (str): 音频文件的URL
            hotword (str): 热词,默认为'魔搭'
        返回:
            str: 识别结果文本,失败时返回None
        """
        temp_file_path = None  # 初始化临时文件路径变量
        try:
            print(f"开始处理音频URL: {audio_url}")
            # 下载音频数据
            response = requests.get(audio_url, verify=False, timeout=1200)
            response.raise_for_status()
            print(f"音频文件下载成功,大小: {len(response.content)} 字节")
            # 从URL获取原始文件名和扩展名
            parsed_url = urlparse(audio_url)
            original_name = os.path.basename(parsed_url.path)
            # 尝试提取文件扩展名
            if "." in original_name:
                extension = "." + original_name.split(".")[-1].lower()
                valid_extensions = {
                    ".wav",
                    ".mp3",
                    ".flac",
                    ".ogg",
                    ".m4a",
                    ".aac",
                    ".webm",
                }
                if extension not in valid_extensions:
                    print(f"未知的扩展名 '{extension}',使用.wav作为默认")
                    extension = ".wav"
            else:
                print("URL中没有扩展名,使用.wav作为默认")
                extension = ".wav"
            # 生成唯一文件名(UUID + 扩展名)
            file_name = f"{str(uuid.uuid4())}{extension}"
            temp_dir = tempfile.gettempdir()
            temp_file_path = os.path.join(temp_dir, file_name)
            with open(temp_file_path, "wb") as temp_file:
                temp_file.write(response.content)
            print(f"音频文件保存到临时路径: {temp_file_path}")
            if not os.path.exists(temp_file_path):
                raise FileNotFoundError(f"临时文件创建失败: {temp_file_path}")
            file_size = os.path.getsize(temp_file_path)
            if file_size == 0:
                raise ValueError("下载的音频文件为空")
            print(f"临时文件大小: {file_size} 字节 ({file_size / 1024 / 1024:.2f} MB)")
            if file_size > 50 * 1024 * 1024:  # 50MB
                batch_size_s = 60
                print(
                    f"检测到大文件 ({file_size / 1024 / 1024:.2f} MB),batch_size_s={batch_size_s}"
                )
            elif file_size > 20 * 1024 * 1024:  # 20MB
                batch_size_s = 120
                print(
                    f"检测到中等文件 ({file_size / 1024 / 1024:.2f} MB),batch_size_s={batch_size_s}"
                )
            else:
                batch_size_s = 300  # 默认值
            temp_file_path, wav_file2 = deep_analyze_wav(temp_file_path)
            print("开始语音识别...")
            res = self.modelService.offLinePrmodel.generate(
                input=temp_file_path,
                batch_size_s=batch_size_s,
                hotword=hotword,
                cache={},
            )
            recognized_text = res[0]["text"]
            print(f"识别成功: {recognized_text}")
            sentences = [s.strip() for s in recognized_text.split("。") if s.strip()]
            result_text = "。\n".join(sentences)
            if sentences:
                result_text += "。"
            return result_text
        except requests.exceptions.RequestException as e:
            print(f"下载失败: {str(e)}")
            return None
        except Exception as e:
            print(f"识别错误: {str(e)}")
            print("Traceback (most recent call last):")
            print(traceback.format_exc())
            return None
        finally:
            if temp_file_path and os.path.exists(temp_file_path):
                try:
                    os.remove(temp_file_path)
                    print(f"已删除临时文件: {temp_file_path}")
                except OSError as e:
                    print(f"删除临时文件失败: {str(e)}")

class OfflineServiceHandlerV2(tornado.web.RequestHandler):
    def initialize(self, modelService):
        self.modelService = modelService
        self.set_default_header()

    executor = ThreadPoolExecutor(5)

    @tornado.gen.coroutine
    def post(self):
        res = yield self.data_process_and_model()
        print("&&&&&&&&&&&&&&&&&&&&&&&&")
        print((res))
        self.write(json.dumps(res, ensure_ascii=False))
        self.finish()

    def set_default_header(self):
        print("setting headers!!!")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "*")
        self.set_header(
            "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"
        )
        self.set_header("Content-Type", "application/json; charset=UTF-8")
        self.set_header("Access-Control-Allow-Headers", "Content-Type")

    def options(self):
        pass

    @run_on_executor
    def data_process_and_model(self):
        try:
            json_content = self.request.body.decode("utf8")
            print(json_content)
            json_data = json.loads(json_content)
            speech_path = json_data["speech_path"]
            res = self.audio_to_text(speech_path)
            print(res)
            result_text = res if res else ""
            resultInResult = {}
            resultInResult["timestamp"] = time.time()
            resultInResult["message"] = "ok"
            resultInResult["status"] = 200
            resultInResult["code"] = 10170000
            resultInResult["result"] = result_text
            return resultInResult
        except Exception as e:
            resultInResult = {}
            resultInResult["timestamp"] = time.time()
            resultInResult["message"] = "error"
            resultInResult["status"] = 500
            resultInResult["code"] = 10170001
            resultInResult["data"] = str(e)
            return resultInResult

    def audio_to_text(self, audio_url, hotword="魔搭"):
        temp_file_path = None
        cleanup_paths = []
        try:
            print(f"开始处理音频URL: {audio_url}")
            response = requests.get(audio_url, verify=False, timeout=1200)
            response.raise_for_status()
            print(f"音频文件下载成功,大小: {len(response.content)} 字节")
            parsed_url = urlparse(audio_url)
            original_name = os.path.basename(parsed_url.path)
            if "." in original_name:
                extension = "." + original_name.split(".")[-1].lower()
                valid_extensions = {
                    ".wav",
                    ".mp3",
                    ".flac",
                    ".ogg",
                    ".m4a",
                    ".aac",
                    ".webm",
                    ".mp4",
                }
                if extension not in valid_extensions:
                    print(f"未知的扩展名 '{extension}',使用.wav作为默认")
                    extension = ".wav"
            else:
                print("URL中没有扩展名,使用.wav作为默认")
                extension = ".wav"
            file_name = f"{str(uuid.uuid4())}{extension}"
            temp_dir = tempfile.gettempdir()
            temp_file_path = os.path.join(temp_dir, file_name)
            with open(temp_file_path, "wb") as temp_file:
                temp_file.write(response.content)
            print(f"音频文件保存到临时路径: {temp_file_path}")
            cleanup_paths.append(temp_file_path)
            if not os.path.exists(temp_file_path):
                raise FileNotFoundError(f"临时文件创建失败: {temp_file_path}")
            file_size = os.path.getsize(temp_file_path)
            if file_size == 0:
                raise ValueError("下载的音频文件为空")
            print(f"临时文件大小: {file_size} 字节 ({file_size / 1024 / 1024:.2f} MB)")
            if file_size > 50 * 1024 * 1024:
                batch_size_s = 60
                print(
                    f"检测到大文件 ({file_size / 1024 / 1024:.2f} MB),batch_size_s={batch_size_s}"
                )
            elif file_size > 20 * 1024 * 1024:
                batch_size_s = 120
                print(
                    f"检测到中等文件 ({file_size / 1024 / 1024:.2f} MB),batch_size_s={batch_size_s}"
                )
            else:
                batch_size_s = 300
            processing_path = temp_file_path
            if extension == ".wav":
                is_valid_wav = validate_wav_file(processing_path)
                if not is_valid_wav:
                    raise ValueError("音频文件格式验证失败")
                deep_analyze_wav(processing_path)
            else:
                print(f"检测到非 WAV 音频格式 {extension},直接交给模型处理")
            print("开始语音识别...")
            res = self.modelService.offLinePrmodel.generate(
                input=processing_path,
                batch_size_s=batch_size_s,
                hotword=hotword,
                cache={},
            )
            recognized_text = (
                res[0]["text"] if res and isinstance(res, list) and len(res) > 0 else ""
            )
            print(f"识别成功: {recognized_text}")
            sentences = [s.strip() for s in recognized_text.split("。") if s.strip()]
            result_text = "。\n".join(sentences)
            if sentences:
                result_text += "。"
            return result_text
        except requests.exceptions.RequestException as e:
            print(f"下载失败: {str(e)}")
            return None
        except Exception as e:
            print(f"识别错误: {str(e)}")
            print("Traceback (most recent call last):")
            print(traceback.format_exc())
            return None
        finally:
            for path in cleanup_paths:
                if path and os.path.exists(path):
                    try:
                        os.remove(path)
                        print(f"已删除临时文件: {path}")
                    except OSError as e:
                        print(f"删除临时文件失败: {str(e)}")

def array_to_bytes_basic(audio_array, sample_width=2, byte_order="<"):
    byte_stream = bytearray()
    format_map = {
        1: "b",
        2: "h",
        4: "i",
        4.1: "f",
    }
    if sample_width not in format_map:
        raise ValueError(f"不支持的样本宽度: {sample_width}")
    pack_format = byte_order + format_map[sample_width]
    for sample in audio_array:
        if sample_width == 4.1:
            sample = max(min(float(sample), 1.0), -1.0)
            pack_format = byte_order + "f"
        byte_stream += struct.pack(pack_format, int(sample))
    return bytes(byte_stream)

class OnlineServiceHandler(tornado.web.RequestHandler):
    def initialize(self, modelService):
        self.modelService = modelService
        self.set_default_header()

    executor = ThreadPoolExecutor(5)

    @tornado.gen.coroutine
    def post(self):
        res = yield self.data_process_and_model()
        print("&&&&&&&&&&&&&&&&&&&&&&&&")
        print((res))
        self.write(json.dumps(res, ensure_ascii=False))
        self.finish()

    def set_default_header(self):
        print("setting headers!!!")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "*")
        self.set_header(
            "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"
        )
        self.set_header("Content-Type", "application/json; charset=UTF-8")
        self.set_header("Access-Control-Allow-Headers", "Content-Type")

    def options(self):
        pass

    @run_on_executor
    def data_process_and_model(self):
        try:
            totalStartTime = round(time.time(), 3)
            json_content = self.request.body.decode("utf8")
            print(json_content)
            json_data = json.loads(json_content)
            audio_data_array = json_data["wavContent"]
            audio_data_array = base64.b64decode(audio_data_array)
            sessionId = json_data["sessionId"]
            serialId = json_data["serialId"]
            samples = np.frombuffer(audio_data_array, dtype="<i2")
            samples = samples.astype(np.float32)
            result = self.modelService.asr_model.inference(
                samples * 32768, language="zh", disable_pbar=True
            )
            text = rich_transcription_postprocess(result[0]["text"])
            text = filter_text(text)
            print(text)
            resultInResult = {
                "sessionId": sessionId,
                "serialId": serialId,
                "content": text,
            }
            totalEndTime = round(time.time(), 3)
            execution_time = totalEndTime - totalStartTime
            print(execution_time)
            return resultInResult
        except Exception as e:
            if "Length mismatch" in str(e):
                print(str(e))
                print(traceback.print_exc())
                finnal = {}
                finnal["timestamp"] = time.time()
                finnal["message"] = "数据字段数与字段数对应不上"
                finnal["status"] = 500
                finnal["code"] = 10170004
                finnal["result"] = "请输入正确的数据格式"
                return finnal
            elif "HTTP Error" in str(e):
                print(str(e))
                print(traceback.print_exc())
                finnal = {}
                finnal["timestamp"] = time.time()
                finnal["message"] = "imgurl请求403错误"
                finnal["status"] = 500
                finnal["code"] = 10170002
                finnal["result"] = "检查请求的url"
                return finnal
            print(str(e))
            print(traceback.print_exc())
            finnal = {}
            finnal["timestamp"] = time.time()
            finnal["message"] = "模型计算错误"
            finnal["status"] = 500
            finnal["code"] = 10170001
            finnal["result"] = "请检查输入格式"
            return finnal

class PrServiceHandler(tornado.web.RequestHandler):
    def initialize(self, modelService):
        self.modelService = modelService
        self.set_default_header()

    executor = ThreadPoolExecutor(5)

    @tornado.gen.coroutine
    def post(self):
        res = yield self.data_process_and_model()
        print("&&&&&&&&&&&&&&&&&&&&&&&&")
        print((res))
        self.write(json.dumps(res, ensure_ascii=False))
        self.finish()

    def set_default_header(self):
        print("setting headers!!!")
        self.set_header("Access-Control-Allow-Origin", "*")
        self.set_header("Access-Control-Allow-Headers", "*")
        self.set_header(
            "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"
        )
        self.set_header("Content-Type", "application/json; charset=UTF-8")
        self.set_header("Access-Control-Allow-Headers", "Content-Type")

    def options(self):
        pass

    @run_on_executor
    def data_process_and_model(self):
        wav_file = ""
        try:
            totalStartTime = round(time.time(), 3)
            json_content = self.request.body.decode("utf8")
            json_data = json.loads(json_content)
            audioBae64 = json_data["wavContent"]
            totalEndTime = round(time.time(), 3)
            execution_time = totalEndTime - totalStartTime
            print("***")
            print(execution_time)
            wav_file = process_audio_base64(audioBae64, temp_dir="./bak/")
            bool = validate_wav_file(wav_file)
            wav_file, wav_file2 = deep_analyze_wav(wav_file)
            print(bool)
            res = self.modelService.onlinePrmodel.generate(
                input=wav_file,
                batch_size_s=300,
                hotword="魔塔",
                cache={},
            )
            if len(res) > 0:
                text = res[0]["text"]
                print(f"识别成功: {text}")
                print(f"识别成功: {res}")
                extracted_results = []
                if res and isinstance(res, list) and len(res) > 0:
                    sentence_info = res[0].get("sentence_info", [])
                    for item in sentence_info:
                        extracted_item = {
                            "text": item.get("text", ""),
                            "start": item.get("start", 0),
                            "end": item.get("end", 0),
                            "spk": item.get("spk", 0),
                        }
                        extracted_results.append(extracted_item)
                totalEndTime = round(time.time(), 3)
                execution_time = totalEndTime - totalStartTime
                print("***")
                print(execution_time)
                if wav_file and os.path.exists(wav_file):
                    try:
                        os.remove(wav_file)
                        print(f"已删除临时文件: {wav_file}")
                    except OSError as e:
                        print(f"删除临时文件失败: {str(e)}")
                resultInResult = {}
                resultInResult["timestamp"] = time.time()
                resultInResult["message"] = "ok"
                resultInResult["status"] = 200
                resultInResult["code"] = 10170000
                resultInResult["result"] = extracted_results
                return resultInResult
            else:
                print("失败入参")
                print(json_content)
                resultInResult = {}
                resultInResult["timestamp"] = time.time()
                resultInResult["message"] = "ok"
                resultInResult["status"] = 200
                resultInResult["code"] = 10170000
                resultInResult["result"] = []
                return resultInResult
        except Exception as e:
            if wav_file and os.path.exists(wav_file):
                try:
                    os.remove(wav_file)
                    print(f"已删除临时文件: {wav_file}")
                except OSError as e:
                    print(f"删除临时文件失败: {str(e)}")
            print("Traceback (most recent call last):")
            print(traceback.format_exc())
            resultInResult = {}
            resultInResult["timestamp"] = time.time()
            resultInResult["message"] = "error"
            resultInResult["status"] = 500
            resultInResult["code"] = 10170001
            resultInResult["content"] = str(e)
            return resultInResult

    def audio_to_text(self, audio_url, hotword="魔搭"):
        temp_file_path = None
        try:
            print(f"开始处理音频URL: {audio_url}")
            response = requests.get(audio_url, verify=False, timeout=1200)
            response.raise_for_status()
            print(f"音频文件下载成功,大小: {len(response.content)} 字节")
            parsed_url = urlparse(audio_url)
            original_name = os.path.basename(parsed_url.path)
            if "." in original_name:
                extension = "." + original_name.split(".")[-1].lower()
                valid_extensions = {
                    ".wav",
                    ".mp3",
                    ".flac",
                    ".ogg",
                    ".m4a",
                    ".aac",
                    ".webm",
                    ".mp4",
                }
                if extension not in valid_extensions:
                    print(f"未知的扩展名 '{extension}',使用.wav作为默认")
                    extension = ".wav"
            else:
                print("URL中没有扩展名,使用.wav作为默认")
                extension = ".wav"
            file_name = f"{str(uuid.uuid4())}{extension}"
            temp_dir = tempfile.gettempdir()
            temp_file_path = os.path.join(temp_dir, file_name)
            with open(temp_file_path, "wb") as temp_file:
                temp_file.write(response.content)
            print(f"音频文件保存到临时路径: {temp_file_path}")
            if not os.path.exists(temp_file_path):
                raise FileNotFoundError(f"临时文件创建失败: {temp_file_path}")
            file_size = os.path.getsize(temp_file_path)
            if file_size == 0:
                raise ValueError("下载的音频文件为空")
            print(f"临时文件大小: {file_size} 字节")
            print("开始语音识别...")
            res = self.modelService.prmodel.generate(
                input=temp_file_path,
                batch_size_s=300,
                hotword=hotword,
            )
            recognized_text = res[0]["text"]
            print(f"识别成功: {recognized_text}")
            return recognized_text
        except requests.exceptions.RequestException as e:
            print(f"下载失败: {str(e)}")
            return None
        except Exception as e:
            print(f"识别错误: {str(e)}")
            print("Traceback (most recent call last):")
            print(traceback.format_exc())
            return None
        finally:
            if temp_file_path and os.path.exists(temp_file_path):
                try:
                    os.remove(temp_file_path)
                    print(f"已删除临时文件: {temp_file_path}")
                except OSError as e:
                    print(f"删除临时文件失败: {str(e)}")

def filter_text(text, keep_numbers=True, keep_spaces=False):
    patterns = []
    patterns.append(r"[\u3400-\u4DBF\u4E00-\u9FFF\U00020000-\U0002A6DF]")
    patterns.append(r"[a-zA-Z]")
    patterns.append(
        r'[,。!?;:"“”‘’、…—《》【】〔〕〈〉『』「」〖〗().・,\.!?;:"\'\/\\\|\[\]\(\)\{\}\-–—~`]'
    )
    if keep_numbers:
        patterns.append(r"[0-9]")
    if keep_spaces:
        patterns.append(r"\s")
    pattern = f'({"|".join(patterns)})'
    return "".join(re.findall(pattern, text))

def process_base64_audio(base64_audio, temp_dir=None):
    temp_file_path = None
    try:
        try:
            audio_bytes = base64.b64decode(base64_audio)
            print("Base64 解码成功")
        except Exception as e:
            print(f"Base64 解码失败: {str(e)}")
            return None, None
        if temp_dir:
            if not os.path.exists(temp_dir):
                os.makedirs(temp_dir)
                print(f"创建临时目录: {temp_dir}")
        else:
            temp_dir = tempfile.gettempdir()
        file_name = f"{str(uuid.uuid4())}.wav"
        temp_file_path = os.path.join(temp_dir, file_name)
        print(f"生成临时文件路径: {temp_file_path}")
        try:
            with open(temp_file_path, "wb") as wav_file:
                wav_file.write(audio_bytes)
            print(f"成功写入 WAV 文件: {temp_file_path}")
        except Exception as e:
            print(f"写入 WAV 文件失败: {str(e)}")
            return temp_file_path
        return temp_file_path
    except Exception as e:
        print(f"处理过程中出错: {str(e)}")
        return temp_file_path

def process_audio_base64(base64_audio, temp_dir=None):
    try:
        print("开始处理音频 Base64 数据")
        try:
            audio_bytes = base64.b64decode(base64_audio)
            print(f"Base64 解码成功,音频数据大小: {len(audio_bytes)} 字节")
        except Exception as e:
            logging.error(f"Base64 解码失败: {str(e)}")
            return None
        if len(audio_bytes) < 44 or audio_bytes[:4] != b"RIFF":
            logging.warning("音频数据缺少有效的 WAV 头信息")
            audio_bytes = add_wav_header(audio_bytes)
            print("已添加 WAV 头信息")
        if temp_dir:
            os.makedirs(temp_dir, exist_ok=True)
            print(f"使用指定临时目录: {temp_dir}")
        else:
            temp_dir = tempfile.gettempdir()
            print(f"使用系统临时目录: {temp_dir}")
        file_name = f"{str(uuid.uuid4())}.wav"
        temp_file_path = os.path.join(temp_dir, file_name)
        print(f"生成临时文件路径: {temp_file_path}")
        try:
            with open(temp_file_path, "wb") as wav_file:
                wav_file.write(audio_bytes)
            print(f"成功写入 WAV 文件: {temp_file_path}")
        except Exception as e:
            logging.error(f"写入 WAV 文件失败: {str(e)}")
            return None
        if not os.path.exists(temp_file_path):
            logging.error(f"临时文件创建失败: {temp_file_path}")
            return None
        file_size = os.path.getsize(temp_file_path)
        if file_size == 0:
            logging.error("写入的音频文件为空")
            return None
        print(f"音频文件验证成功,大小: {file_size} 字节")
        print(temp_file_path)
        return temp_file_path
    except Exception as e:
        logging.error(f"处理过程中出错: {str(e)}")
        return None

def add_wav_header(audio_bytes, sample_rate=16000, channels=1, sample_width=2):
    data_size = len(audio_bytes)
    file_size = data_size + 36
    header = bytearray()
    header.extend(b"RIFF")
    header.extend(struct.pack("<I", file_size))
    header.extend(b"WAVE")
    header.extend(b"fmt ")
    header.extend(struct.pack("<I", 16))
    header.extend(struct.pack("<H", 1))
    header.extend(struct.pack("<H", channels))
    header.extend(struct.pack("<I", sample_rate))
    byte_rate = sample_rate * channels * sample_width
    header.extend(struct.pack("<I", byte_rate))
    block_align = channels * sample_width
    header.extend(struct.pack("<H", block_align))
    bits_per_sample = sample_width * 8
    header.extend(struct.pack("<H", bits_per_sample))
    header.extend(b"data")
    header.extend(struct.pack("<I", data_size))
    return bytes(header) + audio_bytes

def convert_to_simplified_chinese(text):
    try:
        return zhconv.convert(text, "zh-cn")
    except:
        trad_simp_map = {
            "麼": "么", "為": "为", "裡": "里", "畫": "画", "複": "复", "週": "周",
            "於": "于", "後": "后", "纔": "才", "別": "别", "當": "当", "點": "点",
            "電": "电", "話": "话", "邊": "边", "過": "过", "還": "还", "這": "这",
            "個": "个", "時": "时", "會": "会", "體": "体", "國": "国", "學": "学",
            "書": "书", "寫": "写", "讀": "读", "認": "认", "識": "识", "語": "语",
            "言": "言", "係": "系", "統": "统", "腦": "脑", "網": "网", "絡": "络",
            "資": "资", "訊": "讯", "郵": "邮", "件": "件", "視": "视",
        }
        for trad, simp in trad_simp_map.items():
            text = text.replace(trad, simp)
        return text

def validate_wav_file(file_path):
    try:
        with wave.open(file_path, "rb") as wav_file:
            n_channels = wav_file.getnchannels()
            sample_width = wav_file.getsampwidth()
            frame_rate = wav_file.getframerate()
            n_frames = wav_file.getnframes()
        print(f"验证通过: {file_path}")
        print(f"声道数: {n_channels}, 采样宽度: {sample_width}字节")
        print(f"采样率: {frame_rate}Hz, 帧数: {n_frames}")
        return True
    except Exception as e:
        print(f"WAV文件验证失败: {str(e)}")
        return False

def deep_analyze_wav(file_path):
    import wave
    import audioop
    import struct
    with wave.open(file_path, "rb") as wf:
        nchannels = wf.getnchannels()
        sampwidth = wf.getsampwidth()
        framerate = wf.getframerate()
        nframes = wf.getnframes()
        comptype = wf.getcomptype()
        compname = wf.getcompname()
        params = wf.getparams()
        frames = wf.readframes(nframes)
        print(f"深度分析文件: {file_path}")
        print(f" 声道数: {nchannels}, 采样宽度: {sampwidth}字节")
        print(f" 采样率: {framerate}Hz, 帧数: {nframes}")
        print(f" 压缩类型: {comptype} ({compname})")
        print(f" 数据类型: {type(frames).__name__}, 大小: {len(frames)}字节")
        try:
            if sampwidth == 3:
                converted = audioop.lin2lin(frames, sampwidth, 2)
            elif sampwidth == 1:
                converted = audioop.bias(audioop.lin2lin(frames, sampwidth, 2), 2, -128)
            else:
                converted = frames
            return converted, framerate
        except Exception as e:
            print(f"格式转换失败: {e}")
            return None, None

def handle_audio(audio_data, content_type=None):
    try:
        buffer = BytesIO(audio_data)
        waveform, sample_rate = torchaudio.load(buffer)
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=16000
            )
            waveform = resampler(waveform)
            sample_rate = 16000
        waveform_np = waveform.numpy().squeeze(0)
        text_result = asr_pipe(
            {"array": waveform_np, "sampling_rate": sample_rate},
            batch_size=8,
            generate_kwargs={"language": "zh", "task": "transcribe"},
        )["text"]
        return convert_to_simplified_chinese(text_result)
    except Exception as e:
        raise RuntimeError(f"音频处理失败: {str(e)}")

def transcribe_audio_from_url(audio_url, simplified=True):
    try:
        response = requests.get(audio_url, timeout=300, verify=False)
        response.raise_for_status()
        result = handle_audio(response.content)
        return convert_to_simplified_chinese(result) if simplified else result
    except requests.exceptions.RequestException as e:
        return f"下载失败: {str(e)}"
    except RuntimeError as e:
        return f"转录失败: {str(e)}"
    except Exception as e:
        return f"未知错误: {str(e)}"

def make_app(modelService):
    return tornado.web.Application(
        [
            (
                r"/api/v1.0/funasr/service",
                OfflineServiceHandler,
                {"modelService": modelService},
            ),
            (
                r"/api/speechBytes2text",
                OnlineServiceHandler,
                {"modelService": modelService},
            ),
            (
                r"/api/speechOnlineText",
                PrServiceHandler,
                {"modelService": modelService},
            ),
            (
                r"/api/v2.0/funasr/service",
                OfflineServiceHandlerV2,
                {"modelService": modelService},
            ),
        ]
    )

if __name__ == "__main__":
    modelService = ModelService()
    app = make_app(modelService)
    cf = configparser.ConfigParser()
    cf.read("./config.ini")
    port = cf.get("server", "port")
    print(port)
    app.listen(port)
    tornado.ioloop.IOLoop.current().start()

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐