From 671efeb34057c505b3fa448058925e16cf463613 Mon Sep 17 00:00:00 2001 From: YijiCu <1165084283@qq.com> Date: Tue, 13 May 2025 10:53:27 +0800 Subject: [PATCH 1/3] Update README.md --- README.md | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ab57007..3525a1c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,21 @@ +## 2025_05_13 更新 +本项目代码于上次更新后进行了7轮版本更新,中间进行了几轮重建,目前开源出三份核心代码文件供交流,也作为笔者的一些记录,希望大佬们多多指点! + +### /2025_05_13/chatui.py -- 主要执行文件 +执行文件,包含项目整体UI框架 + +### /2025_05_13/model_api.py -- 模型交互模块(被调用) +1、模型交互代码,目前通过调用本地部署的dify模型问答api来与模型交互(dify上的模型也是本地部署的) +2、用户上传文件问答功能 +3、根据模型上文限制,平衡RAG召回内容与用户上传内容的输入 + +### /2025_05_13/rag_milvus.py -- RAG功能模块(被调用) +RAG功能模块 + +因数据安全要求,对部分信息做了脱敏处理,大部分的数据、配置代码无法开放~ + +---------------------------------------------------------------------------------------------------------------------------------- +2024_09_30 # 基于InternLM2.5的清洁能源行业知识专家 #1408 上海AI Lab书生·浦语训练营第三期项目 @@ -100,4 +118,4 @@ https://www.bilibili.com/video/BV1FfWde5EQL/ ## 欢迎交流! -如有任何问题或建议,欢迎随时与我们联系。 \ No newline at end of file +如有任何问题或建议,欢迎随时与我们联系。 From ff1d5822c5dd09d111f2f381240b2c716e21789e Mon Sep 17 00:00:00 2001 From: YijiCu <1165084283@qq.com> Date: Tue, 13 May 2025 10:55:42 +0800 Subject: [PATCH 2/3] Create chatui.py --- 2025_05_13/chatui.py | 562 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 562 insertions(+) create mode 100644 2025_05_13/chatui.py diff --git a/2025_05_13/chatui.py b/2025_05_13/chatui.py new file mode 100644 index 0000000..238262b --- /dev/null +++ b/2025_05_13/chatui.py @@ -0,0 +1,562 @@ +import os +import logging +import streamlit as st +from PIL import Image +import base64 +from model_api import UnifiedAssistant + +class UIConfig: + """UI配置类""" + PAGE_TITLE = "xxx大模型AI办公助手" + PAGE_ICON = "🔍" + BACKGROUND_COLOR = "#F3F6FB" + USER_ICON_PATH = '/app/rag/config/提问.png' + AI_ICON_PATH = '/app/rag/config/模型回复机器人.png' + DEEPSEEK_ICON_PATH = '/app/rag/config/小钢笔.png' + SHANGHAI_ICON_PATH = '/app/rag/config/书本.png' + WELCOME_MESSAGE = "xxx小伙伴您好,请输入您的问题,我可以帮您进行报告分析、生成等各种办公作业,并且我还在持续更新!" + + # 选中状态 + SELECTED_BG_COLOR = "#3F70FF" + SELECTED_TEXT_COLOR = "#fff" + SELECTED_FONT_WEIGHT = "bold" + SELECTED_FONT_SIZE = "18px" + + # 未选中状态 + UNSELECTED_BG_COLOR = "#fff" + UNSELECTED_TEXT_COLOR = "#4D4B72" + UNSELECTED_FONT_WEIGHT = "normal" + UNSELECTED_FONT_SIZE = "18px" + + # 思考文字样式 + THINKING_BG_COLOR = "#FAFDFF" + THINKING_TEXT_COLOR = "#8b8b8b" + THINKING_FONT_SIZE = "14px" + THINKING_BORDER_COLOR = "#D5DBDE" + + @staticmethod + def load_image(path: str): + """加载图片""" + try: + if os.path.exists(path): + return Image.open(path) + return None + except Exception as e: + logging.error(f"Error loading image from {path}: {str(e)}") + return None + +def initialize_session_state(): + """初始化会话状态""" + if "messages" not in st.session_state: + st.session_state.messages = [] + # 修改点1:默认为DeepSeek模式,对应rag_enabled为False + if "current_mode" not in st.session_state: + st.session_state.current_mode = "deepseek" # 可选值: "deepseek" 或 "shanghai_rag" + if "rag_enabled" not in st.session_state: + st.session_state.rag_enabled = False + if "assistant" not in st.session_state: + try: + st.session_state.assistant = UnifiedAssistant() + logging.info("UnifiedAssistant 初始化成功") + except Exception as e: + logging.error(f"UnifiedAssistant 初始化失败: {e}") + st.error("系统初始化失败,请刷新页面重试") + +def setup_page(): + """设置页面配置""" + st.set_page_config( + page_title=UIConfig.PAGE_TITLE, + page_icon=UIConfig.PAGE_ICON, + layout="wide", + initial_sidebar_state="expanded" + ) + # 使用h1标签直接设置标题,以便更精确控制样式 + st.markdown(f"

