|
| 1 | +import tiktoken |
| 2 | +import functools |
| 3 | +import uuid |
1 | 4 | import requests |
2 | | -from typing import Dict, List, Any |
| 5 | +import base64 |
| 6 | +import time |
| 7 | +import openai |
| 8 | +from datetime import datetime |
| 9 | +from pydub import AudioSegment |
| 10 | +import requests |
| 11 | +from pydantic import BaseModel |
| 12 | +from typing import Dict, List, Any, Optional |
| 13 | + |
| 14 | + |
| 15 | +def get_tokens(text: str) -> int: |
| 16 | + encoding = tiktoken.get_encoding("cl100k_base") |
| 17 | + num_tokens = len(encoding.encode(text)) |
| 18 | + return num_tokens |
3 | 19 |
|
4 | 20 |
|
5 | 21 | class AGiXTSDK: |
@@ -962,3 +978,241 @@ def text_to_speech(self, agent_name: str, text: str, conversation_name: str): |
962 | 978 | conversation_name=conversation_name, |
963 | 979 | ) |
964 | 980 | return response |
| 981 | + |
| 982 | + |
| 983 | +class ChatCompletions(BaseModel): |
| 984 | + model: str = "NurseGPT" # This is the agent name |
| 985 | + messages: List[dict] = None |
| 986 | + temperature: Optional[float] = 0.9 |
| 987 | + top_p: Optional[float] = 1.0 |
| 988 | + tools: Optional[List[dict]] = None |
| 989 | + tools_choice: Optional[str] = "auto" |
| 990 | + n: Optional[int] = 1 |
| 991 | + stream: Optional[bool] = False |
| 992 | + stop: Optional[List[str]] = None |
| 993 | + max_tokens: Optional[int] = 4096 |
| 994 | + presence_penalty: Optional[float] = 0.0 |
| 995 | + frequency_penalty: Optional[float] = 0.0 |
| 996 | + logit_bias: Optional[Dict[str, float]] = None |
| 997 | + user: Optional[str] = "Chat" # This is the conversation name |
| 998 | + |
| 999 | + |
| 1000 | +# Chat Completion Decorator |
| 1001 | +def AGiXT_chat(base_uri: str, api_key: str = None): |
| 1002 | + def decorator(func): |
| 1003 | + @functools.wraps(func) |
| 1004 | + def wrapper(prompt: ChatCompletions): |
| 1005 | + agixt = AGiXTSDK( |
| 1006 | + base_uri=base_uri, api_key=api_key if api_key else base_uri |
| 1007 | + ) |
| 1008 | + agent_name = prompt.model # prompt.model is the agent name |
| 1009 | + conversation_name = prompt.user # prompt.user is the conversation name |
| 1010 | + agent_config = agixt.get_agentconfig(agent_name=agent_name) |
| 1011 | + agent_settings = ( |
| 1012 | + agent_config["settings"] if "settings" in agent_config else {} |
| 1013 | + ) |
| 1014 | + images = [] |
| 1015 | + new_prompt = "" |
| 1016 | + for message in prompt.messages: |
| 1017 | + if "content" not in message: |
| 1018 | + continue |
| 1019 | + if isinstance(message["content"], str): |
| 1020 | + role = message["role"] if "role" in message else "User" |
| 1021 | + if role.lower() == "system": |
| 1022 | + if "/" in message["content"]: |
| 1023 | + new_prompt += f"{message['content']}\n\n" |
| 1024 | + if role.lower() == "user": |
| 1025 | + new_prompt += f"{message['content']}\n\n" |
| 1026 | + if isinstance(message["content"], list): |
| 1027 | + for msg in message["content"]: |
| 1028 | + if "text" in msg: |
| 1029 | + role = message["role"] if "role" in message else "User" |
| 1030 | + if role.lower() == "user": |
| 1031 | + new_prompt += f"{msg['text']}\n\n" |
| 1032 | + if "image_url" in msg: |
| 1033 | + url = ( |
| 1034 | + msg["image_url"]["url"] |
| 1035 | + if "url" in msg["image_url"] |
| 1036 | + else msg["image_url"] |
| 1037 | + ) |
| 1038 | + image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg" |
| 1039 | + if url.startswith("http"): |
| 1040 | + image = requests.get(url).content |
| 1041 | + else: |
| 1042 | + file_type = ( |
| 1043 | + url.split(",")[0].split("/")[1].split(";")[0] |
| 1044 | + ) |
| 1045 | + if file_type == "jpeg": |
| 1046 | + file_type = "jpg" |
| 1047 | + file_name = f"{uuid.uuid4().hex}.{file_type}" |
| 1048 | + image_path = f"./WORKSPACE/{file_name}" |
| 1049 | + image = base64.b64decode(url.split(",")[1]) |
| 1050 | + with open(image_path, "wb") as f: |
| 1051 | + f.write(image) |
| 1052 | + images.append(image_path) |
| 1053 | + if "audio_url" in msg: |
| 1054 | + audio_url = ( |
| 1055 | + msg["audio_url"]["url"] |
| 1056 | + if "url" in msg["audio_url"] |
| 1057 | + else msg["audio_url"] |
| 1058 | + ) |
| 1059 | + # If it is not a url, we need to find the file type and convert with pydub |
| 1060 | + if not audio_url.startswith("http"): |
| 1061 | + file_type = ( |
| 1062 | + audio_url.split(",")[0].split("/")[1].split(";")[0] |
| 1063 | + ) |
| 1064 | + audio_data = base64.b64decode(audio_url.split(",")[1]) |
| 1065 | + audio_path = ( |
| 1066 | + f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" |
| 1067 | + ) |
| 1068 | + with open(audio_path, "wb") as f: |
| 1069 | + f.write(audio_data) |
| 1070 | + audio_url = audio_path |
| 1071 | + else: |
| 1072 | + # Download the audio file from the url, get the file type and convert to wav |
| 1073 | + audio_type = audio_url.split(".")[-1] |
| 1074 | + audio_url = ( |
| 1075 | + f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}" |
| 1076 | + ) |
| 1077 | + audio_data = requests.get(audio_url).content |
| 1078 | + with open(audio_url, "wb") as f: |
| 1079 | + f.write(audio_data) |
| 1080 | + wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" |
| 1081 | + AudioSegment.from_file(audio_url).set_frame_rate( |
| 1082 | + 16000 |
| 1083 | + ).export(wav_file, format="wav") |
| 1084 | + # Switch this to use the endpoint |
| 1085 | + openai.api_key = ( |
| 1086 | + agixt.headers["Authorization"] |
| 1087 | + .replace("Bearer ", "") |
| 1088 | + .replace("bearer ", "") |
| 1089 | + ) |
| 1090 | + openai.base_url = f"{agixt.base_uri}/v1/" |
| 1091 | + with open(wav_file, "rb") as audio_file: |
| 1092 | + transcription = openai.audio.transcriptions.create( |
| 1093 | + model=agent_name, file=audio_file |
| 1094 | + ) |
| 1095 | + new_prompt += transcription.text |
| 1096 | + if "video_url" in msg: |
| 1097 | + video_url = str( |
| 1098 | + msg["video_url"]["url"] |
| 1099 | + if "url" in msg["video_url"] |
| 1100 | + else msg["video_url"] |
| 1101 | + ) |
| 1102 | + if "collection_number" in msg: |
| 1103 | + collection_number = int(msg["collection_number"]) |
| 1104 | + else: |
| 1105 | + collection_number = 0 |
| 1106 | + if video_url.startswith("https://www.youtube.com/watch?v="): |
| 1107 | + agixt.learn_url( |
| 1108 | + agent_name=agent_name, |
| 1109 | + url=video_url, |
| 1110 | + collection_number=collection_number, |
| 1111 | + ) |
| 1112 | + if ( |
| 1113 | + "file_url" in msg |
| 1114 | + or "application_url" in msg |
| 1115 | + or "text_url" in msg |
| 1116 | + or "url" in msg |
| 1117 | + ): |
| 1118 | + file_url = str( |
| 1119 | + msg["file_url"]["url"] |
| 1120 | + if "url" in msg["file_url"] |
| 1121 | + else msg["file_url"] |
| 1122 | + ) |
| 1123 | + if ( |
| 1124 | + "collection_number" in message |
| 1125 | + or "collection_number" in msg |
| 1126 | + ): |
| 1127 | + collection_number = int( |
| 1128 | + message["collection_number"] |
| 1129 | + if "collection_number" in message |
| 1130 | + else msg["collection_number"] |
| 1131 | + ) |
| 1132 | + else: |
| 1133 | + collection_number = 0 |
| 1134 | + if file_url.startswith("http"): |
| 1135 | + if file_url.startswith( |
| 1136 | + "https://www.youtube.com/watch?v=" |
| 1137 | + ): |
| 1138 | + agixt.learn_url( |
| 1139 | + agent_name=agent_name, |
| 1140 | + url=file_url, |
| 1141 | + collection_number=collection_number, |
| 1142 | + ) |
| 1143 | + elif file_url.startswith("https://github.com"): |
| 1144 | + agixt.learn_github_repo( |
| 1145 | + agent_name=agent_name, |
| 1146 | + github_repo=file_url, |
| 1147 | + github_user=( |
| 1148 | + agent_settings["GITHUB_USER"] |
| 1149 | + if "GITHUB_USER" in agent_settings |
| 1150 | + else None |
| 1151 | + ), |
| 1152 | + github_token=( |
| 1153 | + agent_settings["GITHUB_TOKEN"] |
| 1154 | + if "GITHUB_TOKEN" in agent_settings |
| 1155 | + else None |
| 1156 | + ), |
| 1157 | + github_branch=( |
| 1158 | + "main" |
| 1159 | + if "branch" not in message |
| 1160 | + else message["branch"] |
| 1161 | + ), |
| 1162 | + collection_number=collection_number, |
| 1163 | + ) |
| 1164 | + else: |
| 1165 | + agixt.learn_url( |
| 1166 | + agent_name=agent_name, |
| 1167 | + url=file_url, |
| 1168 | + collection_number=collection_number, |
| 1169 | + ) |
| 1170 | + else: |
| 1171 | + file_type = ( |
| 1172 | + file_url.split(",")[0].split("/")[1].split(";")[0] |
| 1173 | + ) |
| 1174 | + file_data = base64.b64decode(file_url.split(",")[1]) |
| 1175 | + file_path = ( |
| 1176 | + f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" |
| 1177 | + ) |
| 1178 | + with open(file_path, "wb") as f: |
| 1179 | + f.write(file_data) |
| 1180 | + # file name should be a safe timestamp |
| 1181 | + file_name = f"Uploaded File {datetime.now().strftime('%Y%m%d%H%M%S')}.{file_type}" |
| 1182 | + agixt.learn_file( |
| 1183 | + agent_name=agent_name, |
| 1184 | + file_name=f"Uploaded File {uuid.uuid4().hex}.{file_type}", |
| 1185 | + file_content=file_data, |
| 1186 | + collection_number=collection_number, |
| 1187 | + ) |
| 1188 | + response = func(new_prompt) |
| 1189 | + prompt_tokens = get_tokens(new_prompt) |
| 1190 | + completion_tokens = get_tokens(response) |
| 1191 | + total_tokens = int(prompt_tokens) + int(completion_tokens) |
| 1192 | + res_model = { |
| 1193 | + "id": conversation_name, |
| 1194 | + "object": "chat.completion", |
| 1195 | + "created": int(time.time()), |
| 1196 | + "model": agent_name, |
| 1197 | + "choices": [ |
| 1198 | + { |
| 1199 | + "index": 0, |
| 1200 | + "message": { |
| 1201 | + "role": "assistant", |
| 1202 | + "content": str(response), |
| 1203 | + }, |
| 1204 | + "finish_reason": "stop", |
| 1205 | + "logprobs": None, |
| 1206 | + } |
| 1207 | + ], |
| 1208 | + "usage": { |
| 1209 | + "prompt_tokens": prompt_tokens, |
| 1210 | + "completion_tokens": completion_tokens, |
| 1211 | + "total_tokens": total_tokens, |
| 1212 | + }, |
| 1213 | + } |
| 1214 | + return res_model |
| 1215 | + |
| 1216 | + return wrapper |
| 1217 | + |
| 1218 | + return decorator |
0 commit comments