pinecone_text.sparse.bm25_encoder

  1import json
  2import mmh3
  3import numpy as np
  4import tempfile
  5from pathlib import Path
  6from tqdm.auto import tqdm
  7import wget
  8from typing import List, Optional, Dict, Union, Tuple
  9from collections import Counter
 10
 11from pinecone_text.sparse import SparseVector
 12from pinecone_text.sparse.base_sparse_encoder import BaseSparseEncoder
 13from pinecone_text.sparse.bm25_tokenizer import BM25Tokenizer
 14
 15
 16class BM25Encoder(BaseSparseEncoder):
 17
 18    """OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)"""
 19
 20    def __init__(
 21        self,
 22        b: float = 0.75,
 23        k1: float = 1.2,
 24        lower_case: bool = True,
 25        remove_punctuation: bool = True,
 26        remove_stopwords: bool = True,
 27        stem: bool = True,
 28        language: str = "english",
 29    ):
 30        """
 31        OKapi BM25 with mmh3 hashing
 32
 33        Args:
 34            b: The length normalization parameter
 35            k1: The term frequency normalization parameter
 36            lower_case: Whether to lower case the tokens
 37            remove_punctuation: Whether to remove punctuation tokens
 38            remove_stopwords: Whether to remove stopwords tokens
 39            stem: Whether to stem the tokens (using SnowballStemmer)
 40            language: The language of the text (used for stopwords and stemmer)
 41
 42        Example:
 43
 44            ```python
 45            from pinecone_text.sparse import BM25
 46
 47            bm25 = BM25Encoder()
 48
 49            bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"])
 50
 51            bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}
 52            bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]}
 53            ```
 54        """
 55        # Fixed params
 56        self.b: float = b
 57        self.k1: float = k1
 58
 59        self._tokenizer = BM25Tokenizer(
 60            lower_case=lower_case,
 61            remove_punctuation=remove_punctuation,
 62            remove_stopwords=remove_stopwords,
 63            stem=stem,
 64            language=language,
 65        )
 66
 67        # Learned Params
 68        self.doc_freq: Optional[Dict[int, float]] = None
 69        self.n_docs: Optional[int] = None
 70        self.avgdl: Optional[float] = None
 71
 72    def fit(self, corpus: List[str]) -> "BM25Encoder":
 73        """
 74        Fit BM25 by calculating document frequency over the corpus
 75
 76        Args:
 77            corpus: list of texts to fit BM25 with
 78        """
 79        n_docs = 0
 80        sum_doc_len = 0
 81        doc_freq_counter: Counter = Counter()
 82
 83        for doc in tqdm(corpus):
 84            if not isinstance(doc, str):
 85                raise ValueError("corpus must be a list of strings")
 86
 87            indices, tf = self._tf(doc)
 88            if len(indices) == 0:
 89                continue
 90            n_docs += 1
 91            sum_doc_len += sum(tf)
 92
 93            # Count the number of documents that contain each token
 94            doc_freq_counter.update(indices)
 95
 96        self.doc_freq = dict(doc_freq_counter)
 97        self.n_docs = n_docs
 98        self.avgdl = sum_doc_len / n_docs
 99        return self
