From 671efeb34057c505b3fa448058925e16cf463613 Mon Sep 17 00:00:00 2001
From: YijiCu <1165084283@qq.com>
Date: Tue, 13 May 2025 10:53:27 +0800
Subject: [PATCH 1/3] Update README.md
---
README.md | 20 +++++++++++++++++++-
1 file changed, 19 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ab57007..3525a1c 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,21 @@
+## 2025_05_13 更新
+本项目代码于上次更新后进行了7轮版本更新,中间进行了几轮重建,目前开源出三份核心代码文件供交流,也作为笔者的一些记录,希望大佬们多多指点!
+
+### /2025_05_13/chatui.py -- 主要执行文件
+执行文件,包含项目整体UI框架
+
+### /2025_05_13/model_api.py -- 模型交互模块(被调用)
+1、模型交互代码,目前通过调用本地部署的dify模型问答api来与模型交互(dify上的模型也是本地部署的)
+2、用户上传文件问答功能
+3、根据模型上文限制,平衡RAG召回内容与用户上传内容的输入
+
+### /2025_05_13/rag_milvus.py -- RAG功能模块(被调用)
+RAG功能模块
+
+因数据安全要求,对部分信息做了脱敏处理,大部分的数据、配置代码无法开放~
+
+----------------------------------------------------------------------------------------------------------------------------------
+2024_09_30
# 基于InternLM2.5的清洁能源行业知识专家 #1408
上海AI Lab书生·浦语训练营第三期项目
@@ -100,4 +118,4 @@ https://www.bilibili.com/video/BV1FfWde5EQL/
## 欢迎交流!
-如有任何问题或建议,欢迎随时与我们联系。
\ No newline at end of file
+如有任何问题或建议,欢迎随时与我们联系。
From ff1d5822c5dd09d111f2f381240b2c716e21789e Mon Sep 17 00:00:00 2001
From: YijiCu <1165084283@qq.com>
Date: Tue, 13 May 2025 10:55:42 +0800
Subject: [PATCH 2/3] Create chatui.py
---
2025_05_13/chatui.py | 562 +++++++++++++++++++++++++++++++++++++++++++
1 file changed, 562 insertions(+)
create mode 100644 2025_05_13/chatui.py
diff --git a/2025_05_13/chatui.py b/2025_05_13/chatui.py
new file mode 100644
index 0000000..238262b
--- /dev/null
+++ b/2025_05_13/chatui.py
@@ -0,0 +1,562 @@
+import os
+import logging
+import streamlit as st
+from PIL import Image
+import base64
+from model_api import UnifiedAssistant
+
+class UIConfig:
+ """UI配置类"""
+ PAGE_TITLE = "xxx大模型AI办公助手"
+ PAGE_ICON = "🔍"
+ BACKGROUND_COLOR = "#F3F6FB"
+ USER_ICON_PATH = '/app/rag/config/提问.png'
+ AI_ICON_PATH = '/app/rag/config/模型回复机器人.png'
+ DEEPSEEK_ICON_PATH = '/app/rag/config/小钢笔.png'
+ SHANGHAI_ICON_PATH = '/app/rag/config/书本.png'
+ WELCOME_MESSAGE = "xxx小伙伴您好,请输入您的问题,我可以帮您进行报告分析、生成等各种办公作业,并且我还在持续更新!"
+
+ # 选中状态
+ SELECTED_BG_COLOR = "#3F70FF"
+ SELECTED_TEXT_COLOR = "#fff"
+ SELECTED_FONT_WEIGHT = "bold"
+ SELECTED_FONT_SIZE = "18px"
+
+ # 未选中状态
+ UNSELECTED_BG_COLOR = "#fff"
+ UNSELECTED_TEXT_COLOR = "#4D4B72"
+ UNSELECTED_FONT_WEIGHT = "normal"
+ UNSELECTED_FONT_SIZE = "18px"
+
+ # 思考文字样式
+ THINKING_BG_COLOR = "#FAFDFF"
+ THINKING_TEXT_COLOR = "#8b8b8b"
+ THINKING_FONT_SIZE = "14px"
+ THINKING_BORDER_COLOR = "#D5DBDE"
+
+ @staticmethod
+ def load_image(path: str):
+ """加载图片"""
+ try:
+ if os.path.exists(path):
+ return Image.open(path)
+ return None
+ except Exception as e:
+ logging.error(f"Error loading image from {path}: {str(e)}")
+ return None
+
+def initialize_session_state():
+ """初始化会话状态"""
+ if "messages" not in st.session_state:
+ st.session_state.messages = []
+ # 修改点1:默认为DeepSeek模式,对应rag_enabled为False
+ if "current_mode" not in st.session_state:
+ st.session_state.current_mode = "deepseek" # 可选值: "deepseek" 或 "shanghai_rag"
+ if "rag_enabled" not in st.session_state:
+ st.session_state.rag_enabled = False
+ if "assistant" not in st.session_state:
+ try:
+ st.session_state.assistant = UnifiedAssistant()
+ logging.info("UnifiedAssistant 初始化成功")
+ except Exception as e:
+ logging.error(f"UnifiedAssistant 初始化失败: {e}")
+ st.error("系统初始化失败,请刷新页面重试")
+
+def setup_page():
+ """设置页面配置"""
+ st.set_page_config(
+ page_title=UIConfig.PAGE_TITLE,
+ page_icon=UIConfig.PAGE_ICON,
+ layout="wide",
+ initial_sidebar_state="expanded"
+ )
+ # 使用h1标签直接设置标题,以便更精确控制样式
+ st.markdown(f"
{UIConfig.PAGE_TITLE}
", unsafe_allow_html=True)
+ st.caption("基于DeepSeek R1定制化开发")
+
+def setup_background(background_color):
+ """设置背景样式"""
+ st.markdown(
+ f"""
+
+ """,
+ unsafe_allow_html=True
+ )
+
+def setup_sidebar():
+ """设置侧边栏"""
+ with st.sidebar:
+ st.title("📚 工作台")
+
+def setup_sidebar():
+ """设置侧边栏"""
+ with st.sidebar:
+ # 增大工作台标题字体
+ st.markdown('📚 工作台
', unsafe_allow_html=True)
+
+ # 模式选择按钮,上下布局,使用小齿轮图标
+ st.markdown("### ⚙️ 模式选择")
+
+ # DeepSeek办公助手按钮 - 默认选中,添加钢笔图标
+ deepseek_selected = st.session_state.current_mode == "deepseek"
+ if st.button(
+ "🖋️ DeepSeek AI办公助手",
+ key="deepseek_button",
+ use_container_width=True,
+ type="primary" if deepseek_selected else "secondary"
+ ):
+ if not deepseek_selected: # 只有当前未选中时才触发状态变更
+ st.session_state.current_mode = "deepseek"
+ st.session_state.rag_enabled = False
+ st.rerun()
+
+ # 上海院知识库问答助手按钮,添加书本图标
+ shanghai_selected = st.session_state.current_mode == "shanghai_rag"
+ if st.button(
+ "📚 上海院知识库问答助手",
+ key="shanghai_button",
+ use_container_width=True,
+ type="primary" if shanghai_selected else "secondary"
+ ):
+ if not shanghai_selected: # 只有当前未选中时才触发状态变更
+ st.session_state.current_mode = "shanghai_rag"
+ st.session_state.rag_enabled = True
+ st.rerun()
+
+ # 设置按钮样式
+ st.markdown(f"""
+
+ """, unsafe_allow_html=True)
+
+ # 文件处理
+ st.markdown("### 📁 文件处理")
+
+ # 文件上传器的key要保持一致
+ uploaded_file = st.file_uploader(
+ "上传文件",
+ type=["txt", "pdf", "doc", "docx"],
+ help="上传文件进行问答",
+ key="file_uploader" # 添加固定的key
+ )
+
+ # 处理文件上传和移除
+ if uploaded_file is not None:
+ # 检查是否需要重新处理文件
+ current_file_name = getattr(uploaded_file, 'name', None)
+ if 'last_processed_file' not in st.session_state or \
+ st.session_state.last_processed_file != current_file_name:
+
+ st.write("文件信息:")
+ st.write(f"- 文件名: {uploaded_file.name}")
+ st.write(f"- 文件类型: {uploaded_file.type}")
+ st.write(f"- 文件大小: {uploaded_file.size / 1024:.2f} KB")
+
+ if hasattr(st.session_state, 'assistant'):
+ # 处理文件内容
+ doc_content = st.session_state.assistant.read_file(uploaded_file)
+ if doc_content:
+ st.session_state.current_doc_content = doc_content
+ st.session_state.last_processed_file = current_file_name
+ st.success("文件处理成功!")
+ else:
+ st.error("文件处理失败")
+ else:
+ # 文件被移除时清理相关状态
+ if 'current_doc_content' in st.session_state:
+ del st.session_state.current_doc_content
+ if 'last_processed_file' in st.session_state:
+ del st.session_state.last_processed_file
+
+ # 清除历史按钮 - 使用独立样式,不受模式选择影响
+ st.markdown("### 🗑️ 清除历史")
+ if st.button("清除对话历史", key="clear_history_button", use_container_width=True):
+ st.session_state.messages = []
+ # 清除文件相关的所有状态
+ if 'current_doc_content' in st.session_state:
+ del st.session_state.current_doc_content
+ if 'last_processed_file' in st.session_state:
+ del st.session_state.last_processed_file
+ st.rerun()
+
+ # 添加清除历史按钮的独立样式
+ st.markdown("""
+
+ """, unsafe_allow_html=True)
+
+def handle_file_change():
+ """处理文件变更的回调函数"""
+ if 'current_doc_content' in st.session_state:
+ del st.session_state.current_doc_content
+ logging.info("文件已移除,清除文档内容")
+
+def merge_doc_links(doc_links):
+ """合并来自同一文档的链接"""
+ if not doc_links or not isinstance(doc_links, list):
+ return []
+
+ # 使用字典来合并相同文档的链接
+ merged = {}
+ for doc in doc_links:
+ if isinstance(doc, dict):
+ title = doc.get('title')
+ url = doc.get('url')
+ if title and url:
+ # 如果这个文档已经存在,跳过(只保留第一次出现)
+ if title not in merged:
+ merged[title] = url
+
+ # 转换回列表格式
+ return [{"title": title, "url": url} for title, url in merged.items()]
+
+def display_chat_history():
+ """显示聊天历史"""
+ # 显示欢迎消息
+ if not st.session_state.messages:
+ with st.chat_message("assistant", avatar=UIConfig.load_image(UIConfig.AI_ICON_PATH)):
+ st.write(UIConfig.WELCOME_MESSAGE)
+
+ # 显示历史消息
+ for message in st.session_state.messages:
+ with st.chat_message(
+ message["role"],
+ avatar=UIConfig.load_image(UIConfig.USER_ICON_PATH if message["role"] == "user" else UIConfig.AI_ICON_PATH)
+ ):
+ # 处理思考内容的显示格式
+ content = message["content"]
+ if message["role"] == "assistant" and "" in content and "" in content:
+ try:
+ # 分离思考内容和普通内容
+ parts = content.split("", 1) # 只分割第一个出现的标签
+ before_think = parts[0]
+
+ if len(parts) > 1:
+ remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签
+ think_content = remaining[0]
+ after_think = remaining[1] if len(remaining) > 1 else ""
+
+ # 显示格式化后的内容
+ if before_think.strip():
+ st.markdown(before_think)
+
+ # 显示思考内容(使用自定义CSS样式)
+ if think_content.strip():
+ st.markdown(f'{think_content}
', unsafe_allow_html=True)
+
+ if after_think.strip():
+ st.markdown(after_think)
+ else:
+ # 如果分割失败,显示原始内容
+ st.markdown(content)
+ except Exception as e:
+ # 如果处理过程中出错,回退到显示原始内容
+ logging.error(f"处理思考内容时出错: {str(e)}")
+ st.markdown(content)
+ else:
+ st.markdown(content)
+
+ # 如果是助手消息且有相关文档,显示文档链接
+ if message["role"] == "assistant" and "relevant_docs" in message:
+ docs = merge_doc_links(message["relevant_docs"])
+ if docs:
+ with st.expander("📑 参考文档来源", expanded=True):
+ for idx, doc in enumerate(docs, 1):
+ st.markdown(f"**文档 {idx}**: [{doc['title']}]({doc['url']})")
+
+def handle_user_input():
+ """处理用户输入"""
+ if not hasattr(st.session_state, 'assistant'):
+ st.error("系统未正确初始化,请刷新页面重试")
+ return
+
+ if prompt := st.chat_input("请输入您的问题"):
+ # 添加用户消息
+ st.session_state.messages.append({"role": "user", "content": prompt})
+
+ # 显示用户消息
+ with st.chat_message("user", avatar=UIConfig.load_image(UIConfig.USER_ICON_PATH)):
+ st.markdown(prompt)
+
+ # 创建新的助手消息容器
+ with st.chat_message("assistant", avatar=UIConfig.load_image(UIConfig.AI_ICON_PATH)):
+ try:
+ # 创建消息占位符和等待提示占位符
+ message_placeholder = st.empty()
+ waiting_placeholder = st.empty()
+
+ # 在等待占位符中显示等待提示
+ with waiting_placeholder:
+ waiting_placeholder.markdown("🤔 排队等待回答中...")
+
+ # 用一个变量来标记是否已经开始接收响应
+ response_started = False
+
+ def streaming_callback(text):
+ nonlocal response_started
+ if not response_started:
+ # 清除等待提示
+ waiting_placeholder.empty()
+ response_started = True
+
+ # 处理思考内容的显示格式
+ try:
+ if "" in text and "" in text:
+ # 分离思考内容和普通内容
+ parts = text.split("", 1) # 只分割第一个出现的标签
+ before_think = parts[0]
+
+ if len(parts) > 1:
+ remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签
+ think_content = remaining[0]
+ after_think = remaining[1] if len(remaining) > 1 else ""
+
+ # 显示格式化后的内容
+ html_content = ""
+ if before_think.strip():
+ html_content += before_think
+
+ # 显示思考内容(使用自定义CSS样式)
+ if think_content.strip():
+ html_content += f'{think_content}
'
+
+ if after_think.strip():
+ html_content += after_think
+
+ message_placeholder.markdown(html_content, unsafe_allow_html=True)
+ else:
+ # 如果分割失败,显示原始内容
+ message_placeholder.markdown(text)
+ else:
+ message_placeholder.markdown(text)
+ except Exception as e:
+ # 如果处理过程中出错,回退到显示原始内容
+ logging.error(f"处理流式思考内容时出错: {str(e)}")
+ message_placeholder.markdown(text)
+
+ # 处理文档内容和RAG
+ doc_content = st.session_state.get('current_doc_content')
+ success, full_response, doc_links = st.session_state.assistant.chat(
+ prompt,
+ doc_content=doc_content,
+ use_rag=st.session_state.rag_enabled,
+ streaming_callback=streaming_callback
+ )
+
+ # 确保等待提示被清除
+ waiting_placeholder.empty()
+
+ # 显示完整响应
+ if success:
+ # 处理思考内容的显示格式(确保在流式显示结束后也能正确显示)
+ try:
+ if "" in full_response and "" in full_response:
+ # 分离思考内容和普通内容
+ parts = full_response.split("", 1) # 只分割第一个出现的标签
+ before_think = parts[0]
+
+ if len(parts) > 1:
+ remaining = parts[1].split("", 1) # 只分割第一个出现的结束标签
+ think_content = remaining[0]
+ after_think = remaining[1] if len(remaining) > 1 else ""
+
+ # 显示格式化后的内容
+ html_content = ""
+ if before_think.strip():
+ html_content += before_think
+
+ # 显示思考内容(使用自定义CSS样式)
+ if think_content.strip():
+ html_content += f'{think_content}
'
+
+ if after_think.strip():
+ html_content += after_think
+
+ message_placeholder.markdown(html_content, unsafe_allow_html=True)
+ else:
+ # 如果分割失败,显示原始内容
+ message_placeholder.markdown(full_response)
+ else:
+ message_placeholder.markdown(full_response)
+ except Exception as e:
+ # 如果处理过程中出错,回退到显示原始内容
+ logging.error(f"处理最终思考内容时出错: {str(e)}")
+ message_placeholder.markdown(full_response)
+
+ # 合并并显示文档链接
+ merged_doc_links = merge_doc_links(doc_links)
+ if merged_doc_links:
+ with st.expander("📑 参考文档来源", expanded=True):
+ for idx, doc in enumerate(merged_doc_links, 1):
+ st.markdown(f"**文档 {idx}**: [{doc['title']}]({doc['url']})")
+
+ # 保存消息到历史记录(保存合并后的文档链接)
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": full_response,
+ "relevant_docs": merged_doc_links
+ })
+ else:
+ st.error(full_response)
+
+ except Exception as e:
+ error_msg = f"处理出错: {str(e)}"
+ logging.error(error_msg)
+ st.error(error_msg)
+
+def main():
+ """主函数"""
+ try:
+ # 配置日志
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+
+ # 初始化会话状态
+ initialize_session_state()
+
+ # 设置页面
+ setup_page()
+
+ # 设置背景:修改点2,使用背景色值而非背景图片
+ setup_background(UIConfig.BACKGROUND_COLOR)
+
+ # 设置侧边栏
+ setup_sidebar()
+
+ # 显示聊天历史
+ display_chat_history()
+
+ # 处理用户输入
+ handle_user_input()
+
+ except Exception as e:
+ logging.error(f"程序运行出错: {str(e)}")
+ st.error("系统出错,请刷新页面重试")
+
+if __name__ == "__main__":
+ main()
From b76da7d41cc6400225e7bf8ff02f90d629535b20 Mon Sep 17 00:00:00 2001
From: YijiCu <1165084283@qq.com>
Date: Tue, 13 May 2025 10:56:46 +0800
Subject: [PATCH 3/3] Add files via upload
---
2025_05_13/model_api.py | 391 +++++++++++++++++++++++++++++++++++++++
2025_05_13/rag_milvus.py | 165 +++++++++++++++++
2 files changed, 556 insertions(+)
create mode 100644 2025_05_13/model_api.py
create mode 100644 2025_05_13/rag_milvus.py
diff --git a/2025_05_13/model_api.py b/2025_05_13/model_api.py
new file mode 100644
index 0000000..9a57250
--- /dev/null
+++ b/2025_05_13/model_api.py
@@ -0,0 +1,391 @@
+import time
+import tempfile
+import os
+import io
+import requests
+import json
+import logging
+from docx import Document
+import pdfplumber
+from rag_milvus import VectorRetrieval
+import torch
+import sseclient
+
+class UnifiedAssistant:
+ def __init__(self):
+ """初始化统一助手"""
+ # API配置
+ self.base_url = "***" #外部展示时覆盖
+ self.api_key = "***" #外部展示时覆盖
+ self.headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json"
+ }
+
+ # 系统提示词
+ self.prompt_template = """你是xxxxx的Deepseek办公助手,请根据用户上传材料(若有)以及企业内部知识(若有)对用户问题进行解答。
+用户问题:{user_input}
+用户上传材料:{user_file}
+企业内部知识:{knowledge}
+请回答:"""
+
+ # GPU设置
+ try:
+ if torch.cuda.is_available():
+ torch.cuda.set_device(1)
+ logging.info("已设置默认GPU为: cuda:1")
+
+ # 初始化RAG系统
+ self.retriever = VectorRetrieval(
+ model_dir="/app/rag/modeldir/bce-embedding-base_v1",
+ db_path="./milvus_db/vector.db",
+ collection_name="kms",
+ device="cuda:1"
+ )
+ self.rag_available = True
+ logging.info("RAG系统初始化成功")
+
+ except Exception as e:
+ logging.error(f"初始化失败: {e}")
+ self.rag_available = False
+
+ def read_file(self, file):
+ """统一的文件读取接口"""
+ try:
+ filename = getattr(file, 'name', 'Unknown')
+ logging.info(f"开始处理文件: {filename}")
+
+ # 获取文件扩展名
+ file_ext = os.path.splitext(filename)[1].lower()
+
+ # 根据文件类型选择相应的处理方法
+ if file_ext == '.pdf':
+ return self.read_pdf(file)
+ elif file_ext == '.docx':
+ return self.read_docx(file)
+ elif file_ext == '.doc':
+ return self.read_doc(file)
+ elif file_ext == '.txt':
+ return self.read_txt(file)
+ else:
+ logging.error(f"不支持的文件类型: {file_ext}")
+ return None
+
+ except Exception as e:
+ logging.error(f"读取文件失败: {str(e)}")
+ return None
+
+ def read_pdf(self, file):
+ """读取PDF文件内容"""
+ try:
+ # 获取文件内容
+ if hasattr(file, 'getvalue'):
+ pdf_data = io.BytesIO(file.getvalue())
+ else:
+ pdf_data = file
+
+ text_content = []
+ with pdfplumber.open(pdf_data) as pdf:
+ for page in pdf.pages:
+ text = page.extract_text()
+ if text.strip():
+ text_content.append(text)
+
+ content = "\n".join(text_content)
+ logging.info("PDF文件内容提取成功")
+ return content
+
+ except Exception as e:
+ logging.error(f"读取PDF文件失败: {str(e)}")
+ return None
+
+ def read_doc(self, file):
+ """读取旧版 Word (.doc) 文件内容"""
+ try:
+ import subprocess
+
+ # 如果是上传的文件对象,需要先保存到临时文件
+ if hasattr(file, 'getvalue'):
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.doc') as temp_file:
+ temp_file.write(file.getvalue())
+ temp_path = temp_file.name
+ else:
+ temp_path = file
+
+ try:
+ # 使用 antiword 提取文本
+ result = subprocess.run(['antiword', temp_path], capture_output=True, text=True)
+
+ if result.returncode != 0:
+ logging.error(f"Antiword 处理失败: {result.stderr}")
+ return None
+
+ text = result.stdout
+
+ # 清理和格式化文本
+ text = text.replace('\r', '\n')
+ text = '\n'.join(line.strip() for line in text.split('\n') if line.strip())
+
+ logging.info("DOC文件内容提取成功")
+ return text
+
+ finally:
+ # 如果使用了临时文件,需要删除它
+ if hasattr(file, 'getvalue') and os.path.exists(temp_path):
+ os.unlink(temp_path)
+
+ except Exception as e:
+ logging.error(f"读取DOC文件失败: {str(e)}")
+ return None
+
+ def read_docx(self, file):
+ """读取Word (.docx) 文件内容"""
+ try:
+ if hasattr(file, 'getvalue'):
+ doc = Document(io.BytesIO(file.getvalue()))
+ else:
+ doc = Document(file)
+
+ # 提取文本
+ text_content = []
+ for paragraph in doc.paragraphs:
+ if paragraph.text.strip():
+ text_content.append(paragraph.text)
+
+ content = "\n".join(text_content)
+ logging.info("DOCX文件内容提取成功")
+ return content
+
+ except Exception as e:
+ logging.error(f"读取DOCX文件失败: {str(e)}")
+ return None
+
+ def read_txt(self, file):
+ """读取TXT文件内容"""
+ try:
+ # 获取文件内容
+ if hasattr(file, 'getvalue'):
+ content = file.getvalue().decode('utf-8')
+ else:
+ content = file.read().decode('utf-8')
+
+ logging.info("TXT文件内容提取成功")
+ return content
+
+ except UnicodeDecodeError:
+ try:
+ # 如果UTF-8解码失败,尝试使用GBK
+ if hasattr(file, 'getvalue'):
+ content = file.getvalue().decode('gbk')
+ else:
+ content = file.read().decode('gbk')
+ logging.info("TXT文件内容提取成功(GBK编码)")
+ return content
+ except Exception as e:
+ logging.error(f"读取TXT文件失败: {str(e)}")
+ return None
+ except Exception as e:
+ logging.error(f"读取TXT文件失败: {str(e)}")
+ return None
+
+ def truncate_text(self, text, max_length, add_ellipsis=True):
+ """智能截断文本到指定长度,尽量保持完整句子和段落"""
+ if not text or len(text) <= max_length:
+ return text
+
+ # 首先按段落分割
+ paragraphs = text.split('\n')
+ truncated = []
+ current_length = 0
+
+ for para in paragraphs:
+ if not para.strip():
+ continue
+
+ # 如果单个段落就超过了最大长度
+ if len(para) > max_length:
+ # 按句子分割
+ sentences = para.split('。')
+ for sentence in sentences:
+ if current_length + len(sentence) + 1 > max_length:
+ break
+ truncated.append(sentence + '。')
+ current_length += len(sentence) + 1
+ else:
+ # 如果添加整个段落后不超过最大长度
+ if current_length + len(para) + 1 <= max_length:
+ truncated.append(para)
+ current_length += len(para) + 1
+ else:
+ break
+
+ result = '\n'.join(truncated)
+ if add_ellipsis and result != text:
+ result = result.rstrip('。\n') + '...'
+
+ return result
+
+ def get_rag_knowledge(self, query, limit=5):
+ """获取RAG检索结果"""
+ if not self.rag_available:
+ return None, []
+
+ try:
+ results = self.retriever.search(query, limit=limit)
+ if not results or not results[0]:
+ return None, []
+
+ knowledge_parts = []
+ doc_links = []
+
+ for item in results[0]:
+ try:
+ metadata = item.get('entity', {}).get('metadata', {})
+ if isinstance(metadata, str):
+ metadata = json.loads(metadata)
+
+ similarity = 1 - item.get('distance', 0)
+ text = item.get('entity', {}).get('text', '')
+
+ if metadata and text:
+ knowledge_parts.append(
+ f"相关度{similarity:.2f}的内容:\n{text}\n"
+ f"(来源:{metadata.get('title', '未知文档')})"
+ )
+
+ if 'url' in item:
+ doc_links.append({
+ "title": metadata.get('title', '未知文档'),
+ "url": item['url']
+ })
+ except Exception as e:
+ logging.warning(f"处理搜索结果时出错: {e}")
+ continue
+
+ return "\n\n".join(knowledge_parts), doc_links
+
+ except Exception as e:
+ logging.error(f"RAG检索失败: {e}")
+ return None, []
+
+ def process_stream(self, response, streaming_callback=None):
+ """处理流式响应"""
+ try:
+ full_response = ""
+ client = sseclient.SSEClient(response)
+
+ for event in client.events():
+ if event.data:
+ try:
+ data = json.loads(event.data)
+ event_type = data.get("event")
+
+ # 检查错误状态
+ if event_type == "error":
+ logging.error(f"Stream error: {data.get('message')}")
+ break
+
+ # 检查消息结束事件
+ if event_type == "message_end":
+ logging.info("收到message_end事件,响应生成完成")
+ return full_response
+
+ # 处理新的文本块
+ if event_type == "message" and "answer" in data:
+ chunk = data["answer"]
+ if chunk: # 确保chunk不为空
+ full_response += chunk
+ if streaming_callback:
+ streaming_callback(full_response)
+
+ except json.JSONDecodeError:
+ continue
+
+ logging.info("事件流结束")
+ return full_response
+
+ except Exception as e:
+ logging.error(f"处理流式响应时出错: {str(e)}")
+ return None
+
+ def chat(self, user_input, doc_content=None, use_rag=False, streaming_callback=None):
+ try:
+ # 准备对话内容
+ file_content = doc_content if doc_content else "无上传文件"
+ knowledge, doc_links = "无相关知识", []
+
+ if use_rag and self.rag_available:
+ knowledge, doc_links = self.get_rag_knowledge(user_input)
+ knowledge = knowledge if knowledge else "未找到相关知识"
+
+ # 计算系统提示词和用户问题的长度
+ base_prompt = self.prompt_template.format(
+ user_input=user_input,
+ user_file="", # 临时空内容用于计算基础长度
+ knowledge=""
+ )
+ base_length = len(base_prompt)
+ MAX_LENGTH = 4400
+
+ # 计算可用于文件内容的最大长度
+ available_length = MAX_LENGTH - base_length - 100 # 预留一些空间给格式化字符
+ if file_content and file_content != "无上传文件":
+ # 使用更大的限制,因为没有使用RAG
+ if not use_rag:
+ available_length = min(available_length, 4000) # 给文件内容分配更多空间
+ file_content = self.truncate_text(file_content, available_length)
+ logging.info(f"处理后的文件内容长度: {len(file_content)}")
+
+ # 组装最终提示词
+ prompt = self.prompt_template.format(
+ user_input=user_input,
+ user_file=file_content,
+ knowledge=knowledge
+ )
+
+ logging.info(f"最终提示词长度: {len(prompt)}")
+
+ # 准备请求
+ payload = {
+ "inputs": {},
+ "query": prompt,
+ "response_mode": "streaming",
+ "user": "test_user_1"
+ }
+
+ # 发送请求时修改超时设置
+ with requests.post(
+ f"{self.base_url}/chat-messages",
+ headers=self.headers,
+ json=payload,
+ stream=True,
+ timeout=(30, 120) # 连接超时30秒,读取超时120秒
+ ) as response:
+ response.raise_for_status()
+
+ # 处理响应
+ full_response = self.process_stream(response, streaming_callback)
+ if full_response: # 改为检查响应内容是否为空
+ logging.info(f"获取响应成功,长度: {len(full_response)}")
+ return True, full_response, doc_links
+
+ logging.error("未能获取有效响应")
+ return False, "生成回答失败", []
+
+ except requests.Timeout:
+ error_msg = "请求超时"
+ logging.error(error_msg)
+ return False, error_msg, []
+ except Exception as e:
+ error_msg = f"聊天过程出错: {str(e)}"
+ logging.error(error_msg)
+ return False, error_msg, []
+
+ def cleanup(self):
+ """清理资源"""
+ if self.rag_available and hasattr(self, 'retriever'):
+ try:
+ self.retriever.cleanup()
+ logging.info("资源清理完成")
+ except Exception as e:
+ logging.error(f"清理资源时出错: {e}")
diff --git a/2025_05_13/rag_milvus.py b/2025_05_13/rag_milvus.py
new file mode 100644
index 0000000..7dccc1c
--- /dev/null
+++ b/2025_05_13/rag_milvus.py
@@ -0,0 +1,165 @@
+import os
+import logging
+import json
+import torch
+from pymilvus import MilvusClient
+from BCEmbedding import EmbeddingModel
+
+class VectorRetrieval:
+ def __init__(
+ self,
+ model_dir="/app/rag/modeldir/bce-embedding-base_v1",
+ db_path="./milvus_db/vector.db",
+ collection_name="kms",
+ device="cuda:1",
+ doclink_path="/app/rag/kb/kms/doclink"
+ ):
+ """初始化向量检索类"""
+ logging.info("初始化向量检索系统...")
+ self.device = device
+ self.doclink_path = doclink_path
+
+ try:
+ # 设置当前设备
+ device_id = int(self.device.split(':')[1])
+ torch.cuda.set_device(device_id)
+
+ # 加载模型
+ self.model = EmbeddingModel(
+ model_name_or_path=model_dir,
+ device=device
+ )
+ logging.info(f"成功在设备 {self.device} 上初始化Embedding模型")
+
+ # 连接数据库
+ try:
+ self.client = MilvusClient(uri=db_path)
+ self.collection_name = collection_name
+
+ if self.client.has_collection(self.collection_name):
+ self.client.load_collection(self.collection_name)
+ logging.info(f"Collection {self.collection_name} 加载成功")
+ else:
+ raise ValueError(f"Collection {collection_name} 不存在!")
+ except Exception as e:
+ raise RuntimeError(f"连接数据库失败: {str(e)}")
+
+ except Exception as e:
+ logging.error(f"初始化失败: {str(e)}")
+ raise
+
+ def find_doc_url(self, filename):
+ """根据文件名在doclink目录下查找对应的URL"""
+ try:
+ # 确保文件名是解码后的中文
+ if isinstance(filename, bytes):
+ filename = filename.decode('utf-8')
+
+ link_file_path = os.path.join(self.doclink_path, f"{filename}.txt")
+ logging.info(f"尝试读取link文件: {link_file_path}")
+
+ if not os.path.exists(link_file_path):
+ logging.warning(f"找不到对应的link文件: {link_file_path}")
+ return None
+
+ with open(link_file_path, 'r', encoding='utf-8') as f:
+ url = f.read().strip()
+ logging.info(f"成功读取URL for {filename}: {url}")
+
+ return url
+
+ except Exception as e:
+ logging.error(f"读取文档URL失败 - 文件名: {filename}, 错误: {str(e)}")
+ return None
+
+ def search(self, query, limit=5):
+ """执行向量检索并返回带URL的结果"""
+ try:
+ logging.info(f"开始执行向量检索 - 查询: {query}, 限制数量: {limit}")
+
+ query_vector = self.model.encode([query])[0].tolist()
+ logging.info("成功生成查询向量")
+
+ results = self.client.search(
+ collection_name=self.collection_name,
+ data=[query_vector],
+ limit=limit,
+ output_fields=["text", "metadata"]
+ )
+
+ if not results or not results[0]:
+ logging.warning("未找到相关结果")
+ return None
+
+ logging.info(f"检索到 {len(results[0])} 条结果")
+
+ # 处理搜索结果
+ for idx, result in enumerate(results[0], 1):
+ logging.info(f"\n--- 结果 {idx} ---")
+
+ # 记录原始结果
+ logging.info(f"原始结果: {json.dumps(result, ensure_ascii=False)}")
+
+ # 处理元数据
+ metadata = result.get('entity', {}).get('metadata') if 'entity' in result else result.get('metadata')
+ if metadata:
+ if isinstance(metadata, str):
+ try:
+ metadata = json.loads(metadata)
+ except json.JSONDecodeError:
+ logging.warning(f"metadata解析失败: {metadata}")
+ continue
+
+ logging.info(f"元数据: {json.dumps(metadata, ensure_ascii=False)}")
+
+ # 获取文件名和URL
+ if 'source' in metadata:
+ filename = os.path.basename(metadata['source'])
+ url = self.find_doc_url(filename)
+ if url:
+ result['url'] = url
+
+ # 记录文本内容
+ text = result.get('entity', {}).get('text') if 'entity' in result else result.get('text')
+ if text:
+ text_preview = text[:200] + "..." if len(text) > 200 else text
+ logging.info(f"文本内容预览:\n{text_preview}")
+
+ return results
+
+ except Exception as e:
+ logging.error(f"搜索过程出错: {str(e)}", exc_info=True)
+ return None
+
+ def cleanup(self):
+ """清理资源"""
+ if hasattr(self, 'client'):
+ try:
+ self.client.close()
+ logging.info("数据库连接已关闭")
+ except Exception as e:
+ logging.error(f"关闭数据库连接时出错: {str(e)}")
+
+def main():
+ # 配置日志
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+
+ try:
+ # 初始化检索系统
+ retriever = VectorRetrieval()
+
+ # 获取用户输入
+ query = input("请输入搜索问题: ")
+
+ # 执行搜索
+ retriever.search(query)
+
+ except Exception as e:
+ logging.error(f"发生错误: {str(e)}")
+
+if __name__ == "__main__":
+ main()