Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def test_image_caption_extraction_schema_defaults():
assert schema.endpoint_url.startswith("https://")
assert schema.prompt.startswith("Caption")
assert schema.model_name.startswith("nvidia/")
assert schema.context_text_max_chars == 0
assert schema.temperature == 1.0
assert schema.raise_on_failure is False


Expand All @@ -40,6 +42,26 @@ def test_image_caption_extraction_schema_accepts_truthy_values():
assert schema.raise_on_failure is False


def test_image_caption_extraction_schema_context_text_max_chars_custom():
schema = ImageCaptionExtractionSchema(context_text_max_chars=512)
assert schema.context_text_max_chars == 512


def test_image_caption_extraction_schema_context_text_max_chars_none_coerced():
schema = ImageCaptionExtractionSchema(context_text_max_chars=None)
assert schema.context_text_max_chars == 0


def test_image_caption_extraction_schema_temperature_custom():
schema = ImageCaptionExtractionSchema(temperature=0.5)
assert schema.temperature == 0.5


def test_image_caption_extraction_schema_temperature_none_coerced():
schema = ImageCaptionExtractionSchema(temperature=None)
assert schema.temperature == 1.0


def test_image_caption_extraction_schema_rejects_extra_fields():
with pytest.raises(ValidationError) as excinfo:
ImageCaptionExtractionSchema(extra_field="oops")
Expand Down
216 changes: 214 additions & 2 deletions api/api_tests/internal/transform/test_caption_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_transform_image_create_vlm_caption_internal_happy_path(
dummy_task_config["api_key"],
dummy_task_config["endpoint_url"],
dummy_task_config["model_name"],
temperature=1.0,
)

# Assert captions updated correctly in the DataFrame
Expand Down Expand Up @@ -125,6 +126,7 @@ def test_transform_image_create_vlm_caption_internal_uses_fallback_config(
dummy_transform_config.api_key,
dummy_transform_config.endpoint_url,
dummy_transform_config.model_name,
temperature=1.0,
)

# Assert captions updated correctly
Expand Down Expand Up @@ -236,7 +238,7 @@ def test_generate_captions_happy_path(mock_scale, mock_create_client):

# Assert infer called with correct data
expected_payload = {"base64_images": ["scaled_b64img1", "scaled_b64img2"], "prompt": "describe this"}
mock_client.infer.assert_called_once_with(expected_payload, model_name="test_model")
mock_client.infer.assert_called_once_with(expected_payload, model_name="test_model", temperature=1.0)

# Result matches mock captions
assert result == ["Caption 1", "Caption 2"]
Expand Down Expand Up @@ -280,6 +282,216 @@ def test_generate_captions_empty_images_returns_empty_list(mock_scale, mock_crea
model_name="test_model",
)

mock_client.infer.assert_called_once_with({"base64_images": [], "prompt": "describe this"}, model_name="test_model")
mock_client.infer.assert_called_once_with(
{"base64_images": [], "prompt": "describe this"}, model_name="test_model", temperature=1.0
)

assert result == []


# --- _gather_context_text_for_image tests ---


def test_gather_context_text_page_match():
"""Page text is returned when page_number matches."""
image_meta = {
"content_metadata": {
"type": "image",
"page_number": 3,
},
}
page_text_map = {3: ["page three text", "more text"]}
result = module_under_test._gather_context_text_for_image(image_meta, page_text_map, 200)
assert result == "page three text more text"


def test_gather_context_text_truncation():
"""Text is truncated to max_chars."""
image_meta = {
"content_metadata": {
"type": "image",
"page_number": 0,
},
}
page_text_map = {0: ["a" * 500]}
result = module_under_test._gather_context_text_for_image(image_meta, page_text_map, 10)
assert len(result) == 10


def test_gather_context_text_safety_cap():
"""Text is capped at _MAX_CONTEXT_TEXT_CHARS even if max_chars is larger."""
image_meta = {
"content_metadata": {
"type": "image",
"page_number": 0,
},
}
big_text = "x" * 10000
page_text_map = {0: [big_text]}
result = module_under_test._gather_context_text_for_image(image_meta, page_text_map, 99999)
assert len(result) == module_under_test._MAX_CONTEXT_TEXT_CHARS


def test_gather_context_text_no_text():
"""Returns empty string when no text is available."""
image_meta = {
"content_metadata": {
"type": "image",
"page_number": 5,
},
}
result = module_under_test._gather_context_text_for_image(image_meta, {}, 200)
assert result == ""


def test_gather_context_text_wrong_page():
"""Returns empty string when page number doesn't match any text."""
image_meta = {
"content_metadata": {
"type": "image",
"page_number": 99,
},
}
page_text_map = {0: ["some text"]}
result = module_under_test._gather_context_text_for_image(image_meta, page_text_map, 200)
assert result == ""


# --- _build_prompt_with_context tests ---


def test_build_prompt_with_context():
result = module_under_test._build_prompt_with_context("Caption this:", "nearby text")
assert result == "Text near this image:\n---\nnearby text\n---\n\nCaption this:"