100
101    def encode_documents(
102        self, texts: Union[str, List[str]]
103    ) -> Union[SparseVector, List[SparseVector]]:
104        """
105        encode documents to a sparse vector (for upsert to pinecone)
106
107        Args:
108            texts: a single or list of documents to encode as a string
109        """
110        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
111            raise ValueError("BM25 must be fit before encoding documents")
112
113        if isinstance(texts, str):
114            return self._encode_single_document(texts)
115        elif isinstance(texts, list):
116            return [self._encode_single_document(text) for text in texts]
117        else:
118            raise ValueError("texts must be a string or list of strings")
119
120    def _encode_single_document(self, text: str) -> SparseVector:
121        indices, doc_tf = self._tf(text)
122        tf = np.array(doc_tf)
123        tf_sum = sum(tf)
124
125        tf_normed = tf / (
126            self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf
127        )
128        return {
129            "indices": indices,
130            "values": tf_normed.tolist(),
131        }
132
133    def encode_queries(
134        self, texts: Union[str, List[str]]
135    ) -> Union[SparseVector, List[SparseVector]]:
136        """
137        encode query to a sparse vector
138
139        Args:
140            texts: a single or list of queries to encode as a string
141        """
142        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
143            raise ValueError("BM25 must be fit before encoding queries")
144
145        if isinstance(texts, str):
146            return self._encode_single_query(texts)
147        elif isinstance(texts, list):
148            return [self._encode_single_query(text) for text in texts]
149        else:
150            raise ValueError("texts must be a string or list of strings")
151
152    def _encode_single_query(self, text: str) -> SparseVector:
153        indices, query_tf = self._tf(text)
154
155        df = np.array([self.doc_freq.get(idx, 1) for idx in indices])  # type: ignore
156        idf = np.log((self.n_docs + 1) / (df + 0.5))  # type: ignore
157        idf_norm = idf / idf.sum()
158        return {
159            "indices": indices,
160            "values": idf_norm.tolist(),
161        }
162
163    def dump(self, path: str) -> None:
164        """
165        Store BM25 params to a file in JSON format
166
167        Args:
168            path: full file path to save params in
169        """
170        with open(path, "w") as f:
171            json.dump(self.get_params(), f)
172
173    def load(self, path: str) -> "BM25Encoder":
174        """
175        Load BM25 params from a file in JSON format
176
177        Args:
178            path: full file path to load params from
179        """
180        with open(path, "r") as f:
181            params = json.load(f)
182        return self.set_params(**params)
183
184    def get_params(
185        self,
186    ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]:
187        """Returns the BM25 params"""
188
189        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
190            raise ValueError("BM25 must be fit before storing params")
191
192        tf_pairs = list(self.doc_freq.items())
193        return {
194            "avgdl": self.avgdl,
195            "n_docs": self.n_docs,
196            "doc_freq": {
197                "indices": [int(idx) for idx, _ in tf_pairs],
198                "values": [float(val) for _, val in tf_pairs],
199            },
200            "b": self.b,
201            "k1": self.k1,
202            "lower_case": self._tokenizer.lower_case,
203            "remove_punctuation": self._tokenizer.remove_punctuation,
204            "remove_stopwords": self._tokenizer.remove_stopwords,
205            "stem": self._tokenizer.stem,
206            "language": self._tokenizer.language,
207        }
208
209    def set_params(
210        self,
211        avgdl: float,
212        n_docs: int,
213        doc_freq: Dict[str, List[int]],
214        b: float,
215        k1: float,
216        lower_case: bool,
217        remove_punctuation: bool,
218        remove_stopwords: bool,
219        stem: bool,
220        language: str,
221    ) -> "BM25Encoder":
222        """
223        Set input parameters to BM25
224
225        Args:
226            avgdl: average document length in the corpus
227            n_docs: number of documents in the corpus
228            doc_freq: document frequency of each term in the corpus
229            b: length normalization parameter
230            k1: term frequency normalization parameter
231            lower_case: whether to lower case the text
232            remove_punctuation: whether to remove punctuation from the text
233            remove_stopwords: whether to remove stopwords from the text
234            stem: whether to stem the text
235            language: language of the text for stopwords and stemmer
236        """
237        self.avgdl = avgdl  # type: ignore
238        self.n_docs = n_docs  # type: ignore
239        self.doc_freq = {
240            idx: val
241            for idx, val in zip(doc_freq["indices"], doc_freq["values"])  # type: ignore
242        }
243        self.b = b  # type: ignore
244        self.k1 = k1  # type: ignore
245        self._tokenizer = BM25Tokenizer(
246            lower_case=lower_case,  # type: ignore
247            remove_punctuation=remove_punctuation,  # type: ignore
248            remove_stopwords=remove_stopwords,  # type: ignore
249            stem=stem,  # type: ignore
250            language=language,
251        )  # type: ignore
252        return self
253
254    @staticmethod
255    def default() -> "BM25Encoder":
256        """Create a BM25 model from pre-made params for the MS MARCO passages corpus"""
257        bm25 = BM25Encoder()
258        url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json"
259        with tempfile.TemporaryDirectory() as tmp_dir:
260            tmp_path = Path(tmp_dir, "msmarco_bm25_params.json")
261            wget.download(url, str(tmp_path))
262            bm25.load(str(tmp_path))
263        return bm25
264
265    @staticmethod
266    def _hash_text(token: str) -> int:
267        """Use mmh3 to hash text to 32-bit unsigned integer"""
268        return mmh3.hash(token, signed=False)
269
270    def _tf(self, text: str) -> Tuple[List[int], List[int]]:
271        """
272        Calculate term frequency for a given text
273
274        Args:
275            text: a document to calculate term frequency for
276
277        Returns: a tuple of two lists:
278            indices: list of term indices
279            values: list of term frequencies
280        """
281        counts = Counter((self._hash_text(token) for token in self._tokenizer(text)))
282
283        items = list(counts.items())
284        return [idx for idx, _ in items], [val for _, val in items]
 17class BM25Encoder(BaseSparseEncoder):
 18
 19    """OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)"""
 20
 21    def __init__(
 22        self,
 23        b: float = 0.75,
 24        k1: float = 1.2,
 25        lower_case: bool = True,
 26        remove_punctuation: bool = True,
 27        remove_stopwords: bool = True,
 28        stem: bool = True,
 29        language: str = "english",
 30    ):
 31        """
 32        OKapi BM25 with mmh3 hashing
 33
 34        Args:
 35            b: The length normalization parameter
 36            k1: The term frequency normalization parameter
 37            lower_case: Whether to lower case the tokens
 38            remove_punctuation: Whether to remove punctuation tokens
 39            remove_stopwords: Whether to remove stopwords tokens
 40            stem: Whether to stem the tokens (using SnowballStemmer)
 41            language: The language of the text (used for stopwords and stemmer)
 42
 43        Example:
 44
 45            ```python
 46            from pinecone_text.sparse import BM25
 47
 48            bm25 = BM25Encoder()
 49
 50            bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"])
 51
 52            bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}
 53            bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]}
 54            ```
 55        """
 56        # Fixed params
 57        self.b: float = b
 58        self.k1: float = k1
 59
 60        self._tokenizer = BM25Tokenizer(
 61            lower_case=lower_case,
 62            remove_punctuation=remove_punctuation,
 63            remove_stopwords=remove_stopwords,
 64            stem=stem,
 65            language=language,
 66        )
 67
 68        # Learned Params
 69        self.doc_freq: Optional[Dict[int, float]] = None
 70        self.n_docs: Optional[int] = None
 71        self.avgdl: Optional[float] = None
 72
 73    def fit(self, corpus: List[str]) -> "BM25Encoder":
 74        """
 75        Fit BM25 by calculating document frequency over the corpus
 76
 77        Args:
 78            corpus: list of texts to fit BM25 with
 79        """
 80        n_docs = 0
 81        sum_doc_len = 0
 82        doc_freq_counter: Counter = Counter()
 83
 84        for doc in tqdm(corpus):
 85            if not isinstance(doc, str):
 86                raise ValueError("corpus must be a list of strings")
 87
 88            indices, tf = self._tf(doc)
 89            if len(indices) == 0:
 90                continue
 91            n_docs += 1
 92            sum_doc_len += sum(tf)
 93
 94            # Count the number of documents that contain each token
 95            doc_freq_counter.update(indices)
 96
 97        self.doc_freq = dict(doc_freq_counter)
 98        self.n_docs = n_docs
 99        self.avgdl = sum_doc_len / n_docs
