Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dist/
build/
*.egg-info
*.egg-info
__pycache__/
114 changes: 78 additions & 36 deletions chatgpt/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import diskcache
import openai
from openai.error import OpenAIError
from openai import OpenAIError, AsyncOpenAI, AsyncStream, Stream
from openai.types.chat import ChatCompletionMessage, ChatCompletion, ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam


logs_dir = os.path.join(os.getcwd(), '.chatgpt_history/logs')
Expand Down Expand Up @@ -52,64 +53,71 @@ def sync_wrapper(*args, **kwargs):


@retry_on_exception()
def complete(messages=None, model='gpt-4', temperature=0, use_cache=False, **kwargs):
def complete(messages:list[ChatCompletionMessageParam]=None, model='gpt-4', temperature=0, use_cache=False, **kwargs):
if use_cache:
key = get_key(messages)
if key in cache:
return cache.get(key)
response = openai.ChatCompletion.create(
response: ChatCompletion | Stream[ChatCompletionChunk] = openai.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
**kwargs
)
stream = kwargs.get('stream', False)
# import pdb; pdb.set_trace()
if stream:
n = kwargs.get('n', 1)
return parse_stream(response, messages, n=n)
return parse_response(response, messages, **kwargs)


@retry_on_exception()
async def acomplete(messages=None, model='gpt-4', temperature=0, use_cache=False, **kwargs):
async def acomplete(messages:list[ChatCompletionMessageParam]=None, model='gpt-4', temperature=0, use_cache=False, **kwargs):
if use_cache:
key = get_key(messages)
if key in cache:
return cache.get(key)
response = await openai.ChatCompletion.acreate(
client = AsyncOpenAI()
response: ChatCompletion | AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
**kwargs
)
return parse_response(response, messages, **kwargs)
stream = kwargs.get('stream', False)
if stream:
n = kwargs.get('n', 1)
return await parse_astream(response, messages, n=n)
return await parse_response(response, messages, **kwargs)


def parse_response(response, messages, **kwargs):
def parse_response(response: ChatCompletion, messages:list[ChatCompletionMessageParam], **kwargs):
n = kwargs.get('n', 1)
stream = kwargs.get('stream', False)
if stream:
return parse_stream(response, messages, n=n)


results = []
for choice in response.choices:
message = choice.message
if kwargs.get('functions', None) and 'function_call' in message:
if kwargs.get('functions', None) and message.function_call:
name = message.function_call.name
try:
args = json.loads(message.function_call.arguments)
except json.decoder.JSONDecodeError as e:
print('ERROR: OpenAI returned invalid JSON for function call arguments')
raise e
results.append({'role': 'function', 'name': name, 'args': args})
log_completion(messages + [results[-1]])
else:
results.append(message.content)
log_completion(messages + [message])
# results.append({'role': 'function', 'name': name, 'args': args})
# log_completion(messages, results[-1])
results.append(message)
log_completion(messages, message)

output = results if n > 1 else results[0]
output = results if n > 1 else results[0]
cache.set(get_key(messages), output)
return output


def parse_stream(response, messages, n=1):
def parse_stream(response: Stream[ChatCompletionChunk], messages:list[ChatCompletionMessageParam], n=1):
results = ['' for _ in range(n)]
chunk: ChatCompletionChunk
for chunk in response:
for choice in chunk.choices:
if not choice.delta:
Expand All @@ -125,34 +133,68 @@ def parse_stream(response, messages, n=1):
yield (text, idx)

for r in results:
log_completion(messages + [{'role': 'assistant', 'content': r}])
log_completion(messages, r)
cache.set(get_key(messages), results)

async def parse_astream(response: AsyncStream[ChatCompletionChunk], messages:list[ChatCompletionMessageParam], n=1):
results = ['' for _ in range(n)]
chunk: ChatCompletionChunk
async for chunk in response:
for choice in chunk.choices:
if not choice.delta:
continue
text = choice.delta.content
if not text:
continue
idx = choice.index
results[idx] += text
if n == 1:
yield text
else:
yield (text, idx)

for r in results:
log_completion(messages, r)
cache.set(get_key(messages), results)

def log_completion(messages):
def log_completion(messages: list[ChatCompletionMessageParam], result: ChatCompletionMessage = None):
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')

save_path = os.path.join(logs_dir, timestamp + '.txt')
os.makedirs(os.path.dirname(save_path), exist_ok=True)

log = ""
for message in messages:
log += message['role'].upper() + ' ' + '-'*100 + '\n\n'
if 'name' in message:
log += f"Called function: {message['name']}("
if 'args' in message:
log += '\n'
for k, v in message['args'].items():
log += f"\t{k}={repr(v)},\n"
log += ')'
if 'content' in message:
log += '\nContent:\n' + message['content']
elif 'function_call' in message:
log += f"Called function: {message['function_call'].get('name', 'UNKNOWN')}(\n"
log += ')'
else:
log += message["content"]
log += message['role'].upper() + ' ' + '-'*100 + '\n'
if "content" in message:
log += 'Content:\n' + message['content'] + "\n"
if "function_call" in message: # TODO: remove later since function_call is deprecated
log += f'Call function\n:{message["function_call"]["name"]}({message["function_call"]["arguments"]})\n'
if "tool_calls" in message:
for tool in message["tool_calls"]:
log += f'\nCall {tool["type"]}:\n'
if tool["type"] == 'function':
log += f'{tool["function"]["name"]}({tool["function"]["arguments"]}) id={tool["id"]}\n'
else:
raise NotImplementedError(f'Tool type {tool["type"]} not implemented in logger')
log += '\n\n'


if result:
log += result.role.upper() + ' ' + '-'*100 + '\n'
if result.content:
log += 'Content:\n' + result['content']
if result.function_call:
log += f'Called function:\n{result.function_call.name}({result.function_call.arguments})\n'
if result.tool_calls:
for tool in result.tool_calls:
log += f'\nCalled {tool.type}:\n'
if tool.type == 'function':
log += f'{tool.function.name}({tool.function.arguments}) id={tool.id}\n'
else:
raise NotImplementedError(f"Tool type {tool.type} not implemented in logger")

log += '\n\n'

with open(save_path, 'w') as f:
f.write(log)