pinecone_text.sparse.splade_encoder

  1from typing import List, Union, Optional
  2
  3try:
  4    import torch
  5except (OSError, ImportError, ModuleNotFoundError) as e:
  6    _torch_installed = False
  7else:
  8    _torch_installed = True
  9
 10try:
 11    from transformers import AutoTokenizer, AutoModelForMaskedLM
 12except (OSError, ImportError, ModuleNotFoundError) as e:
 13    _transformers_installed = False
 14else:
 15    _transformers_installed = True
 16
 17
 18from pinecone_text.sparse import SparseVector
 19from pinecone_text.sparse.base_sparse_encoder import BaseSparseEncoder
 20
 21
 22class SpladeEncoder(BaseSparseEncoder):
 23
 24    """
 25    SPLADE sparse vector encoder.
 26    Currently only supports inference with  naver/splade-cocondenser-ensembledistil
 27    """
 28
 29    def __init__(self, max_seq_length: int = 256, device: Optional[str] = None):
 30        """
 31        Args:
 32            max_seq_length: Maximum sequence length for the model. Must be between 1 and 512.
 33            device: Device to use for inference. Defaults to GPU if available, otherwise CPU.
 34
 35        Example:
 36
 37            ```python
 38            from pinecone_text.sparse import SPLADE
 39
 40            splade = SPLADE()
 41
 42            splade.encode_documents("this is a document") # [{"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}, ...]
 43            ```
 44        """
 45        if not _torch_installed:
 46            raise ImportError(
 47                """Failed to import torch. Make sure you install pytorch extra dependencies by running: `pip install pinecone-text[splade]`
 48        If this doesn't help, it is probably a CUDA error. If you do want to use GPU, please check your CUDA driver.
 49        If you want to use CPU only, run the following command:
 50        `pip uninstall -y torch torchvision;pip install -y torch torchvision --index-url https://download.pytorch.org/whl/cpu`"""
 51            )
 52
 53        if not _transformers_installed:
 54            raise ImportError(
 55                "Failed to import transformers. Make sure you install splade "
 56                "extra dependencies by running: `pip install pinecone-text[splade]`"
 57            )
 58
 59        if not 0 < max_seq_length <= 512:
 60            raise ValueError("max_seq_length must be between 1 and 512")
 61
 62        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 63        self.device = device
 64
 65        model = "naver/splade-cocondenser-ensembledistil"
 66        self.tokenizer = AutoTokenizer.from_pretrained(model)
 67        self.model = AutoModelForMaskedLM.from_pretrained(model).to(self.device)
 68        self.max_seq_length = max_seq_length
 69
 70    def encode_documents(
 71        self, texts: Union[str, List[str]]
 72    ) -> Union[SparseVector, List[SparseVector]]:
 73        """
 74        encode documents to a sparse vector (for upsert to pinecone)
 75
 76        Args:
 77            texts: a single or list of documents to encode as a string
 78        """
 79        return self._encode(texts)
 80
 81    def encode_queries(
 82        self, texts: Union[str, List[str]]
 83    ) -> Union[SparseVector, List[SparseVector]]:
 84        """
 85        encode queries to a sparse vector (for upsert to pinecone)
 86
 87        Args:
 88            texts: a single or list of queries to encode as a string
 89        """
 90        return self._encode(texts)
 91
 92    def _encode(
 93        self, texts: Union[str, List[str]]
 94    ) -> Union[SparseVector, List[SparseVector]]:
 95        """
 96        Args:
 97            texts: single or list of texts to encode.
 98
 99        Returns a list of Splade sparse vectors, one for each input text.
100        """
101        inputs = self.tokenizer(
102            texts,
103            return_tensors="pt",
104            padding=True,
105            truncation=True,
106            max_length=self.max_seq_length,
107        ).to(self.device)
108        with torch.no_grad():
109            logits = self.model(**inputs).logits
110
111        inter = torch.log1p(torch.relu(logits))
112        token_max = torch.max(inter, dim=1)
113
114        nz_tokens_i, nz_tokens_j = torch.where(token_max.values > 0)
115
116        output = []
117        for i in range(token_max.values.shape[0]):
118            nz_tokens = nz_tokens_j[nz_tokens_i == i]
119            nz_weights = token_max.values[i, nz_tokens]
120            output.append(
121                {"indices": nz_tokens.tolist(), "values": nz_weights.tolist()}
122            )
123
124        return output[0] if isinstance(texts, str) else output
 23class SpladeEncoder(BaseSparseEncoder):
 24
 25    """
 26    SPLADE sparse vector encoder.
 27    Currently only supports inference with  naver/splade-cocondenser-ensembledistil
 28    """
 29
 30    def __init__(self, max_seq_length: int = 256, device: Optional[str] = None):
 31        """
 32        Args:
 33            max_seq_length: Maximum sequence length for the model. Must be between 1 and 512.
 34            device: Device to use for inference. Defaults to GPU if available, otherwise CPU.
 35
 36        Example:
 37
 38            ```python
 39            from pinecone_text.sparse import SPLADE
 40
 41            splade = SPLADE()
 42
 43            splade.encode_documents("this is a document") # [{"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}, ...]
 44            ```
 45        """
 46        if not _torch_installed:
 47            raise ImportError(
 48                """Failed to import torch. Make sure you install pytorch extra dependencies by running: `pip install pinecone-text[splade]`
 49        If this doesn't help, it is probably a CUDA error. If you do want to use GPU, please check your CUDA driver.
 50        If you want to use CPU only, run the following command:
 51        `pip uninstall -y torch torchvision;pip install -y torch torchvision --index-url https://download.pytorch.org/whl/cpu`"""
 52            )
 53
 54        if not _transformers_installed:
 55            raise ImportError(
 56                "Failed to import transformers. Make sure you install splade "
 57                "extra dependencies by running: `pip install pinecone-text[splade]`"
 58            )
 59
 60        if not 0 < max_seq_length <= 512:
 61            raise ValueError("max_seq_length must be between 1 and 512")
 62
 63        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 64        self.device = device
 65
 66        model = "naver/splade-cocondenser-ensembledistil"
 67        self.tokenizer = AutoTokenizer.from_pretrained(model)
 68        self.model = AutoModelForMaskedLM.from_pretrained(model).to(self.device)
 69        self.max_seq_length = max_seq_length
 70
 71    def encode_documents(
 72        self, texts: Union[str, List[str]]
 73    ) -> Union[SparseVector, List[SparseVector]]:
 74        """
 75        encode documents to a sparse vector (for upsert to pinecone)
 76
 77        Args:
 78            texts: a single or list of documents to encode as a string
 79        """
 80        return self._encode(texts)
 81
 82    def encode_queries(
 83        self, texts: Union[str, List[str]]
 84    ) -> Union[SparseVector, List[SparseVector]]:
 85        """
 86        encode queries to a sparse vector (for upsert to pinecone)
 87
 88        Args:
 89            texts: a single or list of queries to encode as a string
 90        """
 91        return self._encode(texts)
 92
 93    def _encode(
 94        self, texts: Union[str, List[str]]
 95    ) -> Union[SparseVector, List[SparseVector]]:
 96        """
 97        Args:
 98            texts: single or list of texts to encode.
 99
100        Returns a list of Splade sparse vectors, one for each input text.
101        """
102        inputs = self.tokenizer(
103            texts,
104            return_tensors="pt",
105            padding=True,
106            truncation=True,
107            max_length=self.max_seq_length,
108        ).to(self.device)
109        with torch.no_grad():
110            logits = self.model(**inputs).logits
111
112        inter = torch.log1p(torch.relu(logits))
113        token_max = torch.max(inter, dim=1)
114
115        nz_tokens_i, nz_tokens_j = torch.where(token_max.values > 0)
116
117        output = []
118        for i in range(token_max.values.shape[0]):
119            nz_tokens = nz_tokens_j[nz_tokens_i == i]
120            nz_weights = token_max.values[i, nz_tokens]
121            output.append(
122                {"indices": nz_tokens.tolist(), "values": nz_weights.tolist()}
123            )
124
125        return output[0] if isinstance(texts, str) else output