100        return self
101
102    def encode_documents(
103        self, texts: Union[str, List[str]]
104    ) -> Union[SparseVector, List[SparseVector]]:
105        """
106        encode documents to a sparse vector (for upsert to pinecone)
107
108        Args:
109            texts: a single or list of documents to encode as a string
110        """
111        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
112            raise ValueError("BM25 must be fit before encoding documents")
113
114        if isinstance(texts, str):
115            return self._encode_single_document(texts)
116        elif isinstance(texts, list):
117            return [self._encode_single_document(text) for text in texts]
118        else:
119            raise ValueError("texts must be a string or list of strings")
120
121    def _encode_single_document(self, text: str) -> SparseVector:
122        indices, doc_tf = self._tf(text)
123        tf = np.array(doc_tf)
124        tf_sum = sum(tf)
125
126        tf_normed = tf / (
127            self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf
128        )
129        return {
130            "indices": indices,
131            "values": tf_normed.tolist(),
132        }
133
134    def encode_queries(
135        self, texts: Union[str, List[str]]
136    ) -> Union[SparseVector, List[SparseVector]]:
137        """
138        encode query to a sparse vector
139
140        Args:
141            texts: a single or list of queries to encode as a string
142        """
143        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
144            raise ValueError("BM25 must be fit before encoding queries")
145
146        if isinstance(texts, str):
147            return self._encode_single_query(texts)
148        elif isinstance(texts, list):
149            return [self._encode_single_query(text) for text in texts]
150        else:
151            raise ValueError("texts must be a string or list of strings")
152
153    def _encode_single_query(self, text: str) -> SparseVector:
154        indices, query_tf = self._tf(text)
155
156        df = np.array([self.doc_freq.get(idx, 1) for idx in indices])  # type: ignore
157        idf = np.log((self.n_docs + 1) / (df + 0.5))  # type: ignore
158        idf_norm = idf / idf.sum()
159        return {
160            "indices": indices,
161            "values": idf_norm.tolist(),
162        }
163
164    def dump(self, path: str) -> None:
165        """
166        Store BM25 params to a file in JSON format
167
168        Args:
169            path: full file path to save params in
170        """
171        with open(path, "w") as f:
172            json.dump(self.get_params(), f)
173
174    def load(self, path: str) -> "BM25Encoder":
175        """
176        Load BM25 params from a file in JSON format
177
178        Args:
179            path: full file path to load params from
180        """
181        with open(path, "r") as f:
182            params = json.load(f)
183        return self.set_params(**params)
184
185    def get_params(
186        self,
187    ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]:
188        """Returns the BM25 params"""
189
190        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
191            raise ValueError("BM25 must be fit before storing params")
192
193        tf_pairs = list(self.doc_freq.items())
194        return {
195            "avgdl": self.avgdl,
196            "n_docs": self.n_docs,
197            "doc_freq": {
198                "indices": [int(idx) for idx, _ in tf_pairs],
199                "values": [float(val) for _, val in tf_pairs],
200            },
201            "b": self.b,
202            "k1": self.k1,
203            "lower_case": self._tokenizer.lower_case,
204            "remove_punctuation": self._tokenizer.remove_punctuation,
205            "remove_stopwords": self._tokenizer.remove_stopwords,
206            "stem": self._tokenizer.stem,
207            "language": self._tokenizer.language,
208        }
209
210    def set_params(
211        self,
212        avgdl: float,
213        n_docs: int,
214        doc_freq: Dict[str, List[int]],
215        b: float,
216        k1: float,
217        lower_case: bool,
218        remove_punctuation: bool,
219        remove_stopwords: bool,
220        stem: bool,
221        language: str,
222    ) -> "BM25Encoder":
223        """
224        Set input parameters to BM25
225
226        Args:
227            avgdl: average document length in the corpus
228            n_docs: number of documents in the corpus
229            doc_freq: document frequency of each term in the corpus
230            b: length normalization parameter
231            k1: term frequency normalization parameter
232            lower_case: whether to lower case the text
233            remove_punctuation: whether to remove punctuation from the text
234            remove_stopwords: whether to remove stopwords from the text
235            stem: whether to stem the text
236            language: language of the text for stopwords and stemmer
237        """
238        self.avgdl = avgdl  # type: ignore
239        self.n_docs = n_docs  # type: ignore
240        self.doc_freq = {
241            idx: val
242            for idx, val in zip(doc_freq["indices"], doc_freq["values"])  # type: ignore
243        }
244        self.b = b  # type: ignore
245        self.k1 = k1  # type: ignore
246        self._tokenizer = BM25Tokenizer(
247            lower_case=lower_case,  # type: ignore
248            remove_punctuation=remove_punctuation,  # type: ignore
249            remove_stopwords=remove_stopwords,  # type: ignore
250            stem=stem,  # type: ignore
251            language=language,
252        )  # type: ignore
253        return self
254
255    @staticmethod
256    def default() -> "BM25Encoder":
257        """Create a BM25 model from pre-made params for the MS MARCO passages corpus"""
258        bm25 = BM25Encoder()
259        url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json"
260        with tempfile.TemporaryDirectory() as tmp_dir:
261            tmp_path = Path(tmp_dir, "msmarco_bm25_params.json")
262            wget.download(url, str(tmp_path))
263            bm25.load(str(tmp_path))
264        return bm25
265
266    @staticmethod
267    def _hash_text(token: str) -> int:
268        """Use mmh3 to hash text to 32-bit unsigned integer"""
269        return mmh3.hash(token, signed=False)
270
271    def _tf(self, text: str) -> Tuple[List[int], List[int]]:
272        """
273        Calculate term frequency for a given text
274
275        Args:
276            text: a document to calculate term frequency for
277
278        Returns: a tuple of two lists:
279            indices: list of term indices
280            values: list of term frequencies
281        """
282        counts = Counter((self._hash_text(token) for token in self._tokenizer(text)))
283
284        items = list(counts.items())
285        return [idx for idx, _ in items], [val for _, val in items]

OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)

BM25Encoder( b: float = 0.75, k1: float = 1.2, lower_case: bool = True, remove_punctuation: bool = True, remove_stopwords: bool = True, stem: bool = True, language: str = 'english')
21    def __init__(
22        self,
23        b: float = 0.75,
24        k1: float = 1.2,
25        lower_case: bool = True,
26        remove_punctuation: bool = True,
27        remove_stopwords: bool = True,
28        stem: bool = True,
29        language: str = "english",
30    ):
31        """
32        OKapi BM25 with mmh3 hashing
33
34        Args:
35            b: The length normalization parameter
36            k1: The term frequency normalization parameter
37            lower_case: Whether to lower case the tokens
38            remove_punctuation: Whether to remove punctuation tokens
39            remove_stopwords: Whether to remove stopwords tokens
40            stem: Whether to stem the tokens (using SnowballStemmer)
41            language: The language of the text (used for stopwords and stemmer)
42
43        Example:
44
45            ```python
46            from pinecone_text.sparse import BM25
47
48            bm25 = BM25Encoder()
49
50            bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"])
51
52            bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}
53            bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]}
54            ```
55        """
56        # Fixed params
57        self.b: float = b
58        self.k1: float = k1
59
60        self._tokenizer = BM25Tokenizer(
61            lower_case=lower_case,
62            remove_punctuation=remove_punctuation,
63            remove_stopwords=remove_stopwords,
64            stem=stem,
65            language=language,
66        )
67
68        # Learned Params
69        self.doc_freq: Optional[Dict[int, float]] = None
70        self.n_docs: Optional[int] = None
71        self.avgdl: Optional[float] = None

