Skip to content

Commit 69fd116

Browse files
committed
v0.0.38
1 parent 1c45c55 commit 69fd116

File tree

2 files changed

+256
-2
lines changed

2 files changed

+256
-2
lines changed

agixtsdk/__init__.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1+
import tiktoken
2+
import functools
3+
import uuid
14
import requests
2-
from typing import Dict, List, Any
5+
import base64
6+
import time
7+
import openai
8+
from datetime import datetime
9+
from pydub import AudioSegment
10+
import requests
11+
from pydantic import BaseModel
12+
from typing import Dict, List, Any, Optional
13+
14+
15+
def get_tokens(text: str) -> int:
16+
encoding = tiktoken.get_encoding("cl100k_base")
17+
num_tokens = len(encoding.encode(text))
18+
return num_tokens
319

420

521
class AGiXTSDK:
@@ -962,3 +978,241 @@ def text_to_speech(self, agent_name: str, text: str, conversation_name: str):
962978
conversation_name=conversation_name,
963979
)
964980
return response
981+
982+
983+
class ChatCompletions(BaseModel):
984+
model: str = "NurseGPT" # This is the agent name
985+
messages: List[dict] = None
986+
temperature: Optional[float] = 0.9
987+
top_p: Optional[float] = 1.0
988+
tools: Optional[List[dict]] = None
989+
tools_choice: Optional[str] = "auto"
990+
n: Optional[int] = 1
991+
stream: Optional[bool] = False
992+
stop: Optional[List[str]] = None
993+
max_tokens: Optional[int] = 4096
994+
presence_penalty: Optional[float] = 0.0
995+
frequency_penalty: Optional[float] = 0.0
996+
logit_bias: Optional[Dict[str, float]] = None
997+
user: Optional[str] = "Chat" # This is the conversation name
998+
999+
1000+
# Chat Completion Decorator
1001+
def AGiXT_chat(base_uri: str, api_key: str = None):
1002+
def decorator(func):
1003+
@functools.wraps(func)
1004+
def wrapper(prompt: ChatCompletions):
1005+
agixt = AGiXTSDK(
1006+
base_uri=base_uri, api_key=api_key if api_key else base_uri
1007+
)
1008+
agent_name = prompt.model # prompt.model is the agent name
1009+
conversation_name = prompt.user # prompt.user is the conversation name
1010+
agent_config = agixt.get_agentconfig(agent_name=agent_name)
1011+
agent_settings = (
1012+
agent_config["settings"] if "settings" in agent_config else {}
1013+
)
1014+
images = []
1015+
new_prompt = ""
1016+
for message in prompt.messages:
1017+
if "content" not in message:
1018+
continue
1019+
if isinstance(message["content"], str):
1020+
role = message["role"] if "role" in message else "User"
1021+
if role.lower() == "system":
1022+
if "/" in message["content"]:
1023+
new_prompt += f"{message['content']}\n\n"
1024+
if role.lower() == "user":
1025+
new_prompt += f"{message['content']}\n\n"
1026+
if isinstance(message["content"], list):
1027+
for msg in message["content"]:
1028+
if "text" in msg:
1029+
role = message["role"] if "role" in message else "User"
1030+
if role.lower() == "user":
1031+
new_prompt += f"{msg['text']}\n\n"
1032+
if "image_url" in msg:
1033+
url = (
1034+
msg["image_url"]["url"]
1035+
if "url" in msg["image_url"]
1036+
else msg["image_url"]
1037+
)
1038+
image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg"
1039+
if url.startswith("http"):
1040+
image = requests.get(url).content
1041+
else:
1042+
file_type = (
1043+
url.split(",")[0].split("/")[1].split(";")[0]
1044+
)
1045+
if file_type == "jpeg":
1046+
file_type = "jpg"
1047+
file_name = f"{uuid.uuid4().hex}.{file_type}"
1048+
image_path = f"./WORKSPACE/{file_name}"
1049+
image = base64.b64decode(url.split(",")[1])
1050+
with open(image_path, "wb") as f:
1051+
f.write(image)
1052+
images.append(image_path)
1053+
if "audio_url" in msg:
1054+
audio_url = (
1055+
msg["audio_url"]["url"]
1056+
if "url" in msg["audio_url"]
1057+
else msg["audio_url"]
1058+
)
1059+
# If it is not a url, we need to find the file type and convert with pydub
1060+
if not audio_url.startswith("http"):
1061+
file_type = (
1062+
audio_url.split(",")[0].split("/")[1].split(";")[0]
1063+
)
1064+
audio_data = base64.b64decode(audio_url.split(",")[1])
1065+
audio_path = (
1066+
f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}"
1067+
)
1068+
with open(audio_path, "wb") as f:
1069+
f.write(audio_data)
1070+
audio_url = audio_path
1071+
else:
1072+
# Download the audio file from the url, get the file type and convert to wav
1073+
audio_type = audio_url.split(".")[-1]
1074+
audio_url = (
1075+
f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}"
1076+
)
1077+
audio_data = requests.get(audio_url).content
1078+
with open(audio_url, "wb") as f:
1079+
f.write(audio_data)
1080+
wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav"
1081+
AudioSegment.from_file(audio_url).set_frame_rate(
1082+
16000
1083+
).export(wav_file, format="wav")
1084+
# Switch this to use the endpoint
1085+
openai.api_key = (
1086+
agixt.headers["Authorization"]
1087+
.replace("Bearer ", "")
1088+
.replace("bearer ", "")
1089+
)
1090+
openai.base_url = f"{agixt.base_uri}/v1/"
1091+
with open(wav_file, "rb") as audio_file:
1092+
transcription = openai.audio.transcriptions.create(
1093+
model=agent_name, file=audio_file
1094+
)
1095+
new_prompt += transcription.text
1096+
if "video_url" in msg:
1097+
video_url = str(
1098+
msg["video_url"]["url"]
1099+
if "url" in msg["video_url"]
1100+
else msg["video_url"]
1101+
)
1102+
if "collection_number" in msg:
1103+
collection_number = int(msg["collection_number"])
1104+
else:
1105+
collection_number = 0
1106+
if video_url.startswith("https://www.youtube.com/watch?v="):
1107+
agixt.learn_url(
1108+
agent_name=agent_name,
1109+
url=video_url,
1110+
collection_number=collection_number,
1111+
)
1112+
if (
1113+
"file_url" in msg
1114+
or "application_url" in msg
1115+
or "text_url" in msg
1116+
or "url" in msg
1117+
):
1118+
file_url = str(
1119+
msg["file_url"]["url"]
1120+
if "url" in msg["file_url"]
1121+
else msg["file_url"]
1122+
)
1123+
if (
1124+
"collection_number" in message
1125+
or "collection_number" in msg
1126+
):
1127+
collection_number = int(
1128+
message["collection_number"]
1129+
if "collection_number" in message
1130+
else msg["collection_number"]
1131+
)
1132+
else:
1133+
collection_number = 0
1134+
if file_url.startswith("http"):
1135+
if file_url.startswith(
1136+
"https://www.youtube.com/watch?v="
1137+
):
1138+
agixt.learn_url(
1139+
agent_name=agent_name,
1140+
url=file_url,
1141+
collection_number=collection_number,
1142+
)
1143+
elif file_url.startswith("https://github.com"):
1144+
agixt.learn_github_repo(
1145+
agent_name=agent_name,
1146+
github_repo=file_url,
1147+
github_user=(
1148+
agent_settings["GITHUB_USER"]
1149+
if "GITHUB_USER" in agent_settings
1150+
else None
1151+
),
1152+
github_token=(
1153+
agent_settings["GITHUB_TOKEN"]
1154+
if "GITHUB_TOKEN" in agent_settings
1155+
else None
1156+
),
1157+
github_branch=(
1158+
"main"
1159+
if "branch" not in message
1160+
else message["branch"]
1161+
),
1162+
collection_number=collection_number,
1163+
)
1164+
else:
1165+
agixt.learn_url(
1166+
agent_name=agent_name,
1167+
url=file_url,
1168+
collection_number=collection_number,
1169+
)
1170+
else:
1171+
file_type = (
1172+
file_url.split(",")[0].split("/")[1].split(";")[0]
1173+
)
1174+
file_data = base64.b64decode(file_url.split(",")[1])
1175+
file_path = (
1176+
f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}"
1177+
)
1178+
with open(file_path, "wb") as f:
1179+
f.write(file_data)
1180+
# file name should be a safe timestamp
1181+
file_name = f"Uploaded File {datetime.now().strftime('%Y%m%d%H%M%S')}.{file_type}"
1182+
agixt.learn_file(
1183+
agent_name=agent_name,
1184+
file_name=f"Uploaded File {uuid.uuid4().hex}.{file_type}",
1185+
file_content=file_data,
1186+
collection_number=collection_number,
1187+
)
1188+
response = func(new_prompt)
1189+
prompt_tokens = get_tokens(new_prompt)
1190+
completion_tokens = get_tokens(response)
1191+
total_tokens = int(prompt_tokens) + int(completion_tokens)
1192+
res_model = {
1193+
"id": conversation_name,
1194+
"object": "chat.completion",
1195+
"created": int(time.time()),
1196+
"model": agent_name,
1197+
"choices": [
1198+
{
1199+
"index": 0,
1200+
"message": {
1201+
"role": "assistant",
1202+
"content": str(response),
1203+
},
1204+
"finish_reason": "stop",
1205+
"logprobs": None,
1206+
}
1207+
],
1208+
"usage": {
1209+
"prompt_tokens": prompt_tokens,
1210+
"completion_tokens": completion_tokens,
1211+
"total_tokens": total_tokens,
1212+
},
1213+
}
1214+
return res_model
1215+
1216+
return wrapper
1217+
1218+
return decorator

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name="agixtsdk",
11-
version="0.0.37",
11+
version="0.0.38",
1212
description="The AGiXT SDK for Python.",
1313
long_description=long_description,
1414
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)