diff --git a/python/zvec/extension/sentence_transformer_embedding_function.py b/python/zvec/extension/sentence_transformer_embedding_function.py index 032f02e0..e19999ed 100644 --- a/python/zvec/extension/sentence_transformer_embedding_function.py +++ b/python/zvec/extension/sentence_transformer_embedding_function.py @@ -18,6 +18,7 @@ import numpy as np from ..common.constants import TEXT, DenseVectorType, SparseVectorType +from ..tool import require_module from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction from .sentence_transformer_function import SentenceTransformerFunctionBase @@ -197,6 +198,19 @@ def __init__( # Store extra parameters self._extra_params = kwargs + def _get_model_class(self): + """Get the Sentence Transformer class. + + Returns: + class: SentenceTransformer, the class used for dense embeddings. + + Raises: + ImportError: If required packages are not installed. + """ + sentence_transformers = require_module("sentence_transformers") + + return sentence_transformers.SentenceTransformer + @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" @@ -362,8 +376,7 @@ class DefaultLocalSparseEmbedding( from huggingface_hub import login login(token="your_huggingface_token") - 5. To use a custom SPLADE model, you can subclass this class and override - the model_name in ``__init__``, or create your own implementation + 5. To use a custom SPLADE model, create your own implementation inheriting from ``SentenceTransformerFunctionBase`` and ``SparseEmbeddingFunction``. @@ -656,6 +669,19 @@ def __init__( # Load model to ensure it's available (will use cache if exists) self._get_model() + def _get_model_class(self): + """Get the Sentence Transformer class based on the model source. + + Returns: + class: SparseEncoder, the class used for SPLADE sparse embeddings. + + Raises: + ImportError: If required packages are not installed. + """ + sentence_transformers = require_module("sentence_transformers") + + return sentence_transformers.SparseEncoder + @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" @@ -714,41 +740,21 @@ def embed(self, input: str) -> SparseVectorType: model = self._get_model() # Use appropriate encoding method based on type - if self._encoding_type == "document" and hasattr(model, "encode_document"): + if self._encoding_type == "document": # Use document encoding sparse_matrix = model.encode_document([input]) - elif hasattr(model, "encode_query"): + else: # Use query encoding (default) sparse_matrix = model.encode_query([input]) - else: - # Fallback: manual implementation for older sentence-transformers - return self._manual_sparse_encode(input) - - # Convert sparse matrix to dictionary - # SPLADE returns shape [1, vocab_size] for single input - - # Check if it's a sparse matrix (duck typing - has toarray method) - if hasattr(sparse_matrix, "toarray"): - # Sparse matrix (CSR/CSC/etc.) - convert to dense array - sparse_array = sparse_matrix[0].toarray().flatten() - sparse_dict = { - int(idx): float(val) - for idx, val in enumerate(sparse_array) - if val > 0 - } - else: - # Dense array format (numpy array or similar) - if isinstance(sparse_matrix, np.ndarray): - sparse_array = sparse_matrix[0] - else: - sparse_array = sparse_matrix - - sparse_dict = { - int(idx): float(val) - for idx, val in enumerate(sparse_array) - if val > 0 - } + # The decode method returns a list of (token_string, score) pairs for non-zero dimensions + # Then we post-process the tokens to IDs again + decoded = model.decode(sparse_matrix)[0] + if not decoded: + return {} + token_strings, scores = zip(*decoded, strict=True) + token_ids = model.tokenizer.convert_tokens_to_ids(token_strings) + sparse_dict = dict(zip(token_ids, scores, strict=True)) # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) @@ -757,59 +763,6 @@ def embed(self, input: str) -> SparseVectorType: raise raise RuntimeError(f"Failed to generate sparse embedding: {e!s}") from e - def _manual_sparse_encode(self, input: str) -> SparseVectorType: - """Fallback manual SPLADE encoding for older sentence-transformers. - - Args: - input (str): Input text to encode. - - Returns: - SparseVectorType: Sparse vector as dictionary. - """ - import torch - - model = self._get_model() - - # Tokenize input - features = model.tokenize([input]) - - # Move to correct device - features = {k: v.to(model.device) for k, v in features.items()} - - # Forward pass with no gradient - with torch.no_grad(): - embeddings = model.forward(features) - - # Get logits from model output - # SPLADE models typically output 'token_embeddings' - if isinstance(embeddings, dict) and "token_embeddings" in embeddings: - logits = embeddings["token_embeddings"][0] # First batch item - elif hasattr(embeddings, "token_embeddings"): - logits = embeddings.token_embeddings[0] - # Fallback: try to get first value - elif isinstance(embeddings, dict): - logits = next(iter(embeddings.values()))[0] - else: - logits = embeddings[0] - - # Apply SPLADE activation: log(1 + relu(x)) - relu_log = torch.log(1 + torch.relu(logits)) - - # Max pooling over token dimension (reduce to vocab size) - if relu_log.dim() > 1: - sparse_vec, _ = torch.max(relu_log, dim=0) - else: - sparse_vec = relu_log - - # Convert to sparse dictionary (only non-zero values) - sparse_vec_np = sparse_vec.cpu().numpy() - sparse_dict = { - int(idx): float(val) for idx, val in enumerate(sparse_vec_np) if val > 0 - } - - # Sort by indices (keys) to ensure consistent ordering - return dict(sorted(sparse_dict.items())) - def _get_model(self): """Load or retrieve the SPLADE model from class-level cache. diff --git a/python/zvec/extension/sentence_transformer_function.py b/python/zvec/extension/sentence_transformer_function.py index 1ba1662a..78be6a99 100644 --- a/python/zvec/extension/sentence_transformer_function.py +++ b/python/zvec/extension/sentence_transformer_function.py @@ -88,6 +88,17 @@ def device(self) -> str: return str(model.device) return self._device or "cpu" + def _get_model_class(self): + """Get the Sentence Transformer class. + + Returns: + class: The Sentence Transformer class to use for loading models. + + Raises: + ImportError: If required packages are not installed. + """ + raise NotImplementedError() + def _get_model(self): """Load or retrieve the Sentence Transformer model. @@ -104,8 +115,6 @@ def _get_model(self): # Load model try: - sentence_transformers = require_module("sentence_transformers") - if self._model_source == "modelscope": # Load from ModelScope require_module("modelscope") @@ -115,12 +124,12 @@ def _get_model(self): model_dir = snapshot_download(self._model_name) # Load from local path - self._model = sentence_transformers.SentenceTransformer( + self._model = self._get_model_class( model_dir, device=self._device, trust_remote_code=True ) else: # Load from Hugging Face (default) - self._model = sentence_transformers.SentenceTransformer( + self._model = self._get_model_class( self._model_name, device=self._device, trust_remote_code=True ) @@ -138,13 +147,3 @@ def _get_model(self): f"Failed to load Sentence Transformer model '{self._model_name}' " f"from {self._model_source}: {e!s}" ) from e - - def _is_sparse_model(self) -> bool: - """Check if the loaded model is a sparse encoder (e.g., SPLADE). - - Returns: - bool: True if model supports sparse encoding. - """ - model = self._get_model() - # Check if model has sparse encoding methods - return hasattr(model, "encode_query") or hasattr(model, "encode_document") diff --git a/python/zvec/extension/sentence_transformer_rerank_function.py b/python/zvec/extension/sentence_transformer_rerank_function.py index 58c5838f..adc719d9 100644 --- a/python/zvec/extension/sentence_transformer_rerank_function.py +++ b/python/zvec/extension/sentence_transformer_rerank_function.py @@ -95,7 +95,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): .. code-block:: python # Recommended for users in China - reranker = SentenceTransformerReRanker( + reranker = DefaultLocalReRanker( query="机器学习算法", rerank_field="content", model_source="modelscope" @@ -109,9 +109,9 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): Examples: >>> # Basic usage with default MS MARCO MiniLM model - >>> from zvec.extension import SentenceTransformerReRanker + >>> from zvec.extension import DefaultLocalReRanker >>> - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="machine learning algorithms", ... topn=5, ... rerank_field="content" @@ -125,7 +125,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): ... ) >>> # Using ModelScope for users in China - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="深度学习", ... topn=10, ... rerank_field="content", @@ -133,7 +133,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): ... ) >>> # Using larger model for better quality - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="neural networks", ... topn=5, ... rerank_field="content", @@ -177,7 +177,7 @@ def __init__( device: Optional[str] = None, batch_size: int = 32, ): - """Initialize SentenceTransformerReRanker with query and configuration. + """Initialize DefaultLocalReRanker with query and configuration. Args: query (Optional[str]): Query text for semantic matching. Required. @@ -214,59 +214,18 @@ def __init__( ) self._model = model - def _get_model(self): - """Load or retrieve the CrossEncoder model. - - This overrides the base class method to load CrossEncoder instead of - SentenceTransformer, as reranking requires cross-encoder models. + def _get_model_class(self): + """Get the Sentence Transformer class. Returns: - CrossEncoder: The loaded cross-encoder model instance. + class: CrossEncoder, the class used for cross-encoder re-ranking. Raises: ImportError: If required packages are not installed. - ValueError: If model cannot be loaded. """ - # Return cached model if exists - if self._model is not None: - return self._model - - # Load cross-encoder model - try: - sentence_transformers = require_module("sentence_transformers") - - if self._model_source == "modelscope": - # Load from ModelScope - require_module("modelscope") - from modelscope.hub.snapshot_download import snapshot_download + sentence_transformers = require_module("sentence_transformers") - # Download model to cache - model_dir = snapshot_download(self._model_name) - - # Load CrossEncoder from local path - model = sentence_transformers.CrossEncoder( - model_dir, device=self._device - ) - else: - # Load CrossEncoder from Hugging Face (default) - model = sentence_transformers.CrossEncoder( - self._model_name, device=self._device - ) - - return model - - except ImportError as e: - if "modelscope" in str(e) and self._model_source == "modelscope": - raise ImportError( - "ModelScope support requires the 'modelscope' package. " - "Please install it with: pip install modelscope" - ) from e - raise - except Exception as e: - raise ValueError( - f"Failed to load CrossEncoder model '{self._model_name}' " - f"from {self._model_source}: {e!s}" - ) from e + return sentence_transformers.CrossEncoder @property def query(self) -> str: @@ -305,7 +264,7 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: - Processing time is O(n) where n is the number of documents Examples: - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="machine learning", ... topn=3, ... rerank_field="content"