{UIConfig.PAGE_TITLE}

", unsafe_allow_html=True) + st.caption("基于DeepSeek R1定制化开发") + +def setup_background(background_color): + """设置背景样式""" + st.markdown( + f""" + + """, + unsafe_allow_html=True + ) + +def setup_sidebar(): + """设置侧边栏""" + with st.sidebar: + st.title("📚 工作台") + +def setup_sidebar(): + """设置侧边栏""" + with st.sidebar: + # 增大工作台标题字体 + st.markdown('

📚 工作台

', unsafe_allow_html=True) + + # 模式选择按钮,上下布局,使用小齿轮图标 + st.markdown("### ⚙️ 模式选择") + + # DeepSeek办公助手按钮 - 默认选中,添加钢笔图标 + deepseek_selected = st.session_state.current_mode == "deepseek" + if st.button( + "🖋️ DeepSeek AI办公助手", + key="deepseek_button", + use_container_width=True, + type="primary" if deepseek_selected else "secondary" + ): + if not deepseek_selected: # 只有当前未选中时才触发状态变更 + st.session_state.current_mode = "deepseek" + st.session_state.rag_enabled = False + st.rerun() + + # 上海院知识库问答助手按钮,添加书本图标 + shanghai_selected = st.session_state.current_mode == "shanghai_rag" + if st.button( + "📚 上海院知识库问答助手", + key="shanghai_button", + use_container_width=True, + type="primary" if shanghai_selected else "secondary" + ): + if not shanghai_selected: # 只有当前未选中时才触发状态变更 + st.session_state.current_mode = "shanghai_rag" + st.session_state.rag_enabled = True + st.rerun() + + # 设置按钮样式 + st.markdown(f""" + + """, unsafe_allow_html=True) + + # 文件处理 + st.markdown("### 📁 文件处理") + + # 文件上传器的key要保持一致 + uploaded_file = st.file_uploader( + "上传文件", + type=["txt", "pdf", "doc", "docx"], + help="上传文件进行问答", + key="file_uploader" # 添加固定的key + ) + + # 处理文件上传和移除 + if uploaded_file is not None: + # 检查是否需要重新处理文件 + current_file_name = getattr(uploaded_file, 'name', None) + if 'last_processed_file' not in st.session_state or \ + st.session_state.last_processed_file != current_file_name: + + st.write("文件信息:") + st.write(f"- 文件名: {uploaded_file.name}") + st.write(f"- 文件类型: {uploaded_file.type}") + st.write(f"- 文件大小: {uploaded_file.size / 1024:.2f} KB") + + if hasattr(st.session_state, 'assistant'): + # 处理文件内容 + doc_content = st.session_state.assistant.read_file(uploaded_file) + if doc_content: + st.session_state.current_doc_content = doc_content + st.session_state.last_processed_file = current_file_name + st.success("文件处理成功!") + else: + st.error("文件处理失败") + else: + # 文件被移除时清理相关状态 + if 'current_doc_content' in st.session_state: + del st.session_state.current_doc_content + if 'last_processed_file' in st.session_state: + del st.session_state.last_processed_file + + # 清除历史按钮 - 使用独立样式,不受模式选择影响 + st.markdown("### 🗑️ 清除历史") + if st.button("清除对话历史", key="clear_history_button", use_container_width=True): + st.session_state.messages = [] + # 清除文件相关的所有状态 + if 'current_doc_content' in st.session_state: + del st.session_state.current_doc_content + if 'last_processed_file' in st.session_state: + del st.session_state.last_processed_file + st.rerun() + + # 添加清除历史按钮的独立样式 + st.markdown(""" + + """, unsafe_allow_html=True) + +def handle_file_change(): + """处理文件变更的回调函数""" + if 'current_doc_content' in st.session_state: + del st.session_state.current_doc_content + logging.info("文件已移除,清除文档内容") + +def merge_doc_links(doc_links): + """合并来自同一文档的链接""" + if not doc_links or not isinstance(doc_links, list): + return [] + + # 使用字典来合并相同文档的链接 + merged = {} + for doc in doc_links: + if isinstance(doc, dict): + title = doc.get('title') + url = doc.get('url') + if title and url: + # 如果这个文档已经存在,跳过(只保留第一次出现) + if title not in merged: + merged[title] = url + + # 转换回列表格式 + return [{"title": title, "url": url} for title, url in merged.items()] + +def display_chat_history(): + """显示聊天历史""" + # 显示欢迎消息 + if not st.session_state.messages: + with st.chat_message("assistant", avatar=UIConfig.load_image(UIConfig.AI_ICON_PATH)): + st.write(UIConfig.WELCOME_MESSAGE) + + # 显示历史消息 + for message in st.session_state.messages: + with st.chat_message( + message["role"], + avatar=UIConfig.load_image(UIConfig.USER_ICON_PATH if message["role"] == "user" else UIConfig.AI_ICON_PATH) + ): + # 处理思考内容的显示格式 + content = message["content"] + if message["role"] == "assistant" and "" in content and "" in content: + try: + # 分离思考内容和普通内容 + parts = content.split("", 1) # 只分割第一个出现的标签 + before_think = parts[0] + + if len(parts) > 1: + remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签 + think_content = remaining[0] + after_think = remaining[1] if len(remaining) > 1 else "" + + # 显示格式化后的内容 + if before_think.strip(): + st.markdown(before_think) + + # 显示思考内容(使用自定义CSS样式) + if think_content.strip(): + st.markdown(f'
{think_content}
', unsafe_allow_html=True) + + if after_think.strip(): + st.markdown(after_think) + else: + # 如果分割失败,显示原始内容 + st.markdown(content) + except Exception as e: + # 如果处理过程中出错,回退到显示原始内容 + logging.error(f"处理思考内容时出错: {str(e)}") + st.markdown(content) + else: + st.markdown(content) + + # 如果是助手消息且有相关文档,显示文档链接 + if message["role"] == "assistant" and "relevant_docs" in message: + docs = merge_doc_links(message["relevant_docs"]) + if docs: + with st.expander("📑 参考文档来源", expanded=True): + for idx, doc in enumerate(docs, 1): + st.markdown(f"**文档 {idx}**: [{doc['title']}]({doc['url']})") + +def handle_user_input(): + """处理用户输入""" + if not hasattr(st.session_state, 'assistant'): + st.error("系统未正确初始化,请刷新页面重试") + return + + if prompt := st.chat_input("请输入您的问题"): + # 添加用户消息 + st.session_state.messages.append({"role": "user", "content": prompt}) + + # 显示用户消息 + with st.chat_message("user", avatar=UIConfig.load_image(UIConfig.USER_ICON_PATH)): + st.markdown(prompt) + + # 创建新的助手消息容器 + with st.chat_message("assistant", avatar=UIConfig.load_image(UIConfig.AI_ICON_PATH)): + try: + # 创建消息占位符和等待提示占位符 + message_placeholder = st.empty() + waiting_placeholder = st.empty() + + # 在等待占位符中显示等待提示 + with waiting_placeholder: + waiting_placeholder.markdown("🤔 排队等待回答中...") + + # 用一个变量来标记是否已经开始接收响应 + response_started = False + + def streaming_callback(text): + nonlocal response_started + if not response_started: + # 清除等待提示 + waiting_placeholder.empty() + response_started = True + + # 处理思考内容的显示格式 + try: + if "" in text and "" in text: + # 分离思考内容和普通内容 + parts = text.split("", 1) # 只分割第一个出现的标签 + before_think = parts[0] + + if len(parts) > 1: + remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签 + think_content = remaining[0] + after_think = remaining[1] if len(remaining) > 1 else "" + + # 显示格式化后的内容 + html_content = "" + if before_think.strip(): + html_content += before_think + + # 显示思考内容(使用自定义CSS样式) + if think_content.strip(): + html_content += f'
{think_content}
' + + if after_think.strip(): + html_content += after_think + + message_placeholder.markdown(html_content, unsafe_allow_html=True) + else: + # 如果分割失败,显示原始内容 + message_placeholder.markdown(text) + else: + message_placeholder.markdown(text) + except Exception as e: + # 如果处理过程中出错,回退到显示原始内容 + logging.error(f"处理流式思考内容时出错: {str(e)}") + message_placeholder.markdown(text) + + # 处理文档内容和RAG + doc_content = st.session_state.get('current_doc_content') + success, full_response, doc_links = st.session_state.assistant.chat( + prompt, + doc_content=doc_content, + use_rag=st.session_state.rag_enabled, + streaming_callback=streaming_callback + ) + + # 确保等待提示被清除 + waiting_placeholder.empty() + + # 显示完整响应 + if success: + # 处理思考内容的显示格式(确保在流式显示结束后也能正确显示) + try: + if "" in full_response and "" in full_response: + # 分离思考内容和普通内容 + parts = full_response.split("", 1) # 只分割第一个出现的标签 + before_think = parts[0] + + if len(parts) > 1: + remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签 + think_content = remaining[0] + after_think = remaining[1] if len(remaining) > 1 else "" + + # 显示格式化后的内容 + html_content = "" + if before_think.strip(): + html_content += before_think + + # 显示思考内容(使用自定义CSS样式) + if think_content.strip(): + html_content += f'
{think_content}
' + + if after_think.strip(): + html_content += after_think + + message_placeholder.markdown(html_content, unsafe_allow_html=True) + else: + # 如果分割失败,显示原始内容 + message_placeholder.markdown(full_response) + else: + message_placeholder.markdown(full_response) + except Exception as e: + # 如果处理过程中出错,回退到显示原始内容 + logging.error(f"处理最终思考内容时出错: {str(e)}") + message_placeholder.markdown(full_response) + + # 合并并显示文档链接 + merged_doc_links = merge_doc_links(doc_links) + if merged_doc_links: + with st.expander("📑 参考文档来源", expanded=True): + for idx, doc in enumerate(merged_doc_links, 1): + st.markdown(f"**文档 {idx}**: [{doc['title']}]({doc['url']})") + + # 保存消息到历史记录(保存合并后的文档链接) + st.session_state.messages.append({ + "role": "assistant", + "content": full_response, + "relevant_docs": merged_doc_links + }) + else: + st.error(full_response) + + except Exception as e: + error_msg = f"处理出错: {str(e)}" + logging.error(error_msg) + st.error(error_msg) + +def main(): + """主函数""" + try: + # 配置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - [%(levelname)s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 初始化会话状态 + initialize_session_state() + + # 设置页面 + setup_page() + + # 设置背景:修改点2,使用背景色值而非背景图片 + setup_background(UIConfig.BACKGROUND_COLOR) + + # 设置侧边栏 + setup_sidebar() + + # 显示聊天历史 + display_chat_history() + + # 处理用户输入 + handle_user_input() + + except Exception as e: + logging.error(f"程序运行出错: {str(e)}") + st.error("系统出错,请刷新页面重试") + +if __name__ == "__main__": + main() From b76da7d41cc6400225e7bf8ff02f90d629535b20 Mon Sep 17 00:00:00 2001 From: YijiCu <1165084283@qq.com> Date: Tue, 13 May 2025 10:56:46 +0800 Subject: [PATCH 3/3] Add files via upload --- 2025_05_13/model_api.py | 391 +++++++++++++++++++++++++++++++++++++++ 2025_05_13/rag_milvus.py | 165 +++++++++++++++++ 2 files changed, 556 insertions(+) create mode 100644 2025_05_13/model_api.py create mode 100644 2025_05_13/rag_milvus.py diff --git a/2025_05_13/model_api.py b/2025_05_13/model_api.py new file mode 100644 index 0000000..9a57250 --- /dev/null +++ b/2025_05_13/model_api.py @@ -0,0 +1,391 @@ +import time +import tempfile +import os +import io +import requests +import json +import logging +from docx import Document +import pdfplumber +from rag_milvus import VectorRetrieval +import torch +import sseclient + +class UnifiedAssistant: + def __init__(self): + """初始化统一助手""" + # API配置 + self.base_url = "***" #外部展示时覆盖 + self.api_key = "***" #外部展示时覆盖 + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 系统提示词 + self.prompt_template = """你是xxxxx的Deepseek办公助手,请根据用户上传材料(若有)以及企业内部知识(若有)对用户问题进行解答。 +用户问题:{user_input} +用户上传材料:{user_file} +企业内部知识:{knowledge} +请回答:""" + + # GPU设置 + try: + if torch.cuda.is_available(): + torch.cuda.set_device(1) + logging.info("已设置默认GPU为: cuda:1") + + # 初始化RAG系统 + self.retriever = VectorRetrieval( + model_dir="/app/rag/modeldir/bce-embedding-base_v1", + db_path="./milvus_db/vector.db", + collection_name="kms", + device="cuda:1" + ) + self.rag_available = True + logging.info("RAG系统初始化成功") + + except Exception as e: + logging.error(f"初始化失败: {e}") + self.rag_available = False + + def read_file(self, file): + """统一的文件读取接口""" + try: + filename = getattr(file, 'name', 'Unknown') + logging.info(f"开始处理文件: {filename}") + + # 获取文件扩展名 + file_ext = os.path.splitext(filename)[1].lower() + + # 根据文件类型选择相应的处理方法 + if file_ext == '.pdf': + return self.read_pdf(file) + elif file_ext == '.docx': + return self.read_docx(file) + elif file_ext == '.doc': + return self.read_doc(file) + elif file_ext == '.txt': + return self.read_txt(file) + else: + logging.error(f"不支持的文件类型: {file_ext}") + return None + + except Exception as e: + logging.error(f"读取文件失败: {str(e)}") + return None + + def read_pdf(self, file): + """读取PDF文件内容""" + try: + # 获取文件内容 + if hasattr(file, 'getvalue'): + pdf_data = io.BytesIO(file.getvalue()) + else: + pdf_data = file + + text_content = [] + with pdfplumber.open(pdf_data) as pdf: + for page in pdf.pages: + text = page.extract_text() + if text.strip(): + text_content.append(text) + + content = "\n".join(text_content) + logging.info("PDF文件内容提取成功") + return content + + except Exception as e: + logging.error(f"读取PDF文件失败: {str(e)}") + return None + + def read_doc(self, file): + """读取旧版 Word (.doc) 文件内容""" + try: + import subprocess + + # 如果是上传的文件对象,需要先保存到临时文件 + if hasattr(file, 'getvalue'): + with tempfile.NamedTemporaryFile(delete=False, suffix='.doc') as temp_file: + temp_file.write(file.getvalue()) + temp_path = temp_file.name + else: + temp_path = file + + try: + # 使用 antiword 提取文本 + result = subprocess.run(['antiword', temp_path], capture_output=True, text=True) + + if result.returncode != 0: + logging.error(f"Antiword 处理失败: {result.stderr}") + return None + + text = result.stdout + + # 清理和格式化文本 + text = text.replace('\r', '\n') + text = '\n'.join(line.strip() for line in text.split('\n') if line.strip()) + + logging.info("DOC文件内容提取成功") + return text + + finally: + # 如果使用了临时文件,需要删除它 + if hasattr(file, 'getvalue') and os.path.exists(temp_path): + os.unlink(temp_path) + + except Exception as e: + logging.error(f"读取DOC文件失败: {str(e)}") + return None + + def read_docx(self, file): + """读取Word (.docx) 文件内容""" + try: + if hasattr(file, 'getvalue'): + doc = Document(io.BytesIO(file.getvalue())) + else: + doc = Document(file) + + # 提取文本 + text_content = [] + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text_content.append(paragraph.text) + + content = "\n".join(text_content) + logging.info("DOCX文件内容提取成功") + return content + + except Exception as e: + logging.error(f"读取DOCX文件失败: {str(e)}") + return None + + def read_txt(self, file): + """读取TXT文件内容""" + try: + # 获取文件内容 + if hasattr(file, 'getvalue'): + content = file.getvalue().decode('utf-8') + else: + content = file.read().decode('utf-8') + + logging.info("TXT文件内容提取成功") + return content + + except UnicodeDecodeError: + try: + # 如果UTF-8解码失败,尝试使用GBK + if hasattr(file, 'getvalue'): + content = file.getvalue().decode('gbk') + else: + content = file.read().decode('gbk') + logging.info("TXT文件内容提取成功(GBK编码)") + return content + except Exception as e: + logging.error(f"读取TXT文件失败: {str(e)}") + return None + except Exception as e: + logging.error(f"读取TXT文件失败: {str(e)}") + return None + + def truncate_text(self, text, max_length, add_ellipsis=True): + """智能截断文本到指定长度,尽量保持完整句子和段落""" + if not text or len(text) <= max_length: + return text + + # 首先按段落分割 + paragraphs = text.split('\n') + truncated = [] + current_length = 0 + + for para in paragraphs: + if not para.strip(): + continue + + # 如果单个段落就超过了最大长度 + if len(para) > max_length: + # 按句子分割 + sentences = para.split('。') + for sentence in sentences: + if current_length + len(sentence) + 1 > max_length: + break + truncated.append(sentence + '。') + current_length += len(sentence) + 1 + else: + # 如果添加整个段落后不超过最大长度 + if current_length + len(para) + 1 <= max_length: + truncated.append(para) + current_length += len(para) + 1 + else: + break + + result = '\n'.join(truncated) + if add_ellipsis and result != text: + result = result.rstrip('。\n') + '...' + + return result + + def get_rag_knowledge(self, query, limit=5): + """获取RAG检索结果""" + if not self.rag_available: + return None, [] + + try: + results = self.retriever.search(query, limit=limit) + if not results or not results[0]: + return None, [] + + knowledge_parts = [] + doc_links = [] + + for item in results[0]: + try: + metadata = item.get('entity', {}).get('metadata', {}) + if isinstance(metadata, str): + metadata = json.loads(metadata) + + similarity = 1 - item.get('distance', 0) + text = item.get('entity', {}).get('text', '') + + if metadata and text: + knowledge_parts.append( + f"相关度{similarity:.2f}的内容:\n{text}\n" + f"(来源:{metadata.get('title', '未知文档')})" + ) + + if 'url' in item: + doc_links.append({ + "title": metadata.get('title', '未知文档'), + "url": item['url'] + }) + except Exception as e: + logging.warning(f"处理搜索结果时出错: {e}") + continue + + return "\n\n".join(knowledge_parts), doc_links + + except Exception as e: + logging.error(f"RAG检索失败: {e}") + return None, [] + + def process_stream(self, response, streaming_callback=None): + """处理流式响应""" + try: + full_response = "" + client = sseclient.SSEClient(response) + + for event in client.events(): + if event.data: + try: + data = json.loads(event.data) + event_type = data.get("event") + + # 检查错误状态 + if event_type == "error": + logging.error(f"Stream error: {data.get('message')}") + break + + # 检查消息结束事件 + if event_type == "message_end": + logging.info("收到message_end事件,响应生成完成") + return full_response + + # 处理新的文本块 + if event_type == "message" and "answer" in data: + chunk = data["answer"] + if chunk: # 确保chunk不为空 + full_response += chunk + if streaming_callback: + streaming_callback(full_response) + + except json.JSONDecodeError: + continue + + logging.info("事件流结束") + return full_response + + except Exception as e: + logging.error(f"处理流式响应时出错: {str(e)}") + return None + + def chat(self, user_input, doc_content=None, use_rag=False, streaming_callback=None): + try: + # 准备对话内容 + file_content = doc_content if doc_content else "无上传文件" + knowledge, doc_links = "无相关知识", [] + + if use_rag and self.rag_available: + knowledge, doc_links = self.get_rag_knowledge(user_input) + knowledge = knowledge if knowledge else "未找到相关知识" + + # 计算系统提示词和用户问题的长度 + base_prompt = self.prompt_template.format( + user_input=user_input, + user_file="", # 临时空内容用于计算基础长度 + knowledge="" + ) + base_length = len(base_prompt) + MAX_LENGTH = 4400 + + # 计算可用于文件内容的最大长度 + available_length = MAX_LENGTH - base_length - 100 # 预留一些空间给格式化字符 + if file_content and file_content != "无上传文件": + # 使用更大的限制,因为没有使用RAG + if not use_rag: + available_length = min(available_length, 4000) # 给文件内容分配更多空间 + file_content = self.truncate_text(file_content, available_length) + logging.info(f"处理后的文件内容长度: {len(file_content)}") + + # 组装最终提示词 + prompt = self.prompt_template.format( + user_input=user_input, + user_file=file_content, + knowledge=knowledge + ) + + logging.info(f"最终提示词长度: {len(prompt)}") + + # 准备请求 + payload = { + "inputs": {}, + "query": prompt, + "response_mode": "streaming", + "user": "test_user_1" + } + + # 发送请求时修改超时设置 + with requests.post( + f"{self.base_url}/chat-messages", + headers=self.headers, + json=payload, + stream=True, + timeout=(30, 120) # 连接超时30秒,读取超时120秒 + ) as response: + response.raise_for_status() + + # 处理响应 + full_response = self.process_stream(response, streaming_callback) + if full_response: # 改为检查响应内容是否为空 + logging.info(f"获取响应成功,长度: {len(full_response)}") + return True, full_response, doc_links + + logging.error("未能获取有效响应") + return False, "生成回答失败", [] + + except requests.Timeout: + error_msg = "请求超时" + logging.error(error_msg) + return False, error_msg, [] + except Exception as e: + error_msg = f"聊天过程出错: {str(e)}" + logging.error(error_msg) + return False, error_msg, [] + + def cleanup(self): + """清理资源""" + if self.rag_available and hasattr(self, 'retriever'): + try: + self.retriever.cleanup() + logging.info("资源清理完成") + except Exception as e: + logging.error(f"清理资源时出错: {e}") diff --git a/2025_05_13/rag_milvus.py b/2025_05_13/rag_milvus.py new file mode 100644 index 0000000..7dccc1c --- /dev/null +++ b/2025_05_13/rag_milvus.py @@ -0,0 +1,165 @@ +import os +import logging +import json +import torch +from pymilvus import MilvusClient +from BCEmbedding import EmbeddingModel + +class VectorRetrieval: + def __init__( + self, + model_dir="/app/rag/modeldir/bce-embedding-base_v1", + db_path="./milvus_db/vector.db", + collection_name="kms", + device="cuda:1", + doclink_path="/app/rag/kb/kms/doclink" + ): + """初始化向量检索类""" + logging.info("初始化向量检索系统...") + self.device = device + self.doclink_path = doclink_path + + try: + # 设置当前设备 + device_id = int(self.device.split(':')[1]) + torch.cuda.set_device(device_id) + + # 加载模型 + self.model = EmbeddingModel( + model_name_or_path=model_dir, + device=device + ) + logging.info(f"成功在设备 {self.device} 上初始化Embedding模型") + + # 连接数据库 + try: + self.client = MilvusClient(uri=db_path) + self.collection_name = collection_name + + if self.client.has_collection(self.collection_name): + self.client.load_collection(self.collection_name) + logging.info(f"Collection {self.collection_name} 加载成功") + else: + raise ValueError(f"Collection {collection_name} 不存在!") + except Exception as e: + raise RuntimeError(f"连接数据库失败: {str(e)}") + + except Exception as e: + logging.error(f"初始化失败: {str(e)}") + raise + + def find_doc_url(self, filename): + """根据文件名在doclink目录下查找对应的URL""" + try: + # 确保文件名是解码后的中文 + if isinstance(filename, bytes): + filename = filename.decode('utf-8') + + link_file_path = os.path.join(self.doclink_path, f"{filename}.txt") + logging.info(f"尝试读取link文件: {link_file_path}") + + if not os.path.exists(link_file_path): + logging.warning(f"找不到对应的link文件: {link_file_path}") + return None + + with open(link_file_path, 'r', encoding='utf-8') as f: + url = f.read().strip() + logging.info(f"成功读取URL for {filename}: {url}") + + return url + + except Exception as e: + logging.error(f"读取文档URL失败 - 文件名: {filename}, 错误: {str(e)}") + return None + + def search(self, query, limit=5): + """执行向量检索并返回带URL的结果""" + try: + logging.info(f"开始执行向量检索 - 查询: {query}, 限制数量: {limit}") + + query_vector = self.model.encode([query])[0].tolist() + logging.info("成功生成查询向量") + + results = self.client.search( + collection_name=self.collection_name, + data=[query_vector], + limit=limit, + output_fields=["text", "metadata"] + ) + + if not results or not results[0]: + logging.warning("未找到相关结果") + return None + + logging.info(f"检索到 {len(results[0])} 条结果") + + # 处理搜索结果 + for idx, result in enumerate(results[0], 1): + logging.info(f"\n--- 结果 {idx} ---") + + # 记录原始结果 + logging.info(f"原始结果: {json.dumps(result, ensure_ascii=False)}") + + # 处理元数据 + metadata = result.get('entity', {}).get('metadata') if 'entity' in result else result.get('metadata') + if metadata: + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + logging.warning(f"metadata解析失败: {metadata}") + continue + + logging.info(f"元数据: {json.dumps(metadata, ensure_ascii=False)}") + + # 获取文件名和URL + if 'source' in metadata: + filename = os.path.basename(metadata['source']) + url = self.find_doc_url(filename) + if url: + result['url'] = url + + # 记录文本内容 + text = result.get('entity', {}).get('text') if 'entity' in result else result.get('text') + if text: + text_preview = text[:200] + "..." if len(text) > 200 else text + logging.info(f"文本内容预览:\n{text_preview}") + + return results + + except Exception as e: + logging.error(f"搜索过程出错: {str(e)}", exc_info=True) + return None + + def cleanup(self): + """清理资源""" + if hasattr(self, 'client'): + try: + self.client.close() + logging.info("数据库连接已关闭") + except Exception as e: + logging.error(f"关闭数据库连接时出错: {str(e)}") + +def main(): + # 配置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - [%(levelname)s] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + try: + # 初始化检索系统 + retriever = VectorRetrieval() + + # 获取用户输入 + query = input("请输入搜索问题: ") + + # 执行搜索 + retriever.search(query) + + except Exception as e: + logging.error(f"发生错误: {str(e)}") + +if __name__ == "__main__": + main()