Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,28 @@ uv run python examples/realtime_synthetic.py
uv lock --upgrade
```

### Test UI

The SDK includes an interactive test UI built with Gradio for quickly testing all SDK features without writing code.

```bash
# Install Gradio
pip install gradio

# Run the test UI
python test_ui.py
```

Then open http://localhost:7860 in your browser.

The UI provides tabs for:
- **Image Generation** - Text-to-image and image-to-image transformations
- **Video Generation** - Text-to-video, image-to-video, and video-to-video
- **Video Restyle** - Restyle videos using text prompts or reference images
- **Tokens** - Create short-lived client tokens

Enter your API key at the top of the interface to start testing.

### Publishing a New Version

The package is automatically published to PyPI when you create a GitHub release.
Expand Down
6 changes: 5 additions & 1 deletion decart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
QueueResultError,
TokenCreateError,
)
from .models import models, ModelDefinition
from .models import models, ModelDefinition, VideoRestyleInput
from .types import FileInput, ModelState, Prompt
from .queue import (
QueueClient,
Expand All @@ -31,6 +31,7 @@
RealtimeClient,
RealtimeConnectOptions,
ConnectionState,
AvatarOptions,
)

REALTIME_AVAILABLE = True
Expand All @@ -39,6 +40,7 @@
RealtimeClient = None # type: ignore
RealtimeConnectOptions = None # type: ignore
ConnectionState = None # type: ignore
AvatarOptions = None # type: ignore

__version__ = "0.0.1"

Expand All @@ -56,6 +58,7 @@
"QueueResultError",
"models",
"ModelDefinition",
"VideoRestyleInput",
"FileInput",
"ModelState",
"Prompt",
Expand All @@ -75,5 +78,6 @@
"RealtimeClient",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
]
)
52 changes: 50 additions & 2 deletions decart/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Literal, Optional, List, Generic, TypeVar
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, Field, ConfigDict, model_validator
from .errors import ModelNotFoundError
from .types import FileInput, MotionTrajectoryInput


RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt"]
RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt", "avatar-live"]
VideoModels = Literal[
"lucy-dev-i2v",
"lucy-fast-v2v",
Expand All @@ -13,6 +13,7 @@
"lucy-pro-v2v",
"lucy-pro-flf2v",
"lucy-motion",
"lucy-restyle-v2v",
]
ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"]
Model = Literal[RealTimeModels, VideoModels, ImageModels]
Expand Down Expand Up @@ -95,6 +96,36 @@ class ImageToMotionVideoInput(DecartBaseModel):
resolution: Optional[str] = None


class VideoRestyleInput(DecartBaseModel):
"""Input for lucy-restyle-v2v model.

Must provide either `prompt` OR `reference_image`, but not both.
`enhance_prompt` is only valid when using `prompt`, not `reference_image`.
"""

prompt: Optional[str] = Field(default=None, min_length=1, max_length=1000)
reference_image: Optional[FileInput] = None
data: FileInput
seed: Optional[int] = None
resolution: Optional[str] = None
enhance_prompt: Optional[bool] = None

@model_validator(mode="after")
def validate_prompt_or_reference_image(self) -> "VideoRestyleInput":
has_prompt = self.prompt is not None
has_reference_image = self.reference_image is not None

if has_prompt == has_reference_image:
raise ValueError("Must provide either 'prompt' or 'reference_image', but not both")

if has_reference_image and self.enhance_prompt is not None:
raise ValueError(
"'enhance_prompt' is only valid when using 'prompt', not 'reference_image'"
)

return self


class TextToImageInput(BaseModel):
prompt: str = Field(
...,
Expand Down Expand Up @@ -144,6 +175,14 @@ class ImageToImageInput(DecartBaseModel):
height=704,
input_schema=BaseModel,
),
"avatar-live": ModelDefinition(
name="avatar-live",
url_path="/v1/avatar-live/stream",
fps=25,
width=1280,
height=720,
input_schema=BaseModel,
),
},
"video": {
"lucy-dev-i2v": ModelDefinition(
Expand Down Expand Up @@ -202,6 +241,14 @@ class ImageToImageInput(DecartBaseModel):
height=704,
input_schema=ImageToMotionVideoInput,
),
"lucy-restyle-v2v": ModelDefinition(
name="lucy-restyle-v2v",
url_path="/v1/generate/lucy-restyle-v2v",
fps=25,
width=1280,
height=704,
input_schema=VideoRestyleInput,
),
},
"image": {
"lucy-pro-t2i": ModelDefinition(
Expand Down Expand Up @@ -247,6 +294,7 @@ def video(model: VideoModels) -> VideoModelDefinition:
- "lucy-dev-i2v" - Image-to-video (Dev quality)
- "lucy-fast-v2v" - Video-to-video (Fast quality)
- "lucy-motion" - Image-to-motion-video
- "lucy-restyle-v2v" - Video-to-video with prompt or reference image
"""
try:
return _MODELS["video"][model] # type: ignore[return-value]
Expand Down
2 changes: 1 addition & 1 deletion decart/queue/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def submit_job(

for key, value in inputs.items():
if value is not None:
if key in ("data", "start", "end"):
if key in ("data", "start", "end", "reference_image"):
content, content_type = await file_input_to_bytes(value, session)
form_data.add_field(key, content, content_type=content_type)
else:
Expand Down
3 changes: 2 additions & 1 deletion decart/realtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .client import RealtimeClient
from .types import RealtimeConnectOptions, ConnectionState
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions

__all__ = [
"RealtimeClient",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
]
96 changes: 91 additions & 5 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import Callable, Optional
import asyncio
import base64
import logging
import uuid
import aiohttp
from aiortc import MediaStreamTrack

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage
from .messages import PromptMessage, SetAvatarImageMessage
from .types import ConnectionState, RealtimeConnectOptions
from ..types import FileInput
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
from ..process.request import file_input_to_bytes

logger = logging.getLogger(__name__)


class RealtimeClient:
def __init__(self, manager: WebRTCManager, session_id: str):
def __init__(
self,
manager: WebRTCManager,
session_id: str,
http_session: Optional[aiohttp.ClientSession] = None,
is_avatar_live: bool = False,
):
self._manager = manager
self.session_id = session_id
self._http_session = http_session
self._is_avatar_live = is_avatar_live
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []

Expand All @@ -24,14 +36,16 @@ async def connect(
cls,
base_url: str,
api_key: str,
local_track: MediaStreamTrack,
local_track: Optional[MediaStreamTrack],
options: RealtimeConnectOptions,
integration: Optional[str] = None,
) -> "RealtimeClient":
session_id = str(uuid.uuid4())
ws_url = f"{base_url}{options.model.url_path}"
ws_url += f"?api_key={api_key}&model={options.model.name}"

is_avatar_live = options.model.name == "avatar-live"

config = WebRTCConfiguration(
webrtc_url=ws_url,
api_key=api_key,
Expand All @@ -43,24 +57,55 @@ async def connect(
initial_state=options.initial_state,
customize_offer=options.customize_offer,
integration=integration,
is_avatar_live=is_avatar_live,
)

# Create HTTP session for file conversions
http_session = aiohttp.ClientSession()

manager = WebRTCManager(config)
client = cls(manager=manager, session_id=session_id)
client = cls(
manager=manager,
session_id=session_id,
http_session=http_session,
is_avatar_live=is_avatar_live,
)

config.on_connection_state_change = client._emit_connection_change
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))

try:
await manager.connect(local_track)
# For avatar-live, convert and send avatar image before WebRTC connection
avatar_image_base64: Optional[str] = None
if is_avatar_live and options.avatar:
image_bytes, _ = await file_input_to_bytes(
options.avatar.avatar_image, http_session
)
avatar_image_base64 = base64.b64encode(image_bytes).decode("utf-8")

# Prepare initial prompt if provided
initial_prompt: Optional[dict] = None
if options.initial_prompt:
initial_prompt = {
"text": options.initial_prompt.text,
"enhance": options.initial_prompt.enhance,
}

await manager.connect(
local_track,
avatar_image_base64=avatar_image_base64,
initial_prompt=initial_prompt,
)

# Handle initial_state.prompt for backward compatibility (after WebRTC connection)
if options.initial_state:
if options.initial_state.prompt:
await client.set_prompt(
options.initial_state.prompt.text,
enrich=options.initial_state.prompt.enrich,
)
except Exception as e:
await http_session.close()
raise WebRTCError(str(e), cause=e)

return client
Expand Down Expand Up @@ -100,6 +145,45 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
finally:
self._manager.unregister_prompt_wait(prompt)

async def set_image(self, image: FileInput) -> None:
"""Set or update the avatar image.

Only available for avatar-live model.

Args:
image: The image to set. Can be bytes, Path, URL string, or file-like object.

Raises:
InvalidInputError: If not using avatar-live model or image is invalid.
DecartSDKError: If the server fails to acknowledge the image.
"""
if not self._is_avatar_live:
raise InvalidInputError("set_image() is only available for avatar-live model")

if not self._http_session:
raise InvalidInputError("HTTP session not available")

# Convert image to base64
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
image_base64 = base64.b64encode(image_bytes).decode("utf-8")

event, result = self._manager.register_image_set_wait()

try:
await self._manager.send_message(
SetAvatarImageMessage(type="set_image", image_data=image_base64)
)

try:
await asyncio.wait_for(event.wait(), timeout=15.0)
except asyncio.TimeoutError:
raise DecartSDKError("Image set acknowledgment timed out")

if not result["success"]:
raise DecartSDKError(result.get("error") or "Failed to set avatar image")
finally:
self._manager.unregister_image_set_wait()

def is_connected(self) -> bool:
return self._manager.is_connected()

Expand All @@ -108,6 +192,8 @@ def get_connection_state(self) -> ConnectionState:

async def disconnect(self) -> None:
await self._manager.cleanup()
if self._http_session and not self._http_session.closed:
await self._http_session.close()

def on(self, event: str, callback: Callable) -> None:
if event == "connection_change":
Expand Down
Loading
Loading