Source code for pyod.utils.encoders.sentence_transformer

# -*- coding: utf-8 -*-
"""SentenceTransformerEncoder for EmbeddingOD."""
# Author: Yue Zhao <yzhao062@gmail.com>
# License: BSD 2 clause

try:
    from sentence_transformers import SentenceTransformer
except ImportError:
    SentenceTransformer = None

from . import BaseEncoder


[docs] class SentenceTransformerEncoder(BaseEncoder): """Encoder using sentence-transformers library. Wraps ``sentence_transformers.SentenceTransformer`` to produce text embeddings compatible with PyOD detectors. Parameters ---------- model_name : str, optional (default='all-MiniLM-L6-v2') Name or path of a sentence-transformers model. device : str or None, optional (default=None) Device for inference ('cpu', 'cuda', etc.). None for auto-detection. normalize : bool, optional (default=False) L2-normalize output embeddings. truncate_dim : int or None, optional (default=None) Truncate embeddings to this dimensionality (Matryoshka). Examples -------- >>> from pyod.utils.encoders.sentence_transformer import \\ ... SentenceTransformerEncoder >>> encoder = SentenceTransformerEncoder('all-MiniLM-L6-v2') >>> embeddings = encoder.encode(["hello world", "anomaly text"]) >>> embeddings.shape (2, 384) """ def __init__(self, model_name='all-MiniLM-L6-v2', device=None, normalize=False, truncate_dim=None): if SentenceTransformer is None: raise ImportError( "SentenceTransformerEncoder requires 'sentence-transformers'. " "Install with: pip install sentence-transformers") self.model_name = model_name self.device = device self.normalize = normalize self.truncate_dim = truncate_dim
[docs] def encode(self, X, batch_size=32, show_progress=True): """Encode text strings to embeddings. Parameters ---------- X : list of str Text strings to encode. batch_size : int, optional (default=32) Batch size for encoding. show_progress : bool, optional (default=True) Show progress bar. Returns ------- embeddings : numpy array of shape (n_samples, n_features) """ if not hasattr(self, 'model_'): self.model_ = SentenceTransformer( self.model_name, device=self.device) embeddings = self.model_.encode( X, batch_size=batch_size, show_progress_bar=show_progress, convert_to_numpy=True, normalize_embeddings=self.normalize, truncate_dim=self.truncate_dim, ) return self._validate_output(embeddings, n_samples=len(X))