SPLADE sparse vector encoder. Currently only supports inference with naver/splade-cocondenser-ensembledistil

SpladeEncoder(max_seq_length: int = 256, device: Optional[str] = None)
30    def __init__(self, max_seq_length: int = 256, device: Optional[str] = None):
31        """
32        Args:
33            max_seq_length: Maximum sequence length for the model. Must be between 1 and 512.
34            device: Device to use for inference. Defaults to GPU if available, otherwise CPU.
35
36        Example:
37
38            ```python
39            from pinecone_text.sparse import SPLADE
40
41            splade = SPLADE()
42
43            splade.encode_documents("this is a document") # [{"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}, ...]
44            ```
45        """
46        if not _torch_installed:
47            raise ImportError(
48                """Failed to import torch. Make sure you install pytorch extra dependencies by running: `pip install pinecone-text[splade]`
49        If this doesn't help, it is probably a CUDA error. If you do want to use GPU, please check your CUDA driver.
50        If you want to use CPU only, run the following command:
51        `pip uninstall -y torch torchvision;pip install -y torch torchvision --index-url https://download.pytorch.org/whl/cpu`"""
52            )
53
54        if not _transformers_installed:
55            raise ImportError(
56                "Failed to import transformers. Make sure you install splade "
57                "extra dependencies by running: `pip install pinecone-text[splade]`"
58            )
59
60        if not 0 < max_seq_length <= 512:
61            raise ValueError("max_seq_length must be between 1 and 512")
62
63        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
64        self.device = device
65
66        model = "naver/splade-cocondenser-ensembledistil"
67        self.tokenizer = AutoTokenizer.from_pretrained(model)
68        self.model = AutoModelForMaskedLM.from_pretrained(model).to(self.device)
69        self.max_seq_length = max_seq_length
Arguments:
  • max_seq_length: Maximum sequence length for the model. Must be between 1 and 512.
  • device: Device to use for inference. Defaults to GPU if available, otherwise CPU.
Example:
from pinecone_text.sparse import SPLADE

splade = SPLADE()

splade.encode_documents("this is a document") # [{"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}, ...]
def encode_documents( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
71    def encode_documents(
72        self, texts: Union[str, List[str]]
73    ) -> Union[SparseVector, List[SparseVector]]:
74        """
75        encode documents to a sparse vector (for upsert to pinecone)
76
77        Args:
78            texts: a single or list of documents to encode as a string
79        """
80        return self._encode(texts)

encode documents to a sparse vector (for upsert to pinecone)

Arguments:
  • texts: a single or list of documents to encode as a string
def encode_queries( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
82    def encode_queries(
83        self, texts: Union[str, List[str]]
84    ) -> Union[SparseVector, List[SparseVector]]:
85        """
86        encode queries to a sparse vector (for upsert to pinecone)
87
88        Args:
89            texts: a single or list of queries to encode as a string
90        """
91        return self._encode(texts)

encode queries to a sparse vector (for upsert to pinecone)

Arguments:
  • texts: a single or list of queries to encode as a string