OKapi BM25 with mmh3 hashing

Arguments:
  • b: The length normalization parameter
  • k1: The term frequency normalization parameter
  • lower_case: Whether to lower case the tokens
  • remove_punctuation: Whether to remove punctuation tokens
  • remove_stopwords: Whether to remove stopwords tokens
  • stem: Whether to stem the tokens (using SnowballStemmer)
  • language: The language of the text (used for stopwords and stemmer)
Example:
from pinecone_text.sparse import BM25

bm25 = BM25Encoder()

bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"])

bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]}
bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]}
def fit(self, corpus: List[str]) -> pinecone_text.sparse.bm25_encoder.BM25Encoder:
 73    def fit(self, corpus: List[str]) -> "BM25Encoder":
 74        """
 75        Fit BM25 by calculating document frequency over the corpus
 76
 77        Args:
 78            corpus: list of texts to fit BM25 with
 79        """
 80        n_docs = 0
 81        sum_doc_len = 0
 82        doc_freq_counter: Counter = Counter()
 83
 84        for doc in tqdm(corpus):
 85            if not isinstance(doc, str):
 86                raise ValueError("corpus must be a list of strings")
 87
 88            indices, tf = self._tf(doc)
 89            if len(indices) == 0:
 90                continue
 91            n_docs += 1
 92            sum_doc_len += sum(tf)
 93
 94            # Count the number of documents that contain each token
 95            doc_freq_counter.update(indices)
 96
 97        self.doc_freq = dict(doc_freq_counter)
 98        self.n_docs = n_docs
 99        self.avgdl = sum_doc_len / n_docs