def test_build_prompt_with_empty_context():
result = module_under_test._build_prompt_with_context("Caption this:", "")
assert result == "Caption this:"


# --- _build_page_text_map tests ---


def test_build_page_text_map():
df = pd.DataFrame(
[
{
"metadata": {
"content": "text on page 0",
"content_metadata": {"type": "text", "page_number": 0},
}
},
{
"metadata": {
"content": "more on page 0",
"content_metadata": {"type": "text", "page_number": 0},
}
},
{
"metadata": {
"content": "image content",
"content_metadata": {"type": "image", "page_number": 0},
}
},
{
"metadata": {
"content": "page 1 text",
"content_metadata": {"type": "text", "page_number": 1},
}
},
]
)
result = module_under_test._build_page_text_map(df)
assert result == {0: ["text on page 0", "more on page 0"], 1: ["page 1 text"]}


# --- Context-enabled integration tests ---


@patch(f"{MODULE_UNDER_TEST}._generate_captions")
def test_transform_context_enabled_per_image_calls(mock_generate, dummy_transform_config):
"""With context enabled, each image gets its own VLM call with enriched prompt."""
df = pd.DataFrame(
[
{
"metadata": {
"content": "b64_img1",
"content_metadata": {"type": "image", "page_number": 0},
"image_metadata": {},
}
},
{
"metadata": {
"content": "page zero text",
"content_metadata": {"type": "text", "page_number": 0},
"image_metadata": {},
}
},
]
)
mock_generate.return_value = ["caption_with_context"]

task_config = {
"api_key": "key",
"prompt": "Caption this:",
"system_prompt": "sys",
"endpoint_url": "https://url",
"model_name": "model",
"context_text_max_chars": 500,
}

result = transform_image_create_vlm_caption_internal(df.copy(), task_config, dummy_transform_config)

# Should be called once (one image)
assert mock_generate.call_count == 1
call_args = mock_generate.call_args
# The prompt should be enriched with context
assert "Text near this image:" in call_args[0][1]
assert "page zero text" in call_args[0][1]
assert "Caption this:" in call_args[0][1]
# The image should be passed individually
assert call_args[0][0] == ["b64_img1"]
# Caption should be set
assert result.iloc[0]["metadata"]["image_metadata"]["caption"] == "caption_with_context"


@patch(f"{MODULE_UNDER_TEST}._generate_captions")
def test_transform_temperature_forwarded(mock_generate, dummy_df_with_images, dummy_transform_config):
"""Temperature from task_config is forwarded to _generate_captions."""
mock_generate.return_value = ["c1", "c2"]

task_config = {
"api_key": "key",
"prompt": "Describe",
"system_prompt": "sys",
"endpoint_url": "https://url",
"model_name": "model",
"temperature": 0.7,
}

transform_image_create_vlm_caption_internal(dummy_df_with_images.copy(), task_config, dummy_transform_config)

mock_generate.assert_called_once()
_, kwargs = mock_generate.call_args
assert kwargs["temperature"] == 0.7


@patch(f"{MODULE_UNDER_TEST}._generate_captions")
def test_transform_context_disabled_batch_preserved(
mock_generate, dummy_df_with_images, dummy_task_config, dummy_transform_config
):
"""With context disabled (default), batch behavior is unchanged."""
mock_generate.return_value = ["c1", "c2"]

_ = transform_image_create_vlm_caption_internal(
dummy_df_with_images.copy(), dummy_task_config, dummy_transform_config
)

# Should be called once in batch mode
mock_generate.assert_called_once()
call_args = mock_generate.call_args
# All images passed at once
assert call_args[0][0] == ["base64_image_1", "base64_image_2"]
# Prompt should NOT be enriched
assert "Text near this image:" not in call_args[0][1]
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class IngestTaskCaptionSchema(BaseModelNoExt):
prompt: Optional[str] = None
system_prompt: Optional[str] = None
model_name: Optional[str] = None
context_text_max_chars: Optional[int] = None
temperature: Optional[float] = None


class IngestTaskFilterParamsSchema(BaseModelNoExt):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class ImageCaptionExtractionSchema(BaseModel):
prompt: str = "Caption the content of this image:"
system_prompt: str = "/no_think"
model_name: str = "nvidia/nemotron-nano-12b-v2-vl"
context_text_max_chars: int = 0
temperature: float = 1.0
raise_on_failure: bool = False
model_config = ConfigDict(extra="forbid")

Expand All @@ -33,4 +35,8 @@ def _coerce_none_to_defaults(cls, values):
values["prompt"] = cls.model_fields["prompt"].default
if values.get("system_prompt") is None:
values["system_prompt"] = cls.model_fields["system_prompt"].default
if values.get("context_text_max_chars") is None:
values["context_text_max_chars"] = cls.model_fields["context_text_max_chars"].default
if values.get("temperature") is None:
values["temperature"] = cls.model_fields["temperature"].default
return values
Loading
Loading