Skip to content

Commit ad2ea69

Browse files
Change package name.
1 parent 6bc10df commit ad2ea69

File tree

3 files changed

+208
-178
lines changed

3 files changed

+208
-178
lines changed

singlestoredb/ai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .chat import SingleStoreChatFactory # noqa: F401
2-
from .chat import SingleStoreChatFactoryDebug # noqa: F401
2+
from .debug import SingleStoreChatFactoryDebug # noqa: F401
33
from .embeddings import SingleStoreEmbeddingsFactory # noqa: F401

singlestoredb/ai/chat.py

Lines changed: 0 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -205,180 +205,3 @@ def inject_auth_headers(request: httpx.Request) -> None:
205205
**openai_kwargs,
206206
**kwargs,
207207
)
208-
209-
210-
def SingleStoreChatFactoryDebug(
211-
model_name: str,
212-
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
213-
streaming: bool = True,
214-
http_client: Optional[httpx.Client] = None,
215-
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
216-
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
217-
base_url: Optional[str] = None,
218-
hosting_platform: Optional[str] = None,
219-
timeout: Optional[float] = None,
220-
**kwargs: Any,
221-
) -> Union[ChatOpenAI, ChatBedrockConverse]:
222-
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
223-
"""
224-
# Handle api_key and obo_token as callable functions
225-
if callable(api_key):
226-
api_key_getter_fn = api_key
227-
else:
228-
def api_key_getter_fn() -> Optional[str]:
229-
if api_key is None:
230-
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
231-
return api_key
232-
233-
if obo_token_getter is not None:
234-
obo_token_getter_fn = obo_token_getter
235-
else:
236-
if callable(obo_token):
237-
obo_token_getter_fn = obo_token
238-
else:
239-
def obo_token_getter_fn() -> Optional[str]:
240-
return obo_token
241-
242-
# handle model info
243-
if base_url is None:
244-
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
245-
if hosting_platform is None:
246-
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
247-
if base_url is None or hosting_platform is None:
248-
inference_api_manager = (
249-
manage_workspaces().organizations.current.inference_apis
250-
)
251-
info = inference_api_manager.get(model_name=model_name)
252-
else:
253-
info = InferenceAPIInfo(
254-
service_id='',
255-
model_name=model_name,
256-
name='',
257-
connection_url=base_url,
258-
project_id='',
259-
hosting_platform=hosting_platform,
260-
)
261-
if base_url is not None:
262-
info.connection_url = base_url
263-
if hosting_platform is not None:
264-
info.hosting_platform = hosting_platform
265-
266-
# Extract timeouts from http_client if provided
267-
t = http_client.timeout if http_client is not None else None
268-
connect_timeout = None
269-
read_timeout = None
270-
if t is not None:
271-
if isinstance(t, httpx.Timeout):
272-
if t.connect is not None:
273-
connect_timeout = float(t.connect)
274-
if t.read is not None:
275-
read_timeout = float(t.read)
276-
if connect_timeout is None and read_timeout is not None:
277-
connect_timeout = read_timeout
278-
if read_timeout is None and connect_timeout is not None:
279-
read_timeout = connect_timeout
280-
elif isinstance(t, (int, float)):
281-
connect_timeout = float(t)
282-
read_timeout = float(t)
283-
if timeout is not None:
284-
connect_timeout = timeout
285-
read_timeout = timeout
286-
t = httpx.Timeout(timeout)
287-
288-
if info.hosting_platform == 'Amazon':
289-
# Instantiate Bedrock client
290-
cfg_kwargs = {
291-
'signature_version': UNSIGNED,
292-
'retries': {'max_attempts': 1, 'mode': 'standard'},
293-
}
294-
if read_timeout is not None:
295-
cfg_kwargs['read_timeout'] = read_timeout
296-
if connect_timeout is not None:
297-
cfg_kwargs['connect_timeout'] = connect_timeout
298-
299-
cfg = Config(**cfg_kwargs)
300-
client = boto3.client(
301-
'bedrock-runtime',
302-
endpoint_url=info.connection_url,
303-
region_name='us-east-1',
304-
aws_access_key_id='placeholder',
305-
aws_secret_access_key='placeholder',
306-
config=cfg,
307-
)
308-
309-
def _inject_headers(request: Any, **_ignored: Any) -> None:
310-
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
311-
if api_key_getter_fn is not None:
312-
token_val = api_key_getter_fn()
313-
if token_val:
314-
request.headers['Authorization'] = f'Bearer {token_val}'
315-
if obo_token_getter_fn is not None:
316-
obo_val = obo_token_getter_fn()
317-
if obo_val:
318-
request.headers['X-S2-OBO'] = obo_val
319-
request.headers.pop('X-Amz-Date', None)
320-
request.headers.pop('X-Amz-Security-Token', None)
321-
322-
emitter = client._endpoint._event_emitter
323-
emitter.register_first(
324-
'before-send.bedrock-runtime.Converse',
325-
_inject_headers,
326-
)
327-
emitter.register_first(
328-
'before-send.bedrock-runtime.ConverseStream',
329-
_inject_headers,
330-
)
331-
emitter.register_first(
332-
'before-send.bedrock-runtime.InvokeModel',
333-
_inject_headers,
334-
)
335-
emitter.register_first(
336-
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
337-
_inject_headers,
338-
)
339-
340-
return ChatBedrockConverse(
341-
model_id=model_name,
342-
endpoint_url=info.connection_url,
343-
region_name='us-east-1',
344-
aws_access_key_id='placeholder',
345-
aws_secret_access_key='placeholder',
346-
disable_streaming=not streaming,
347-
client=client,
348-
**kwargs,
349-
)
350-
351-
def inject_auth_headers(request: httpx.Request) -> None:
352-
"""Inject dynamic auth/OBO headers before request is sent."""
353-
if api_key_getter_fn is not None:
354-
token_val = api_key_getter_fn()
355-
if token_val:
356-
request.headers['Authorization'] = f'Bearer {token_val}'
357-
if obo_token_getter_fn is not None:
358-
obo_val = obo_token_getter_fn()
359-
if obo_val:
360-
request.headers['X-S2-OBO'] = obo_val
361-
362-
if t is not None:
363-
http_client = httpx.Client(
364-
timeout=t,
365-
event_hooks={'request': [inject_auth_headers]},
366-
)
367-
else:
368-
http_client = httpx.Client(
369-
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
370-
event_hooks={'request': [inject_auth_headers]},
371-
)
372-
373-
# OpenAI / Azure OpenAI path
374-
openai_kwargs = dict(
375-
base_url=info.connection_url,
376-
api_key='placeholder',
377-
model=model_name,
378-
streaming=streaming,
379-
http_client=http_client,
380-
)
381-
return ChatOpenAI(
382-
**openai_kwargs,
383-
**kwargs,
384-
)

singlestoredb/ai/debug.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import os
2+
from typing import Any
3+
from typing import Callable
4+
from typing import Optional
5+
from typing import Union
6+
7+
import httpx
8+
9+
from singlestoredb import manage_workspaces
10+
from singlestoredb.management.inference_api import InferenceAPIInfo
11+
12+
try:
13+
from langchain_openai import ChatOpenAI
14+
except ImportError:
15+
raise ImportError(
16+
'Could not import langchain_openai python package. '
17+
'Please install it with `pip install langchain_openai`.',
18+
)
19+
20+
try:
21+
from langchain_aws import ChatBedrockConverse
22+
except ImportError:
23+
raise ImportError(
24+
'Could not import langchain-aws python package. '
25+
'Please install it with `pip install langchain-aws`.',
26+
)
27+
28+
import boto3
29+
from botocore import UNSIGNED
30+
from botocore.config import Config
31+
32+
33+
def SingleStoreChatFactoryDebug(
34+
model_name: str,
35+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
36+
streaming: bool = True,
37+
http_client: Optional[httpx.Client] = None,
38+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
39+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
40+
base_url: Optional[str] = None,
41+
hosting_platform: Optional[str] = None,
42+
timeout: Optional[float] = None,
43+
**kwargs: Any,
44+
) -> Union[ChatOpenAI, ChatBedrockConverse]:
45+
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
46+
"""
47+
# Handle api_key and obo_token as callable functions
48+
if callable(api_key):
49+
api_key_getter_fn = api_key
50+
else:
51+
def api_key_getter_fn() -> Optional[str]:
52+
if api_key is None:
53+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
54+
return api_key
55+
56+
if obo_token_getter is not None:
57+
obo_token_getter_fn = obo_token_getter
58+
else:
59+
if callable(obo_token):
60+
obo_token_getter_fn = obo_token
61+
else:
62+
def obo_token_getter_fn() -> Optional[str]:
63+
return obo_token
64+
65+
# handle model info
66+
if base_url is None:
67+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
68+
if hosting_platform is None:
69+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
70+
if base_url is None or hosting_platform is None:
71+
inference_api_manager = (
72+
manage_workspaces().organizations.current.inference_apis
73+
)
74+
info = inference_api_manager.get(model_name=model_name)
75+
else:
76+
info = InferenceAPIInfo(
77+
service_id='',
78+
model_name=model_name,
79+
name='',
80+
connection_url=base_url,
81+
project_id='',
82+
hosting_platform=hosting_platform,
83+
)
84+
if base_url is not None:
85+
info.connection_url = base_url
86+
if hosting_platform is not None:
87+
info.hosting_platform = hosting_platform
88+
89+
# Extract timeouts from http_client if provided
90+
t = http_client.timeout if http_client is not None else None
91+
connect_timeout = None
92+
read_timeout = None
93+
if t is not None:
94+
if isinstance(t, httpx.Timeout):
95+
if t.connect is not None:
96+
connect_timeout = float(t.connect)
97+
if t.read is not None:
98+
read_timeout = float(t.read)
99+
if connect_timeout is None and read_timeout is not None:
100+
connect_timeout = read_timeout
101+
if read_timeout is None and connect_timeout is not None:
102+
read_timeout = connect_timeout
103+
elif isinstance(t, (int, float)):
104+
connect_timeout = float(t)
105+
read_timeout = float(t)
106+
if timeout is not None:
107+
connect_timeout = timeout
108+
read_timeout = timeout
109+
t = httpx.Timeout(timeout)
110+
111+
if info.hosting_platform == 'Amazon':
112+
# Instantiate Bedrock client
113+
cfg_kwargs = {
114+
'signature_version': UNSIGNED,
115+
'retries': {'max_attempts': 1, 'mode': 'standard'},
116+
}
117+
if read_timeout is not None:
118+
cfg_kwargs['read_timeout'] = read_timeout
119+
if connect_timeout is not None:
120+
cfg_kwargs['connect_timeout'] = connect_timeout
121+
122+
cfg = Config(**cfg_kwargs)
123+
client = boto3.client(
124+
'bedrock-runtime',
125+
endpoint_url=info.connection_url,
126+
region_name='us-east-1',
127+
aws_access_key_id='placeholder',
128+
aws_secret_access_key='placeholder',
129+
config=cfg,
130+
)
131+
132+
def _inject_headers(request: Any, **_ignored: Any) -> None:
133+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
134+
if api_key_getter_fn is not None:
135+
token_val = api_key_getter_fn()
136+
if token_val:
137+
request.headers['Authorization'] = f'Bearer {token_val}'
138+
if obo_token_getter_fn is not None:
139+
obo_val = obo_token_getter_fn()
140+
if obo_val:
141+
request.headers['X-S2-OBO'] = obo_val
142+
request.headers.pop('X-Amz-Date', None)
143+
request.headers.pop('X-Amz-Security-Token', None)
144+
145+
emitter = client._endpoint._event_emitter
146+
emitter.register_first(
147+
'before-send.bedrock-runtime.Converse',
148+
_inject_headers,
149+
)
150+
emitter.register_first(
151+
'before-send.bedrock-runtime.ConverseStream',
152+
_inject_headers,
153+
)
154+
emitter.register_first(
155+
'before-send.bedrock-runtime.InvokeModel',
156+
_inject_headers,
157+
)
158+
emitter.register_first(
159+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
160+
_inject_headers,
161+
)
162+
163+
return ChatBedrockConverse(
164+
model_id=model_name,
165+
endpoint_url=info.connection_url,
166+
region_name='us-east-1',
167+
aws_access_key_id='placeholder',
168+
aws_secret_access_key='placeholder',
169+
disable_streaming=not streaming,
170+
client=client,
171+
**kwargs,
172+
)
173+
174+
def inject_auth_headers(request: httpx.Request) -> None:
175+
"""Inject dynamic auth/OBO headers before request is sent."""
176+
if api_key_getter_fn is not None:
177+
token_val = api_key_getter_fn()
178+
if token_val:
179+
request.headers['Authorization'] = f'Bearer {token_val}'
180+
if obo_token_getter_fn is not None:
181+
obo_val = obo_token_getter_fn()
182+
if obo_val:
183+
request.headers['X-S2-OBO'] = obo_val
184+
185+
if t is not None:
186+
http_client = httpx.Client(
187+
timeout=t,
188+
event_hooks={'request': [inject_auth_headers]},
189+
)
190+
else:
191+
http_client = httpx.Client(
192+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
193+
event_hooks={'request': [inject_auth_headers]},
194+
)
195+
196+
# OpenAI / Azure OpenAI path
197+
openai_kwargs = dict(
198+
base_url=info.connection_url,
199+
api_key='placeholder',
200+
model=model_name,
201+
streaming=streaming,
202+
http_client=http_client,
203+
)
204+
return ChatOpenAI(
205+
**openai_kwargs,
206+
**kwargs,
207+
)

0 commit comments

Comments
 (0)