Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 38 additions & 85 deletions python/zvec/extension/sentence_transformer_embedding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from ..common.constants import TEXT, DenseVectorType, SparseVectorType
from ..tool import require_module
from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction
from .sentence_transformer_function import SentenceTransformerFunctionBase

Expand Down Expand Up @@ -197,6 +198,19 @@ def __init__(
# Store extra parameters
self._extra_params = kwargs

def _get_model_class(self):
"""Get the Sentence Transformer class.

Returns:
class: SentenceTransformer, the class used for dense embeddings.

Raises:
ImportError: If required packages are not installed.
"""
sentence_transformers = require_module("sentence_transformers")

return sentence_transformers.SentenceTransformer

@property
def dimension(self) -> int:
"""int: The expected dimensionality of the embedding vector."""
Expand Down Expand Up @@ -362,8 +376,7 @@ class DefaultLocalSparseEmbedding(
from huggingface_hub import login
login(token="your_huggingface_token")

5. To use a custom SPLADE model, you can subclass this class and override
the model_name in ``__init__``, or create your own implementation
5. To use a custom SPLADE model, create your own implementation
inheriting from ``SentenceTransformerFunctionBase`` and
``SparseEmbeddingFunction``.

Expand Down Expand Up @@ -656,6 +669,19 @@ def __init__(
# Load model to ensure it's available (will use cache if exists)
self._get_model()

def _get_model_class(self):
"""Get the Sentence Transformer class based on the model source.

Returns:
class: SparseEncoder, the class used for SPLADE sparse embeddings.

Raises:
ImportError: If required packages are not installed.
"""
sentence_transformers = require_module("sentence_transformers")

return sentence_transformers.SparseEncoder

@property
def extra_params(self) -> dict:
"""dict: Extra parameters for model-specific customization."""
Expand Down Expand Up @@ -714,41 +740,21 @@ def embed(self, input: str) -> SparseVectorType:
model = self._get_model()

# Use appropriate encoding method based on type
if self._encoding_type == "document" and hasattr(model, "encode_document"):
if self._encoding_type == "document":
# Use document encoding
sparse_matrix = model.encode_document([input])
elif hasattr(model, "encode_query"):
else:
# Use query encoding (default)
sparse_matrix = model.encode_query([input])
else:
# Fallback: manual implementation for older sentence-transformers
return self._manual_sparse_encode(input)

# Convert sparse matrix to dictionary
# SPLADE returns shape [1, vocab_size] for single input

# Check if it's a sparse matrix (duck typing - has toarray method)
if hasattr(sparse_matrix, "toarray"):
# Sparse matrix (CSR/CSC/etc.) - convert to dense array
sparse_array = sparse_matrix[0].toarray().flatten()
sparse_dict = {
int(idx): float(val)
for idx, val in enumerate(sparse_array)
if val > 0
}
else:
# Dense array format (numpy array or similar)
if isinstance(sparse_matrix, np.ndarray):
sparse_array = sparse_matrix[0]
else:
sparse_array = sparse_matrix

sparse_dict = {
int(idx): float(val)
for idx, val in enumerate(sparse_array)
if val > 0
}

# The decode method returns a list of (token_string, score) pairs for non-zero dimensions
# Then we post-process the tokens to IDs again
decoded = model.decode(sparse_matrix)[0]
if not decoded:
return {}
token_strings, scores = zip(*decoded, strict=True)
token_ids = model.tokenizer.convert_tokens_to_ids(token_strings)
sparse_dict = dict(zip(token_ids, scores, strict=True))
# Sort by indices (keys) to ensure consistent ordering
return dict(sorted(sparse_dict.items()))

Expand All @@ -757,59 +763,6 @@ def embed(self, input: str) -> SparseVectorType:
raise
raise RuntimeError(f"Failed to generate sparse embedding: {e!s}") from e

def _manual_sparse_encode(self, input: str) -> SparseVectorType:
"""Fallback manual SPLADE encoding for older sentence-transformers.

Args:
input (str): Input text to encode.

Returns:
SparseVectorType: Sparse vector as dictionary.
"""
import torch

model = self._get_model()

# Tokenize input
features = model.tokenize([input])

# Move to correct device
features = {k: v.to(model.device) for k, v in features.items()}

# Forward pass with no gradient
with torch.no_grad():
embeddings = model.forward(features)

# Get logits from model output
# SPLADE models typically output 'token_embeddings'
if isinstance(embeddings, dict) and "token_embeddings" in embeddings:
logits = embeddings["token_embeddings"][0] # First batch item
elif hasattr(embeddings, "token_embeddings"):
logits = embeddings.token_embeddings[0]
# Fallback: try to get first value
elif isinstance(embeddings, dict):
logits = next(iter(embeddings.values()))[0]
else:
logits = embeddings[0]

# Apply SPLADE activation: log(1 + relu(x))
relu_log = torch.log(1 + torch.relu(logits))

# Max pooling over token dimension (reduce to vocab size)
if relu_log.dim() > 1:
sparse_vec, _ = torch.max(relu_log, dim=0)
else:
sparse_vec = relu_log

# Convert to sparse dictionary (only non-zero values)
sparse_vec_np = sparse_vec.cpu().numpy()
sparse_dict = {
int(idx): float(val) for idx, val in enumerate(sparse_vec_np) if val > 0
}

# Sort by indices (keys) to ensure consistent ordering
return dict(sorted(sparse_dict.items()))

def _get_model(self):
"""Load or retrieve the SPLADE model from class-level cache.

Expand Down
27 changes: 13 additions & 14 deletions python/zvec/extension/sentence_transformer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def device(self) -> str:
return str(model.device)
return self._device or "cpu"

def _get_model_class(self):
"""Get the Sentence Transformer class.

Returns:
class: The Sentence Transformer class to use for loading models.

Raises:
ImportError: If required packages are not installed.
"""
raise NotImplementedError()

def _get_model(self):
"""Load or retrieve the Sentence Transformer model.

Expand All @@ -104,8 +115,6 @@ def _get_model(self):

# Load model
try:
sentence_transformers = require_module("sentence_transformers")

if self._model_source == "modelscope":
# Load from ModelScope
require_module("modelscope")
Expand All @@ -115,12 +124,12 @@ def _get_model(self):
model_dir = snapshot_download(self._model_name)

# Load from local path
self._model = sentence_transformers.SentenceTransformer(
self._model = self._get_model_class(
model_dir, device=self._device, trust_remote_code=True
)
else:
# Load from Hugging Face (default)
self._model = sentence_transformers.SentenceTransformer(
self._model = self._get_model_class(
self._model_name, device=self._device, trust_remote_code=True
)

Expand All @@ -138,13 +147,3 @@ def _get_model(self):
f"Failed to load Sentence Transformer model '{self._model_name}' "
f"from {self._model_source}: {e!s}"
) from e

def _is_sparse_model(self) -> bool:
"""Check if the loaded model is a sparse encoder (e.g., SPLADE).

Returns:
bool: True if model supports sparse encoding.
"""
model = self._get_model()
# Check if model has sparse encoding methods
return hasattr(model, "encode_query") or hasattr(model, "encode_document")
65 changes: 12 additions & 53 deletions python/zvec/extension/sentence_transformer_rerank_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction):
.. code-block:: python

# Recommended for users in China
reranker = SentenceTransformerReRanker(
reranker = DefaultLocalReRanker(
query="机器学习算法",
rerank_field="content",
model_source="modelscope"
Expand All @@ -109,9 +109,9 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction):

Examples:
>>> # Basic usage with default MS MARCO MiniLM model
>>> from zvec.extension import SentenceTransformerReRanker
>>> from zvec.extension import DefaultLocalReRanker
>>>
>>> reranker = SentenceTransformerReRanker(
>>> reranker = DefaultLocalReRanker(
... query="machine learning algorithms",
... topn=5,
... rerank_field="content"
Expand All @@ -125,15 +125,15 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction):
... )

>>> # Using ModelScope for users in China
>>> reranker = SentenceTransformerReRanker(
>>> reranker = DefaultLocalReRanker(
... query="深度学习",
... topn=10,
... rerank_field="content",
... model_source="modelscope"
... )

>>> # Using larger model for better quality
>>> reranker = SentenceTransformerReRanker(
>>> reranker = DefaultLocalReRanker(
... query="neural networks",
... topn=5,
... rerank_field="content",
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
device: Optional[str] = None,
batch_size: int = 32,
):
"""Initialize SentenceTransformerReRanker with query and configuration.
"""Initialize DefaultLocalReRanker with query and configuration.

Args:
query (Optional[str]): Query text for semantic matching. Required.
Expand Down Expand Up @@ -214,59 +214,18 @@ def __init__(
)
self._model = model

def _get_model(self):
"""Load or retrieve the CrossEncoder model.

This overrides the base class method to load CrossEncoder instead of
SentenceTransformer, as reranking requires cross-encoder models.
def _get_model_class(self):
"""Get the Sentence Transformer class.

Returns:
CrossEncoder: The loaded cross-encoder model instance.
class: CrossEncoder, the class used for cross-encoder re-ranking.

Raises:
ImportError: If required packages are not installed.
ValueError: If model cannot be loaded.
"""
# Return cached model if exists
if self._model is not None:
return self._model

# Load cross-encoder model
try:
sentence_transformers = require_module("sentence_transformers")

if self._model_source == "modelscope":
# Load from ModelScope
require_module("modelscope")
from modelscope.hub.snapshot_download import snapshot_download
sentence_transformers = require_module("sentence_transformers")

# Download model to cache
model_dir = snapshot_download(self._model_name)

# Load CrossEncoder from local path
model = sentence_transformers.CrossEncoder(
model_dir, device=self._device
)
else:
# Load CrossEncoder from Hugging Face (default)
model = sentence_transformers.CrossEncoder(
self._model_name, device=self._device
)

return model

except ImportError as e:
if "modelscope" in str(e) and self._model_source == "modelscope":
raise ImportError(
"ModelScope support requires the 'modelscope' package. "
"Please install it with: pip install modelscope"
) from e
raise
except Exception as e:
raise ValueError(
f"Failed to load CrossEncoder model '{self._model_name}' "
f"from {self._model_source}: {e!s}"
) from e
return sentence_transformers.CrossEncoder

@property
def query(self) -> str:
Expand Down Expand Up @@ -305,7 +264,7 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]:
- Processing time is O(n) where n is the number of documents

Examples:
>>> reranker = SentenceTransformerReRanker(
>>> reranker = DefaultLocalReRanker(
... query="machine learning",
... topn=3,
... rerank_field="content"
Expand Down