100        return self

Fit BM25 by calculating document frequency over the corpus

Arguments:
  • corpus: list of texts to fit BM25 with
def encode_documents( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
102    def encode_documents(
103        self, texts: Union[str, List[str]]
104    ) -> Union[SparseVector, List[SparseVector]]:
105        """
106        encode documents to a sparse vector (for upsert to pinecone)
107
108        Args:
109            texts: a single or list of documents to encode as a string
110        """
111        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
112            raise ValueError("BM25 must be fit before encoding documents")
113
114        if isinstance(texts, str):
115            return self._encode_single_document(texts)
116        elif isinstance(texts, list):
117            return [self._encode_single_document(text) for text in texts]
118        else:
119            raise ValueError("texts must be a string or list of strings")

encode documents to a sparse vector (for upsert to pinecone)

Arguments:
  • texts: a single or list of documents to encode as a string
def encode_queries( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
134    def encode_queries(
135        self, texts: Union[str, List[str]]
136    ) -> Union[SparseVector, List[SparseVector]]:
137        """
138        encode query to a sparse vector
139
140        Args:
141            texts: a single or list of queries to encode as a string
142        """
143        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
144            raise ValueError("BM25 must be fit before encoding queries")
145
146        if isinstance(texts, str):
147            return self._encode_single_query(texts)
148        elif isinstance(texts, list):
149            return [self._encode_single_query(text) for text in texts]
150        else:
151            raise ValueError("texts must be a string or list of strings")

encode query to a sparse vector

Arguments:
  • texts: a single or list of queries to encode as a string
def dump(self, path: str) -> None:
164    def dump(self, path: str) -> None:
165        """
166        Store BM25 params to a file in JSON format
167
168        Args:
169            path: full file path to save params in
170        """
171        with open(path, "w") as f:
172            json.dump(self.get_params(), f)

Store BM25 params to a file in JSON format

Arguments:
  • path: full file path to save params in
def load(self, path: str) -> pinecone_text.sparse.bm25_encoder.BM25Encoder:
174    def load(self, path: str) -> "BM25Encoder":
175        """
176        Load BM25 params from a file in JSON format
177
178        Args:
179            path: full file path to load params from
180        """
181        with open(path, "r") as f:
182            params = json.load(f)
183        return self.set_params(**params)

Load BM25 params from a file in JSON format

Arguments:
  • path: full file path to load params from
def get_params( self) -> Dict[str, Union[int, float, str, Dict[str, List[Union[float, int]]]]]:
185    def get_params(
186        self,
187    ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]:
188        """Returns the BM25 params"""
189
190        if self.doc_freq is None or self.n_docs is None or self.avgdl is None:
191            raise ValueError("BM25 must be fit before storing params")
192
193        tf_pairs = list(self.doc_freq.items())
194        return {
195            "avgdl": self.avgdl,
196            "n_docs": self.n_docs,
197            "doc_freq": {
198                "indices": [int(idx) for idx, _ in tf_pairs],
199                "values": [float(val) for _, val in tf_pairs],
200            },
201            "b": self.b,
202            "k1": self.k1,
203            "lower_case": self._tokenizer.lower_case,
204            "remove_punctuation": self._tokenizer.remove_punctuation,
205            "remove_stopwords": self._tokenizer.remove_stopwords,
206            "stem": self._tokenizer.stem,
207            "language": self._tokenizer.language,
208        }

Returns the BM25 params

def set_params( self, avgdl: float, n_docs: int, doc_freq: Dict[str, List[int]], b: float, k1: float, lower_case: bool, remove_punctuation: bool, remove_stopwords: bool, stem: bool, language: str) -> pinecone_text.sparse.bm25_encoder.BM25Encoder:
210    def set_params(
211        self,
212        avgdl: float,
213        n_docs: int,
214        doc_freq: Dict[str, List[int]],
215        b: float,
216        k1: float,
217        lower_case: bool,
218        remove_punctuation: bool,
219        remove_stopwords: bool,
220        stem: bool,
221        language: str,
222    ) -> "BM25Encoder":
223        """
224        Set input parameters to BM25
225
226        Args:
227            avgdl: average document length in the corpus
228            n_docs: number of documents in the corpus
229            doc_freq: document frequency of each term in the corpus
230            b: length normalization parameter
231            k1: term frequency normalization parameter
232            lower_case: whether to lower case the text
233            remove_punctuation: whether to remove punctuation from the text
234            remove_stopwords: whether to remove stopwords from the text
235            stem: whether to stem the text
236            language: language of the text for stopwords and stemmer
237        """
238        self.avgdl = avgdl  # type: ignore
239        self.n_docs = n_docs  # type: ignore
240        self.doc_freq = {
241            idx: val
242            for idx, val in zip(doc_freq["indices"], doc_freq["values"])  # type: ignore
243        }
244        self.b = b  # type: ignore
245        self.k1 = k1  # type: ignore
246        self._tokenizer = BM25Tokenizer(
247            lower_case=lower_case,  # type: ignore
248            remove_punctuation=remove_punctuation,  # type: ignore
249            remove_stopwords=remove_stopwords,  # type: ignore
250            stem=stem,  # type: ignore
251            language=language,
252        )  # type: ignore
253        return self

Set input parameters to BM25

Arguments:
  • avgdl: average document length in the corpus
  • n_docs: number of documents in the corpus
  • doc_freq: document frequency of each term in the corpus
  • b: length normalization parameter
  • k1: term frequency normalization parameter
  • lower_case: whether to lower case the text
  • remove_punctuation: whether to remove punctuation from the text
  • remove_stopwords: whether to remove stopwords from the text
  • stem: whether to stem the text
  • language: language of the text for stopwords and stemmer
@staticmethod
def default() -> pinecone_text.sparse.bm25_encoder.BM25Encoder:
255    @staticmethod
256    def default() -> "BM25Encoder":
257        """Create a BM25 model from pre-made params for the MS MARCO passages corpus"""
258        bm25 = BM25Encoder()
259        url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json"
260        with tempfile.TemporaryDirectory() as tmp_dir:
261            tmp_path = Path(tmp_dir, "msmarco_bm25_params.json")
262            wget.download(url, str(tmp_path))
263            bm25.load(str(tmp_path))
264        return bm25

Create a BM25 model from pre-made params for the MS MARCO passages corpus