From 311c8d2b12b0f600a0b4064057256e2f788bdeae Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 17 Feb 2026 16:12:30 +0100 Subject: [PATCH 1/6] Allow custom model names for Dense, Sparse embeddings, fix Sparse Embeddings --- ...sentence_transformer_embedding_function.py | 146 +++++++----------- .../sentence_transformer_function.py | 29 ++-- .../sentence_transformer_rerank_function.py | 66 ++------ 3 files changed, 85 insertions(+), 156 deletions(-) diff --git a/python/zvec/extension/sentence_transformer_embedding_function.py b/python/zvec/extension/sentence_transformer_embedding_function.py index 032f02e0..e0ebb8bf 100644 --- a/python/zvec/extension/sentence_transformer_embedding_function.py +++ b/python/zvec/extension/sentence_transformer_embedding_function.py @@ -18,6 +18,7 @@ import numpy as np from ..common.constants import TEXT, DenseVectorType, SparseVectorType +from ..tool import require_module from .embedding_function import DenseEmbeddingFunction, SparseEmbeddingFunction from .sentence_transformer_function import SentenceTransformerFunctionBase @@ -39,6 +40,9 @@ class DefaultLocalDenseEmbedding( similarity tasks. It runs locally without requiring API keys. Args: + model_name (Optional[str]): Model identifier or local path. Defaults to: + - ``"all-MiniLM-L6-v2"`` for Hugging Face + - ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope model_source (Literal["huggingface", "modelscope"], optional): Model source. - ``"huggingface"``: Use Hugging Face Hub (default, for international users) - ``"modelscope"``: Use ModelScope (recommended for users in China) @@ -153,6 +157,7 @@ class DefaultLocalDenseEmbedding( def __init__( self, + model_name: Optional[str] = None, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, normalize_embeddings: bool = True, @@ -162,6 +167,9 @@ def __init__( """Initialize with all-MiniLM-L6-v2 model. Args: + model_name (Optional[str]): Model identifier or local path. Defaults to: + - ``"all-MiniLM-L6-v2"`` for Hugging Face + - ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). @@ -176,11 +184,12 @@ def __init__( ValueError: If model cannot be loaded. """ # Use different models based on source - if model_source == "modelscope": - # Use Chinese-optimized model for ModelScope (better for Chinese text) - model_name = "iic/nlp_gte_sentence-embedding_chinese-small" - else: - model_name = "all-MiniLM-L6-v2" + if model_name is None: + if model_source == "modelscope": + # Use Chinese-optimized model for ModelScope (better for Chinese text) + model_name = "iic/nlp_gte_sentence-embedding_chinese-small" + else: + model_name = "all-MiniLM-L6-v2" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( @@ -197,6 +206,20 @@ def __init__( # Store extra parameters self._extra_params = kwargs + @property + def _get_model_class(self): + """Get the Sentence Transformer class. + + Returns: + class: SentenceTransformer, the class used for dense embeddings. + + Raises: + ImportError: If required packages are not installed. + """ + sentence_transformers = require_module("sentence_transformers") + + return sentence_transformers.SentenceTransformer + @property def dimension(self) -> int: """int: The expected dimensionality of the embedding vector.""" @@ -368,6 +391,8 @@ class DefaultLocalSparseEmbedding( ``SparseEmbeddingFunction``. Args: + model_name (Optional[str]): Model identifier or local path. Defaults to + ``"naver/splade-cocondenser-ensembledistil"`` if None. model_source (Literal["huggingface", "modelscope"], optional): Model source. Defaults to ``"huggingface"``. ModelScope support may vary for SPLADE models. device (Optional[str], optional): Device to run the model on. @@ -589,6 +614,7 @@ def remove_from_cache( def __init__( self, + model_name: Optional[str] = None, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, encoding_type: Literal["query", "document"] = "query", @@ -597,6 +623,8 @@ def __init__( """Initialize with SPLADE model. Args: + model_name (Optional[str]): Model identifier or local path. Defaults to + ``"naver/splade-cocondenser-ensembledistil"`` if None. model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). @@ -640,7 +668,8 @@ def __init__( # Use publicly available SPLADE model (no gated access required) # Note: naver/splade-v3 requires authentication, so we use the # cocondenser-ensembledistil variant which is publicly accessible - model_name = "naver/splade-cocondenser-ensembledistil" + if model_name is None: + model_name = "naver/splade-cocondenser-ensembledistil" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( @@ -656,6 +685,20 @@ def __init__( # Load model to ensure it's available (will use cache if exists) self._get_model() + @property + def _get_model_class(self): + """Get the Sentence Transformer class based on the model source. + + Returns: + class: SparseEncoder, the class used for SPLADE sparse embeddings. + + Raises: + ImportError: If required packages are not installed. + """ + sentence_transformers = require_module("sentence_transformers") + + return sentence_transformers.SparseEncoder + @property def extra_params(self) -> dict: """dict: Extra parameters for model-specific customization.""" @@ -714,41 +757,19 @@ def embed(self, input: str) -> SparseVectorType: model = self._get_model() # Use appropriate encoding method based on type - if self._encoding_type == "document" and hasattr(model, "encode_document"): + if self._encoding_type == "document": # Use document encoding sparse_matrix = model.encode_document([input]) - elif hasattr(model, "encode_query"): + else: # Use query encoding (default) sparse_matrix = model.encode_query([input]) - else: - # Fallback: manual implementation for older sentence-transformers - return self._manual_sparse_encode(input) - - # Convert sparse matrix to dictionary - # SPLADE returns shape [1, vocab_size] for single input - - # Check if it's a sparse matrix (duck typing - has toarray method) - if hasattr(sparse_matrix, "toarray"): - # Sparse matrix (CSR/CSC/etc.) - convert to dense array - sparse_array = sparse_matrix[0].toarray().flatten() - sparse_dict = { - int(idx): float(val) - for idx, val in enumerate(sparse_array) - if val > 0 - } - else: - # Dense array format (numpy array or similar) - if isinstance(sparse_matrix, np.ndarray): - sparse_array = sparse_matrix[0] - else: - sparse_array = sparse_matrix - - sparse_dict = { - int(idx): float(val) - for idx, val in enumerate(sparse_array) - if val > 0 - } + # The decode method returns a list of (token_string, score) pairs for non-zero dimensions + # Then we post-process the tokens to IDs again + decoded = model.decode(sparse_matrix)[0] + token_strings, scores = zip(*decoded, strict=True) + token_ids = model.tokenizer.convert_tokens_to_ids(token_strings) + sparse_dict = dict(zip(token_ids, scores, strict=True)) # Sort by indices (keys) to ensure consistent ordering return dict(sorted(sparse_dict.items())) @@ -757,59 +778,6 @@ def embed(self, input: str) -> SparseVectorType: raise raise RuntimeError(f"Failed to generate sparse embedding: {e!s}") from e - def _manual_sparse_encode(self, input: str) -> SparseVectorType: - """Fallback manual SPLADE encoding for older sentence-transformers. - - Args: - input (str): Input text to encode. - - Returns: - SparseVectorType: Sparse vector as dictionary. - """ - import torch - - model = self._get_model() - - # Tokenize input - features = model.tokenize([input]) - - # Move to correct device - features = {k: v.to(model.device) for k, v in features.items()} - - # Forward pass with no gradient - with torch.no_grad(): - embeddings = model.forward(features) - - # Get logits from model output - # SPLADE models typically output 'token_embeddings' - if isinstance(embeddings, dict) and "token_embeddings" in embeddings: - logits = embeddings["token_embeddings"][0] # First batch item - elif hasattr(embeddings, "token_embeddings"): - logits = embeddings.token_embeddings[0] - # Fallback: try to get first value - elif isinstance(embeddings, dict): - logits = next(iter(embeddings.values()))[0] - else: - logits = embeddings[0] - - # Apply SPLADE activation: log(1 + relu(x)) - relu_log = torch.log(1 + torch.relu(logits)) - - # Max pooling over token dimension (reduce to vocab size) - if relu_log.dim() > 1: - sparse_vec, _ = torch.max(relu_log, dim=0) - else: - sparse_vec = relu_log - - # Convert to sparse dictionary (only non-zero values) - sparse_vec_np = sparse_vec.cpu().numpy() - sparse_dict = { - int(idx): float(val) for idx, val in enumerate(sparse_vec_np) if val > 0 - } - - # Sort by indices (keys) to ensure consistent ordering - return dict(sorted(sparse_dict.items())) - def _get_model(self): """Load or retrieve the SPLADE model from class-level cache. diff --git a/python/zvec/extension/sentence_transformer_function.py b/python/zvec/extension/sentence_transformer_function.py index 1ba1662a..83caca50 100644 --- a/python/zvec/extension/sentence_transformer_function.py +++ b/python/zvec/extension/sentence_transformer_function.py @@ -88,6 +88,18 @@ def device(self) -> str: return str(model.device) return self._device or "cpu" + @property + def _get_model_class(self): + """Get the Sentence Transformer class. + + Returns: + class: The Sentence Transformer class to use for loading models. + + Raises: + ImportError: If required packages are not installed. + """ + raise NotImplementedError() + def _get_model(self): """Load or retrieve the Sentence Transformer model. @@ -104,8 +116,6 @@ def _get_model(self): # Load model try: - sentence_transformers = require_module("sentence_transformers") - if self._model_source == "modelscope": # Load from ModelScope require_module("modelscope") @@ -115,12 +125,13 @@ def _get_model(self): model_dir = snapshot_download(self._model_name) # Load from local path - self._model = sentence_transformers.SentenceTransformer( + self._model = self._get_model_class( model_dir, device=self._device, trust_remote_code=True ) else: # Load from Hugging Face (default) - self._model = sentence_transformers.SentenceTransformer( + self._model = self._get_model_class( + self._model_name, device=self._device, trust_remote_code=True ) @@ -138,13 +149,3 @@ def _get_model(self): f"Failed to load Sentence Transformer model '{self._model_name}' " f"from {self._model_source}: {e!s}" ) from e - - def _is_sparse_model(self) -> bool: - """Check if the loaded model is a sparse encoder (e.g., SPLADE). - - Returns: - bool: True if model supports sparse encoding. - """ - model = self._get_model() - # Check if model has sparse encoding methods - return hasattr(model, "encode_query") or hasattr(model, "encode_document") diff --git a/python/zvec/extension/sentence_transformer_rerank_function.py b/python/zvec/extension/sentence_transformer_rerank_function.py index 58c5838f..4fd80749 100644 --- a/python/zvec/extension/sentence_transformer_rerank_function.py +++ b/python/zvec/extension/sentence_transformer_rerank_function.py @@ -95,7 +95,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): .. code-block:: python # Recommended for users in China - reranker = SentenceTransformerReRanker( + reranker = DefaultLocalReRanker( query="机器学习算法", rerank_field="content", model_source="modelscope" @@ -109,9 +109,9 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): Examples: >>> # Basic usage with default MS MARCO MiniLM model - >>> from zvec.extension import SentenceTransformerReRanker + >>> from zvec.extension import DefaultLocalReRanker >>> - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="machine learning algorithms", ... topn=5, ... rerank_field="content" @@ -125,7 +125,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): ... ) >>> # Using ModelScope for users in China - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="深度学习", ... topn=10, ... rerank_field="content", @@ -133,7 +133,7 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): ... ) >>> # Using larger model for better quality - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="neural networks", ... topn=5, ... rerank_field="content", @@ -177,7 +177,7 @@ def __init__( device: Optional[str] = None, batch_size: int = 32, ): - """Initialize SentenceTransformerReRanker with query and configuration. + """Initialize DefaultLocalReRanker with query and configuration. Args: query (Optional[str]): Query text for semantic matching. Required. @@ -214,59 +214,19 @@ def __init__( ) self._model = model - def _get_model(self): - """Load or retrieve the CrossEncoder model. - - This overrides the base class method to load CrossEncoder instead of - SentenceTransformer, as reranking requires cross-encoder models. + @property + def _get_model_class(self): + """Get the Sentence Transformer class. Returns: - CrossEncoder: The loaded cross-encoder model instance. + class: CrossEncoder, the class used for cross-encoder re-ranking. Raises: ImportError: If required packages are not installed. - ValueError: If model cannot be loaded. """ - # Return cached model if exists - if self._model is not None: - return self._model - - # Load cross-encoder model - try: - sentence_transformers = require_module("sentence_transformers") - - if self._model_source == "modelscope": - # Load from ModelScope - require_module("modelscope") - from modelscope.hub.snapshot_download import snapshot_download + sentence_transformers = require_module("sentence_transformers") - # Download model to cache - model_dir = snapshot_download(self._model_name) - - # Load CrossEncoder from local path - model = sentence_transformers.CrossEncoder( - model_dir, device=self._device - ) - else: - # Load CrossEncoder from Hugging Face (default) - model = sentence_transformers.CrossEncoder( - self._model_name, device=self._device - ) - - return model - - except ImportError as e: - if "modelscope" in str(e) and self._model_source == "modelscope": - raise ImportError( - "ModelScope support requires the 'modelscope' package. " - "Please install it with: pip install modelscope" - ) from e - raise - except Exception as e: - raise ValueError( - f"Failed to load CrossEncoder model '{self._model_name}' " - f"from {self._model_source}: {e!s}" - ) from e + return sentence_transformers.CrossEncoder @property def query(self) -> str: @@ -305,7 +265,7 @@ def rerank(self, query_results: dict[str, list[Doc]]) -> list[Doc]: - Processing time is O(n) where n is the number of documents Examples: - >>> reranker = SentenceTransformerReRanker( + >>> reranker = DefaultLocalReRanker( ... query="machine learning", ... topn=3, ... rerank_field="content" From 9abce9fcc0e81f27fc6db27c06afef1940db8301 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 26 Feb 2026 14:33:40 +0100 Subject: [PATCH 2/6] Revert model_name for Dense and Sparse models --- ...sentence_transformer_embedding_function.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/zvec/extension/sentence_transformer_embedding_function.py b/python/zvec/extension/sentence_transformer_embedding_function.py index e0ebb8bf..2565685e 100644 --- a/python/zvec/extension/sentence_transformer_embedding_function.py +++ b/python/zvec/extension/sentence_transformer_embedding_function.py @@ -40,9 +40,6 @@ class DefaultLocalDenseEmbedding( similarity tasks. It runs locally without requiring API keys. Args: - model_name (Optional[str]): Model identifier or local path. Defaults to: - - ``"all-MiniLM-L6-v2"`` for Hugging Face - - ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope model_source (Literal["huggingface", "modelscope"], optional): Model source. - ``"huggingface"``: Use Hugging Face Hub (default, for international users) - ``"modelscope"``: Use ModelScope (recommended for users in China) @@ -157,7 +154,6 @@ class DefaultLocalDenseEmbedding( def __init__( self, - model_name: Optional[str] = None, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, normalize_embeddings: bool = True, @@ -167,9 +163,6 @@ def __init__( """Initialize with all-MiniLM-L6-v2 model. Args: - model_name (Optional[str]): Model identifier or local path. Defaults to: - - ``"all-MiniLM-L6-v2"`` for Hugging Face - - ``"iic/nlp_gte_sentence-embedding_chinese-small"`` for ModelScope model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). @@ -184,12 +177,11 @@ def __init__( ValueError: If model cannot be loaded. """ # Use different models based on source - if model_name is None: - if model_source == "modelscope": - # Use Chinese-optimized model for ModelScope (better for Chinese text) - model_name = "iic/nlp_gte_sentence-embedding_chinese-small" - else: - model_name = "all-MiniLM-L6-v2" + if model_source == "modelscope": + # Use Chinese-optimized model for ModelScope (better for Chinese text) + model_name = "iic/nlp_gte_sentence-embedding_chinese-small" + else: + model_name = "all-MiniLM-L6-v2" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( @@ -385,14 +377,11 @@ class DefaultLocalSparseEmbedding( from huggingface_hub import login login(token="your_huggingface_token") - 5. To use a custom SPLADE model, you can subclass this class and override - the model_name in ``__init__``, or create your own implementation + 5. To use a custom SPLADE model, create your own implementation inheriting from ``SentenceTransformerFunctionBase`` and ``SparseEmbeddingFunction``. Args: - model_name (Optional[str]): Model identifier or local path. Defaults to - ``"naver/splade-cocondenser-ensembledistil"`` if None. model_source (Literal["huggingface", "modelscope"], optional): Model source. Defaults to ``"huggingface"``. ModelScope support may vary for SPLADE models. device (Optional[str], optional): Device to run the model on. @@ -614,7 +603,6 @@ def remove_from_cache( def __init__( self, - model_name: Optional[str] = None, model_source: Literal["huggingface", "modelscope"] = "huggingface", device: Optional[str] = None, encoding_type: Literal["query", "document"] = "query", @@ -623,8 +611,6 @@ def __init__( """Initialize with SPLADE model. Args: - model_name (Optional[str]): Model identifier or local path. Defaults to - ``"naver/splade-cocondenser-ensembledistil"`` if None. model_source (Literal["huggingface", "modelscope"]): Model source. Defaults to "huggingface". device (Optional[str]): Target device ("cpu", "cuda", "mps", or None). @@ -668,8 +654,7 @@ def __init__( # Use publicly available SPLADE model (no gated access required) # Note: naver/splade-v3 requires authentication, so we use the # cocondenser-ensembledistil variant which is publicly accessible - if model_name is None: - model_name = "naver/splade-cocondenser-ensembledistil" + model_name = "naver/splade-cocondenser-ensembledistil" # Initialize base class for model loading SentenceTransformerFunctionBase.__init__( From 5a86b1656a8504a42ab649e6ba959212bf792586 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:14:54 +0100 Subject: [PATCH 3/6] Remove extra blank line Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- python/zvec/extension/sentence_transformer_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/zvec/extension/sentence_transformer_function.py b/python/zvec/extension/sentence_transformer_function.py index 83caca50..07f44216 100644 --- a/python/zvec/extension/sentence_transformer_function.py +++ b/python/zvec/extension/sentence_transformer_function.py @@ -131,7 +131,6 @@ def _get_model(self): else: # Load from Hugging Face (default) self._model = self._get_model_class( - self._model_name, device=self._device, trust_remote_code=True ) From 2cbc8538bc56e97ad67acc40d8901d0e01cd724d Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 27 Feb 2026 10:16:24 +0100 Subject: [PATCH 4/6] Patch edge case with zero active dimensions This will effectively never happen, but it technically can --- .../zvec/extension/sentence_transformer_embedding_function.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/zvec/extension/sentence_transformer_embedding_function.py b/python/zvec/extension/sentence_transformer_embedding_function.py index 2565685e..b8af05a1 100644 --- a/python/zvec/extension/sentence_transformer_embedding_function.py +++ b/python/zvec/extension/sentence_transformer_embedding_function.py @@ -752,6 +752,8 @@ def embed(self, input: str) -> SparseVectorType: # The decode method returns a list of (token_string, score) pairs for non-zero dimensions # Then we post-process the tokens to IDs again decoded = model.decode(sparse_matrix)[0] + if not decoded: + return {} token_strings, scores = zip(*decoded, strict=True) token_ids = model.tokenizer.convert_tokens_to_ids(token_strings) sparse_dict = dict(zip(token_ids, scores, strict=True)) From 4a2d706709e70b76dcc339eaf6f7d83f66d2b952 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 27 Feb 2026 10:19:48 +0100 Subject: [PATCH 5/6] Run ruff formatter --- python/zvec/extension/sentence_transformer_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/zvec/extension/sentence_transformer_function.py b/python/zvec/extension/sentence_transformer_function.py index 07f44216..9f9dadd7 100644 --- a/python/zvec/extension/sentence_transformer_function.py +++ b/python/zvec/extension/sentence_transformer_function.py @@ -89,7 +89,7 @@ def device(self) -> str: return self._device or "cpu" @property - def _get_model_class(self): + def _get_model_class(self): """Get the Sentence Transformer class. Returns: From 2211e4dc70891b85c035cf811ffaa8b79906becf Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 27 Feb 2026 16:16:07 +0100 Subject: [PATCH 6/6] Turn _get_model_class into a regular method, makes more sense than a property --- .../zvec/extension/sentence_transformer_embedding_function.py | 2 -- python/zvec/extension/sentence_transformer_function.py | 1 - python/zvec/extension/sentence_transformer_rerank_function.py | 1 - 3 files changed, 4 deletions(-) diff --git a/python/zvec/extension/sentence_transformer_embedding_function.py b/python/zvec/extension/sentence_transformer_embedding_function.py index b8af05a1..e19999ed 100644 --- a/python/zvec/extension/sentence_transformer_embedding_function.py +++ b/python/zvec/extension/sentence_transformer_embedding_function.py @@ -198,7 +198,6 @@ def __init__( # Store extra parameters self._extra_params = kwargs - @property def _get_model_class(self): """Get the Sentence Transformer class. @@ -670,7 +669,6 @@ def __init__( # Load model to ensure it's available (will use cache if exists) self._get_model() - @property def _get_model_class(self): """Get the Sentence Transformer class based on the model source. diff --git a/python/zvec/extension/sentence_transformer_function.py b/python/zvec/extension/sentence_transformer_function.py index 9f9dadd7..78be6a99 100644 --- a/python/zvec/extension/sentence_transformer_function.py +++ b/python/zvec/extension/sentence_transformer_function.py @@ -88,7 +88,6 @@ def device(self) -> str: return str(model.device) return self._device or "cpu" - @property def _get_model_class(self): """Get the Sentence Transformer class. diff --git a/python/zvec/extension/sentence_transformer_rerank_function.py b/python/zvec/extension/sentence_transformer_rerank_function.py index 4fd80749..adc719d9 100644 --- a/python/zvec/extension/sentence_transformer_rerank_function.py +++ b/python/zvec/extension/sentence_transformer_rerank_function.py @@ -214,7 +214,6 @@ def __init__( ) self._model = model - @property def _get_model_class(self): """Get the Sentence Transformer class.