diff --git a/app/controller/app-center/aiChat.ts b/app/controller/app-center/aiChat.ts index 350af76..ce9422b 100644 --- a/app/controller/app-center/aiChat.ts +++ b/app/controller/app-center/aiChat.ts @@ -10,27 +10,76 @@ * */ import { Controller } from 'egg'; -import { E_FOUNDATION_MODEL } from '../../lib/enum'; export default class AiChatController extends Controller { public async aiChat() { const { ctx } = this; - const { foundationModel, messages } = ctx.request.body; - this.ctx.logger.info('ai接口请求参参数 model选型:', foundationModel); + const options = ctx.request.body; + this.ctx.logger.info('ai接口请求参数 model选型:', options); + + const messages = options.messages; if (!messages || !Array.isArray(messages)) { return this.ctx.helper.getResponseData('Not passing the correct message parameter'); } - const model = foundationModel?.model ?? E_FOUNDATION_MODEL.GPT_35_TURBO; - const token = foundationModel.token; - ctx.body = await ctx.service.appCenter.aiChat.getAnswerFromAi(messages, { model, token }); + const apiKey = ctx.request.header?.authorization?.replace('Bearer', ''); + const baseUrl = options?.baseUrl; + const model = options?.model; + const stream = options?.stream || false; + const tools = options?.tools || []; + + if (stream) { + ctx.status = 200; + ctx.set({ + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive' + }); + try { + const result = await ctx.service.appCenter.aiChat.getAnswerFromAi(messages, { + apiKey, + baseUrl, + model, + stream, + tools + }); + + for await (const chunk of result) { + ctx.res.write(`data: ${JSON.stringify(chunk)}\n\n`); // SSE 格式 + } + + // 添加结束标记 + ctx.res.write('data: [DONE]'); + } catch (e: any) { + this.ctx.logger.error(`调用AI大模型接口失败: ${(e as Error).message}`); + } finally { + console.log('end'); + ctx.res.end(); // 关闭连接 + } + + return; + } + + // 非流式模式 + ctx.body = await ctx.service.appCenter.aiChat.getAnswerFromAi(messages, { + apiKey, + baseUrl, + model, + stream, + tools + }); } + public async search() { + const { ctx } = this; + const { content } = ctx.request.body; + + ctx.body = await ctx.service.appCenter.aiChat.search(content); + } public async uploadFile() { const { ctx } = this; - const fileStream = await ctx.getFileStream(); - const foundationModelObject = JSON.parse(fileStream.fields.foundationModel); - const { model, token } = foundationModelObject.foundationModel; - ctx.body = await ctx.service.appCenter.aiChat.getFileContentFromAi(fileStream, { model, token }); + const stream = await ctx.getFileStream(); + + ctx.body = await ctx.service.appCenter.aiChat.uploadFile(stream); } } diff --git a/app/router/appCenter/base.ts b/app/router/appCenter/base.ts index 2b7c273..45f694d 100644 --- a/app/router/appCenter/base.ts +++ b/app/router/appCenter/base.ts @@ -113,5 +113,6 @@ export default (app: Application) => { // AI大模型聊天接口 subRouter.post('/ai/chat', controller.appCenter.aiChat.aiChat); - subRouter.post('/ai/files', controller.appCenter.aiChat.uploadFile); + subRouter.post('/ai/search', controller.appCenter.aiChat.search); + subRouter.post('/ai/uploadFile', controller.appCenter.aiChat.uploadFile); }; diff --git a/app/service/app-center/aiChat.ts b/app/service/app-center/aiChat.ts index 8ed5207..93ee597 100644 --- a/app/service/app-center/aiChat.ts +++ b/app/service/app-center/aiChat.ts @@ -10,13 +10,17 @@ * */ import { Service } from 'egg'; -import { E_FOUNDATION_MODEL } from '../../lib/enum'; -import * as fs from 'fs'; -import * as path from 'path'; - -const to = require('await-to-js').default; -const OpenAI = require('openai'); - +import OpenApi, * as $OpenApi from '@alicloud/openapi-client'; +import OpenApiUtil from '@alicloud/openapi-util'; +import * as $Util from '@alicloud/tea-util'; +import Credential, { Config } from '@alicloud/credentials'; +import OpenAI from 'openai'; +import { ChatCompletionMessageParam } from 'openai/resources/chat/completions'; +import path from 'path'; +import FormData from 'form-data'; +import fs from 'fs'; +import axios from 'axios'; +import pump from 'mz-modules/pump'; export type AiMessage = { role: string; // 角色 @@ -25,251 +29,243 @@ export type AiMessage = { partial?: boolean; }; -interface ConfigModel { - model: string; - token: string; -} - export default class AiChat extends Service { /** * 获取ai的答复 * - * 根据后续引进的大模型情况决定,是否通过重构来对不同大模型进行统一的适配 - * * @param messages * @param model * @return */ - async getAnswerFromAi(messages: Array, chatConfig: any) { - let res = await this.requestAnswerFromAi(messages, chatConfig); - let answerContent = ''; - let isFinish = res.choices[0].finish_reason; + async getAnswerFromAi(messages: ChatCompletionMessageParam[], chatConfig: any) { + let result: any = null; - if (isFinish !== 'length') { - answerContent = res.choices[0]?.message.content; - } - - // 若内容过长被截断,继续回复 - while (isFinish === 'length') { - const prefix = res.choices[0].message.content; - answerContent += prefix; - messages.push({ - role: 'assistant', - content: prefix, - partial: true + try { + const openai = new OpenAI({ + apiKey: chatConfig.apiKey || process.env.OPEN_AI_API_KEY, + baseURL: chatConfig.baseUrl || process.env.OPEN_AI_BASE_URL, + defaultHeaders: { + 'X-DashScope-OssResourceResolve': 'enable' + } }); - res = await this.requestAnswerFromAi(messages, chatConfig); - answerContent += res.choices[0].message.content; - isFinish = res.choices[0].finish_reason; - } + const options: any = { + model: chatConfig.model || process.env.OPEN_AI_MODEL, + messages, + stream: chatConfig.stream + }; + if (chatConfig.tools.length) { + options.tools = chatConfig.tools; + } - const code = this.extractCode(answerContent); - const schema = this.extractSchemaCode(code); - const answer = { - role: res.choices[0].message.role, - content: answerContent - }; - const replyWithoutCode = this.removeCode(answerContent); - return this.ctx.helper.getResponseData({ - originalResponse: answer, - replyWithoutCode, - schema - }); - } + result = await openai.chat.completions.create(options); - async requestAnswerFromAi(messages: Array, chatConfig: any) { - const { ctx } = this; - this.formatMessage(messages); - let res: any = null; - try { - // 根据大模型的不同匹配不同的配置 - const aiChatConfig = this.config.aiChat(messages, chatConfig.token); - const { httpRequestUrl, httpRequestOption } = aiChatConfig[chatConfig.model]; - this.ctx.logger.debug(httpRequestOption); - res = await ctx.curl(httpRequestUrl, httpRequestOption); + return result; } catch (e: any) { - this.ctx.logger.debug(`调用AI大模型接口失败: ${(e as Error).message}`); + this.ctx.logger.error(`调用AI大模型接口失败: ${(e as Error).message}`); return this.ctx.helper.getResponseData(`调用AI大模型接口失败: ${(e as Error).message}`); } - - if (!res) { - return this.ctx.helper.getResponseData(`调用AI大模型接口未返回正确数据.`); - } - - // 适配文心一言的响应数据结构,文心的部分异常情况status也是200,需要转为400,以免前端无所适从 - if (res.data?.error_code) { - return this.ctx.helper.getResponseData(res.data?.error_msg); - } - - // 适配chatgpt的响应数据结构 - if (res.status !== 200) { - return this.ctx.helper.getResponseData(res.data?.error?.message, res.status); - } - - // 适配文心一言的响应数据结构 - if (chatConfig.model === E_FOUNDATION_MODEL.ERNIE_BOT_TURBO) { - return { - ...res.data, - choices: [ - { - message: { - role: 'assistant', - content: res.data.result - } - } - ] - }; - } - - return res.data; } /** - * 提取回复中的代码 - * - * 暂且只满足回复中只包括一个代码块的场景 + * 知识库检索 + * @remarks + * 使用凭据初始化账号Client + * @returns Client * - * @param content ai回复的内容 - * @return 提取的文本 + * @throws Exception */ - private extractCode(content: string) { - const { start, end } = this.getStartAndEnd(content); - if (start < 0 || end < 0) { - return ''; - } - return content.substring(start, end); + private createClient(): OpenApi { + const credentialsConfig1 = new Config({ + type: 'access_key', + accessKeyId: process.env.ALIBABA_CLOUD_ACCESS_KEY_ID, + accessKeySecret: process.env.ALIBABA_CLOUD_ACCESS_KEY_SECRET + }); + let credential = new Credential(credentialsConfig1); + let config = new $OpenApi.Config({ + credential: credential + }); + + config.endpoint = `bailian.cn-beijing.aliyuncs.com`; + return new OpenApi(config); } /** - * 去除回复中的代码 - * - * 暂且只满足回复中只包括一个代码块的场景 + * @remarks + * API 相关 * - * @param content ai回复的内容 - * @return 去除代码后的回复内容 + * @param path - string Path parameters + * @returns OpenApi.Params */ - private removeCode(content: string) { - const { start, end } = this.getStartAndEnd(content); - if (start < 0 || end < 0) { - return content; - } - return content.substring(0, start) + '<代码在画布中展示>' + content.substring(end); + private createApiInfo(WorkspaceId): $OpenApi.Params { + let params = new $OpenApi.Params({ + // 接口名称 + action: 'Retrieve', + // 接口版本 + version: '2023-12-29', + // 接口协议 + protocol: 'HTTPS', + // 接口 HTTP 方法 + method: 'POST', + authType: 'AK', + style: 'ROA', + // 接口 PATH + pathname: `/${WorkspaceId}/index/retrieve`, + // 接口请求体内容格式 + reqBodyType: 'json', + // 接口响应体内容格式 + bodyType: 'json' + }); + return params; } - private extractSchemaCode(content) { - const startMarker = /```json/; - const endMarker = /```/; - - const start = content.search(startMarker); - const end = content.slice(start + 7).search(endMarker) + start + 7; - - if (start >= 0 && end >= 0) { - return JSON.parse(content.substring(start + 7, end).trim()); - } - - return null; + private getSearchList(res) { + const list = res?.body?.Data?.Nodes ?? []; + + return { + data: list.map((node) => { + return { + score: node.Score, + content: node.Text, + doc_name: node.Metadata.doc_name + }; + }) + }; } - private getStartAndEnd(str: string) { - const start = str.search(/```|