pinecone_text.sparse.bm25_encoder
1import json 2import mmh3 3import numpy as np 4import tempfile 5from pathlib import Path 6from tqdm.auto import tqdm 7import wget 8from typing import List, Optional, Dict, Union, Tuple 9from collections import Counter 10 11from pinecone_text.sparse import SparseVector 12from pinecone_text.sparse.base_sparse_encoder import BaseSparseEncoder 13from pinecone_text.sparse.bm25_tokenizer import BM25Tokenizer 14 15 16class BM25Encoder(BaseSparseEncoder): 17 18 """OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)""" 19 20 def __init__( 21 self, 22 b: float = 0.75, 23 k1: float = 1.2, 24 lower_case: bool = True, 25 remove_punctuation: bool = True, 26 remove_stopwords: bool = True, 27 stem: bool = True, 28 language: str = "english", 29 ): 30 """ 31 OKapi BM25 with mmh3 hashing 32 33 Args: 34 b: The length normalization parameter 35 k1: The term frequency normalization parameter 36 lower_case: Whether to lower case the tokens 37 remove_punctuation: Whether to remove punctuation tokens 38 remove_stopwords: Whether to remove stopwords tokens 39 stem: Whether to stem the tokens (using SnowballStemmer) 40 language: The language of the text (used for stopwords and stemmer) 41 42 Example: 43 44 ```python 45 from pinecone_text.sparse import BM25 46 47 bm25 = BM25Encoder() 48 49 bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"]) 50 51 bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]} 52 bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]} 53 ``` 54 """ 55 # Fixed params 56 self.b: float = b 57 self.k1: float = k1 58 59 self._tokenizer = BM25Tokenizer( 60 lower_case=lower_case, 61 remove_punctuation=remove_punctuation, 62 remove_stopwords=remove_stopwords, 63 stem=stem, 64 language=language, 65 ) 66 67 # Learned Params 68 self.doc_freq: Optional[Dict[int, float]] = None 69 self.n_docs: Optional[int] = None 70 self.avgdl: Optional[float] = None 71 72 def fit(self, corpus: List[str]) -> "BM25Encoder": 73 """ 74 Fit BM25 by calculating document frequency over the corpus 75 76 Args: 77 corpus: list of texts to fit BM25 with 78 """ 79 n_docs = 0 80 sum_doc_len = 0 81 doc_freq_counter: Counter = Counter() 82 83 for doc in tqdm(corpus): 84 if not isinstance(doc, str): 85 raise ValueError("corpus must be a list of strings") 86 87 indices, tf = self._tf(doc) 88 if len(indices) == 0: 89 continue 90 n_docs += 1 91 sum_doc_len += sum(tf) 92 93 # Count the number of documents that contain each token 94 doc_freq_counter.update(indices) 95 96 self.doc_freq = dict(doc_freq_counter) 97 self.n_docs = n_docs 98 self.avgdl = sum_doc_len / n_docs 99 return self 100 101 def encode_documents( 102 self, texts: Union[str, List[str]] 103 ) -> Union[SparseVector, List[SparseVector]]: 104 """ 105 encode documents to a sparse vector (for upsert to pinecone) 106 107 Args: 108 texts: a single or list of documents to encode as a string 109 """ 110 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 111 raise ValueError("BM25 must be fit before encoding documents") 112 113 if isinstance(texts, str): 114 return self._encode_single_document(texts) 115 elif isinstance(texts, list): 116 return [self._encode_single_document(text) for text in texts] 117 else: 118 raise ValueError("texts must be a string or list of strings") 119 120 def _encode_single_document(self, text: str) -> SparseVector: 121 indices, doc_tf = self._tf(text) 122 tf = np.array(doc_tf) 123 tf_sum = sum(tf) 124 125 tf_normed = tf / ( 126 self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf 127 ) 128 return { 129 "indices": indices, 130 "values": tf_normed.tolist(), 131 } 132 133 def encode_queries( 134 self, texts: Union[str, List[str]] 135 ) -> Union[SparseVector, List[SparseVector]]: 136 """ 137 encode query to a sparse vector 138 139 Args: 140 texts: a single or list of queries to encode as a string 141 """ 142 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 143 raise ValueError("BM25 must be fit before encoding queries") 144 145 if isinstance(texts, str): 146 return self._encode_single_query(texts) 147 elif isinstance(texts, list): 148 return [self._encode_single_query(text) for text in texts] 149 else: 150 raise ValueError("texts must be a string or list of strings") 151 152 def _encode_single_query(self, text: str) -> SparseVector: 153 indices, query_tf = self._tf(text) 154 155 df = np.array([self.doc_freq.get(idx, 1) for idx in indices]) # type: ignore 156 idf = np.log((self.n_docs + 1) / (df + 0.5)) # type: ignore 157 idf_norm = idf / idf.sum() 158 return { 159 "indices": indices, 160 "values": idf_norm.tolist(), 161 } 162 163 def dump(self, path: str) -> None: 164 """ 165 Store BM25 params to a file in JSON format 166 167 Args: 168 path: full file path to save params in 169 """ 170 with open(path, "w") as f: 171 json.dump(self.get_params(), f) 172 173 def load(self, path: str) -> "BM25Encoder": 174 """ 175 Load BM25 params from a file in JSON format 176 177 Args: 178 path: full file path to load params from 179 """ 180 with open(path, "r") as f: 181 params = json.load(f) 182 return self.set_params(**params) 183 184 def get_params( 185 self, 186 ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]: 187 """Returns the BM25 params""" 188 189 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 190 raise ValueError("BM25 must be fit before storing params") 191 192 tf_pairs = list(self.doc_freq.items()) 193 return { 194 "avgdl": self.avgdl, 195 "n_docs": self.n_docs, 196 "doc_freq": { 197 "indices": [int(idx) for idx, _ in tf_pairs], 198 "values": [float(val) for _, val in tf_pairs], 199 }, 200 "b": self.b, 201 "k1": self.k1, 202 "lower_case": self._tokenizer.lower_case, 203 "remove_punctuation": self._tokenizer.remove_punctuation, 204 "remove_stopwords": self._tokenizer.remove_stopwords, 205 "stem": self._tokenizer.stem, 206 "language": self._tokenizer.language, 207 } 208 209 def set_params( 210 self, 211 avgdl: float, 212 n_docs: int, 213 doc_freq: Dict[str, List[int]], 214 b: float, 215 k1: float, 216 lower_case: bool, 217 remove_punctuation: bool, 218 remove_stopwords: bool, 219 stem: bool, 220 language: str, 221 ) -> "BM25Encoder": 222 """ 223 Set input parameters to BM25 224 225 Args: 226 avgdl: average document length in the corpus 227 n_docs: number of documents in the corpus 228 doc_freq: document frequency of each term in the corpus 229 b: length normalization parameter 230 k1: term frequency normalization parameter 231 lower_case: whether to lower case the text 232 remove_punctuation: whether to remove punctuation from the text 233 remove_stopwords: whether to remove stopwords from the text 234 stem: whether to stem the text 235 language: language of the text for stopwords and stemmer 236 """ 237 self.avgdl = avgdl # type: ignore 238 self.n_docs = n_docs # type: ignore 239 self.doc_freq = { 240 idx: val 241 for idx, val in zip(doc_freq["indices"], doc_freq["values"]) # type: ignore 242 } 243 self.b = b # type: ignore 244 self.k1 = k1 # type: ignore 245 self._tokenizer = BM25Tokenizer( 246 lower_case=lower_case, # type: ignore 247 remove_punctuation=remove_punctuation, # type: ignore 248 remove_stopwords=remove_stopwords, # type: ignore 249 stem=stem, # type: ignore 250 language=language, 251 ) # type: ignore 252 return self 253 254 @staticmethod 255 def default() -> "BM25Encoder": 256 """Create a BM25 model from pre-made params for the MS MARCO passages corpus""" 257 bm25 = BM25Encoder() 258 url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json" 259 with tempfile.TemporaryDirectory() as tmp_dir: 260 tmp_path = Path(tmp_dir, "msmarco_bm25_params.json") 261 wget.download(url, str(tmp_path)) 262 bm25.load(str(tmp_path)) 263 return bm25 264 265 @staticmethod 266 def _hash_text(token: str) -> int: 267 """Use mmh3 to hash text to 32-bit unsigned integer""" 268 return mmh3.hash(token, signed=False) 269 270 def _tf(self, text: str) -> Tuple[List[int], List[int]]: 271 """ 272 Calculate term frequency for a given text 273 274 Args: 275 text: a document to calculate term frequency for 276 277 Returns: a tuple of two lists: 278 indices: list of term indices 279 values: list of term frequencies 280 """ 281 counts = Counter((self._hash_text(token) for token in self._tokenizer(text))) 282 283 items = list(counts.items()) 284 return [idx for idx, _ in items], [val for _, val in items]
17class BM25Encoder(BaseSparseEncoder): 18 19 """OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)""" 20 21 def __init__( 22 self, 23 b: float = 0.75, 24 k1: float = 1.2, 25 lower_case: bool = True, 26 remove_punctuation: bool = True, 27 remove_stopwords: bool = True, 28 stem: bool = True, 29 language: str = "english", 30 ): 31 """ 32 OKapi BM25 with mmh3 hashing 33 34 Args: 35 b: The length normalization parameter 36 k1: The term frequency normalization parameter 37 lower_case: Whether to lower case the tokens 38 remove_punctuation: Whether to remove punctuation tokens 39 remove_stopwords: Whether to remove stopwords tokens 40 stem: Whether to stem the tokens (using SnowballStemmer) 41 language: The language of the text (used for stopwords and stemmer) 42 43 Example: 44 45 ```python 46 from pinecone_text.sparse import BM25 47 48 bm25 = BM25Encoder() 49 50 bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"]) 51 52 bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]} 53 bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]} 54 ``` 55 """ 56 # Fixed params 57 self.b: float = b 58 self.k1: float = k1 59 60 self._tokenizer = BM25Tokenizer( 61 lower_case=lower_case, 62 remove_punctuation=remove_punctuation, 63 remove_stopwords=remove_stopwords, 64 stem=stem, 65 language=language, 66 ) 67 68 # Learned Params 69 self.doc_freq: Optional[Dict[int, float]] = None 70 self.n_docs: Optional[int] = None 71 self.avgdl: Optional[float] = None 72 73 def fit(self, corpus: List[str]) -> "BM25Encoder": 74 """ 75 Fit BM25 by calculating document frequency over the corpus 76 77 Args: 78 corpus: list of texts to fit BM25 with 79 """ 80 n_docs = 0 81 sum_doc_len = 0 82 doc_freq_counter: Counter = Counter() 83 84 for doc in tqdm(corpus): 85 if not isinstance(doc, str): 86 raise ValueError("corpus must be a list of strings") 87 88 indices, tf = self._tf(doc) 89 if len(indices) == 0: 90 continue 91 n_docs += 1 92 sum_doc_len += sum(tf) 93 94 # Count the number of documents that contain each token 95 doc_freq_counter.update(indices) 96 97 self.doc_freq = dict(doc_freq_counter) 98 self.n_docs = n_docs 99 self.avgdl = sum_doc_len / n_docs 100 return self 101 102 def encode_documents( 103 self, texts: Union[str, List[str]] 104 ) -> Union[SparseVector, List[SparseVector]]: 105 """ 106 encode documents to a sparse vector (for upsert to pinecone) 107 108 Args: 109 texts: a single or list of documents to encode as a string 110 """ 111 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 112 raise ValueError("BM25 must be fit before encoding documents") 113 114 if isinstance(texts, str): 115 return self._encode_single_document(texts) 116 elif isinstance(texts, list): 117 return [self._encode_single_document(text) for text in texts] 118 else: 119 raise ValueError("texts must be a string or list of strings") 120 121 def _encode_single_document(self, text: str) -> SparseVector: 122 indices, doc_tf = self._tf(text) 123 tf = np.array(doc_tf) 124 tf_sum = sum(tf) 125 126 tf_normed = tf / ( 127 self.k1 * (1.0 - self.b + self.b * (tf_sum / self.avgdl)) + tf 128 ) 129 return { 130 "indices": indices, 131 "values": tf_normed.tolist(), 132 } 133 134 def encode_queries( 135 self, texts: Union[str, List[str]] 136 ) -> Union[SparseVector, List[SparseVector]]: 137 """ 138 encode query to a sparse vector 139 140 Args: 141 texts: a single or list of queries to encode as a string 142 """ 143 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 144 raise ValueError("BM25 must be fit before encoding queries") 145 146 if isinstance(texts, str): 147 return self._encode_single_query(texts) 148 elif isinstance(texts, list): 149 return [self._encode_single_query(text) for text in texts] 150 else: 151 raise ValueError("texts must be a string or list of strings") 152 153 def _encode_single_query(self, text: str) -> SparseVector: 154 indices, query_tf = self._tf(text) 155 156 df = np.array([self.doc_freq.get(idx, 1) for idx in indices]) # type: ignore 157 idf = np.log((self.n_docs + 1) / (df + 0.5)) # type: ignore 158 idf_norm = idf / idf.sum() 159 return { 160 "indices": indices, 161 "values": idf_norm.tolist(), 162 } 163 164 def dump(self, path: str) -> None: 165 """ 166 Store BM25 params to a file in JSON format 167 168 Args: 169 path: full file path to save params in 170 """ 171 with open(path, "w") as f: 172 json.dump(self.get_params(), f) 173 174 def load(self, path: str) -> "BM25Encoder": 175 """ 176 Load BM25 params from a file in JSON format 177 178 Args: 179 path: full file path to load params from 180 """ 181 with open(path, "r") as f: 182 params = json.load(f) 183 return self.set_params(**params) 184 185 def get_params( 186 self, 187 ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]: 188 """Returns the BM25 params""" 189 190 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 191 raise ValueError("BM25 must be fit before storing params") 192 193 tf_pairs = list(self.doc_freq.items()) 194 return { 195 "avgdl": self.avgdl, 196 "n_docs": self.n_docs, 197 "doc_freq": { 198 "indices": [int(idx) for idx, _ in tf_pairs], 199 "values": [float(val) for _, val in tf_pairs], 200 }, 201 "b": self.b, 202 "k1": self.k1, 203 "lower_case": self._tokenizer.lower_case, 204 "remove_punctuation": self._tokenizer.remove_punctuation, 205 "remove_stopwords": self._tokenizer.remove_stopwords, 206 "stem": self._tokenizer.stem, 207 "language": self._tokenizer.language, 208 } 209 210 def set_params( 211 self, 212 avgdl: float, 213 n_docs: int, 214 doc_freq: Dict[str, List[int]], 215 b: float, 216 k1: float, 217 lower_case: bool, 218 remove_punctuation: bool, 219 remove_stopwords: bool, 220 stem: bool, 221 language: str, 222 ) -> "BM25Encoder": 223 """ 224 Set input parameters to BM25 225 226 Args: 227 avgdl: average document length in the corpus 228 n_docs: number of documents in the corpus 229 doc_freq: document frequency of each term in the corpus 230 b: length normalization parameter 231 k1: term frequency normalization parameter 232 lower_case: whether to lower case the text 233 remove_punctuation: whether to remove punctuation from the text 234 remove_stopwords: whether to remove stopwords from the text 235 stem: whether to stem the text 236 language: language of the text for stopwords and stemmer 237 """ 238 self.avgdl = avgdl # type: ignore 239 self.n_docs = n_docs # type: ignore 240 self.doc_freq = { 241 idx: val 242 for idx, val in zip(doc_freq["indices"], doc_freq["values"]) # type: ignore 243 } 244 self.b = b # type: ignore 245 self.k1 = k1 # type: ignore 246 self._tokenizer = BM25Tokenizer( 247 lower_case=lower_case, # type: ignore 248 remove_punctuation=remove_punctuation, # type: ignore 249 remove_stopwords=remove_stopwords, # type: ignore 250 stem=stem, # type: ignore 251 language=language, 252 ) # type: ignore 253 return self 254 255 @staticmethod 256 def default() -> "BM25Encoder": 257 """Create a BM25 model from pre-made params for the MS MARCO passages corpus""" 258 bm25 = BM25Encoder() 259 url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json" 260 with tempfile.TemporaryDirectory() as tmp_dir: 261 tmp_path = Path(tmp_dir, "msmarco_bm25_params.json") 262 wget.download(url, str(tmp_path)) 263 bm25.load(str(tmp_path)) 264 return bm25 265 266 @staticmethod 267 def _hash_text(token: str) -> int: 268 """Use mmh3 to hash text to 32-bit unsigned integer""" 269 return mmh3.hash(token, signed=False) 270 271 def _tf(self, text: str) -> Tuple[List[int], List[int]]: 272 """ 273 Calculate term frequency for a given text 274 275 Args: 276 text: a document to calculate term frequency for 277 278 Returns: a tuple of two lists: 279 indices: list of term indices 280 values: list of term frequencies 281 """ 282 counts = Counter((self._hash_text(token) for token in self._tokenizer(text))) 283 284 items = list(counts.items()) 285 return [idx for idx, _ in items], [val for _, val in items]
OKAPI BM25 implementation for single fit to a corpus (no continuous corpus updates supported)
BM25Encoder( b: float = 0.75, k1: float = 1.2, lower_case: bool = True, remove_punctuation: bool = True, remove_stopwords: bool = True, stem: bool = True, language: str = 'english')
21 def __init__( 22 self, 23 b: float = 0.75, 24 k1: float = 1.2, 25 lower_case: bool = True, 26 remove_punctuation: bool = True, 27 remove_stopwords: bool = True, 28 stem: bool = True, 29 language: str = "english", 30 ): 31 """ 32 OKapi BM25 with mmh3 hashing 33 34 Args: 35 b: The length normalization parameter 36 k1: The term frequency normalization parameter 37 lower_case: Whether to lower case the tokens 38 remove_punctuation: Whether to remove punctuation tokens 39 remove_stopwords: Whether to remove stopwords tokens 40 stem: Whether to stem the tokens (using SnowballStemmer) 41 language: The language of the text (used for stopwords and stemmer) 42 43 Example: 44 45 ```python 46 from pinecone_text.sparse import BM25 47 48 bm25 = BM25Encoder() 49 50 bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"]) 51 52 bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]} 53 bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]} 54 ``` 55 """ 56 # Fixed params 57 self.b: float = b 58 self.k1: float = k1 59 60 self._tokenizer = BM25Tokenizer( 61 lower_case=lower_case, 62 remove_punctuation=remove_punctuation, 63 remove_stopwords=remove_stopwords, 64 stem=stem, 65 language=language, 66 ) 67 68 # Learned Params 69 self.doc_freq: Optional[Dict[int, float]] = None 70 self.n_docs: Optional[int] = None 71 self.avgdl: Optional[float] = None
OKapi BM25 with mmh3 hashing
Arguments:
- b: The length normalization parameter
- k1: The term frequency normalization parameter
- lower_case: Whether to lower case the tokens
- remove_punctuation: Whether to remove punctuation tokens
- remove_stopwords: Whether to remove stopwords tokens
- stem: Whether to stem the tokens (using SnowballStemmer)
- language: The language of the text (used for stopwords and stemmer)
Example:
from pinecone_text.sparse import BM25 bm25 = BM25Encoder() bm25.fit([ "The quick brown fox jumps over the lazy dog", "The lazy dog is brown"]) bm25.encode_documents("The brown fox is quick") # {"indices": [102, 18, 12, ...], "values": [0.21, 0.38, 0.15, ...]} bm25.encode_queries("Which fox is brown?") # # {"indices": [102, 16, 18, ...], "values": [0.21, 0.11, 0.15, ...]}
73 def fit(self, corpus: List[str]) -> "BM25Encoder": 74 """ 75 Fit BM25 by calculating document frequency over the corpus 76 77 Args: 78 corpus: list of texts to fit BM25 with 79 """ 80 n_docs = 0 81 sum_doc_len = 0 82 doc_freq_counter: Counter = Counter() 83 84 for doc in tqdm(corpus): 85 if not isinstance(doc, str): 86 raise ValueError("corpus must be a list of strings") 87 88 indices, tf = self._tf(doc) 89 if len(indices) == 0: 90 continue 91 n_docs += 1 92 sum_doc_len += sum(tf) 93 94 # Count the number of documents that contain each token 95 doc_freq_counter.update(indices) 96 97 self.doc_freq = dict(doc_freq_counter) 98 self.n_docs = n_docs 99 self.avgdl = sum_doc_len / n_docs 100 return self
Fit BM25 by calculating document frequency over the corpus
Arguments:
- corpus: list of texts to fit BM25 with
def
encode_documents( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
102 def encode_documents( 103 self, texts: Union[str, List[str]] 104 ) -> Union[SparseVector, List[SparseVector]]: 105 """ 106 encode documents to a sparse vector (for upsert to pinecone) 107 108 Args: 109 texts: a single or list of documents to encode as a string 110 """ 111 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 112 raise ValueError("BM25 must be fit before encoding documents") 113 114 if isinstance(texts, str): 115 return self._encode_single_document(texts) 116 elif isinstance(texts, list): 117 return [self._encode_single_document(text) for text in texts] 118 else: 119 raise ValueError("texts must be a string or list of strings")
encode documents to a sparse vector (for upsert to pinecone)
Arguments:
- texts: a single or list of documents to encode as a string
def
encode_queries( self, texts: Union[str, List[str]]) -> Union[Dict[str, Union[List[int], List[float]]], List[Dict[str, Union[List[int], List[float]]]]]:
134 def encode_queries( 135 self, texts: Union[str, List[str]] 136 ) -> Union[SparseVector, List[SparseVector]]: 137 """ 138 encode query to a sparse vector 139 140 Args: 141 texts: a single or list of queries to encode as a string 142 """ 143 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 144 raise ValueError("BM25 must be fit before encoding queries") 145 146 if isinstance(texts, str): 147 return self._encode_single_query(texts) 148 elif isinstance(texts, list): 149 return [self._encode_single_query(text) for text in texts] 150 else: 151 raise ValueError("texts must be a string or list of strings")
encode query to a sparse vector
Arguments:
- texts: a single or list of queries to encode as a string
def
dump(self, path: str) -> None:
164 def dump(self, path: str) -> None: 165 """ 166 Store BM25 params to a file in JSON format 167 168 Args: 169 path: full file path to save params in 170 """ 171 with open(path, "w") as f: 172 json.dump(self.get_params(), f)
Store BM25 params to a file in JSON format
Arguments:
- path: full file path to save params in
174 def load(self, path: str) -> "BM25Encoder": 175 """ 176 Load BM25 params from a file in JSON format 177 178 Args: 179 path: full file path to load params from 180 """ 181 with open(path, "r") as f: 182 params = json.load(f) 183 return self.set_params(**params)
Load BM25 params from a file in JSON format
Arguments:
- path: full file path to load params from
def
get_params( self) -> Dict[str, Union[int, float, str, Dict[str, List[Union[float, int]]]]]:
185 def get_params( 186 self, 187 ) -> Dict[str, Union[int, float, str, Dict[str, List[Union[int, float]]]]]: 188 """Returns the BM25 params""" 189 190 if self.doc_freq is None or self.n_docs is None or self.avgdl is None: 191 raise ValueError("BM25 must be fit before storing params") 192 193 tf_pairs = list(self.doc_freq.items()) 194 return { 195 "avgdl": self.avgdl, 196 "n_docs": self.n_docs, 197 "doc_freq": { 198 "indices": [int(idx) for idx, _ in tf_pairs], 199 "values": [float(val) for _, val in tf_pairs], 200 }, 201 "b": self.b, 202 "k1": self.k1, 203 "lower_case": self._tokenizer.lower_case, 204 "remove_punctuation": self._tokenizer.remove_punctuation, 205 "remove_stopwords": self._tokenizer.remove_stopwords, 206 "stem": self._tokenizer.stem, 207 "language": self._tokenizer.language, 208 }
Returns the BM25 params
def
set_params( self, avgdl: float, n_docs: int, doc_freq: Dict[str, List[int]], b: float, k1: float, lower_case: bool, remove_punctuation: bool, remove_stopwords: bool, stem: bool, language: str) -> pinecone_text.sparse.bm25_encoder.BM25Encoder:
210 def set_params( 211 self, 212 avgdl: float, 213 n_docs: int, 214 doc_freq: Dict[str, List[int]], 215 b: float, 216 k1: float, 217 lower_case: bool, 218 remove_punctuation: bool, 219 remove_stopwords: bool, 220 stem: bool, 221 language: str, 222 ) -> "BM25Encoder": 223 """ 224 Set input parameters to BM25 225 226 Args: 227 avgdl: average document length in the corpus 228 n_docs: number of documents in the corpus 229 doc_freq: document frequency of each term in the corpus 230 b: length normalization parameter 231 k1: term frequency normalization parameter 232 lower_case: whether to lower case the text 233 remove_punctuation: whether to remove punctuation from the text 234 remove_stopwords: whether to remove stopwords from the text 235 stem: whether to stem the text 236 language: language of the text for stopwords and stemmer 237 """ 238 self.avgdl = avgdl # type: ignore 239 self.n_docs = n_docs # type: ignore 240 self.doc_freq = { 241 idx: val 242 for idx, val in zip(doc_freq["indices"], doc_freq["values"]) # type: ignore 243 } 244 self.b = b # type: ignore 245 self.k1 = k1 # type: ignore 246 self._tokenizer = BM25Tokenizer( 247 lower_case=lower_case, # type: ignore 248 remove_punctuation=remove_punctuation, # type: ignore 249 remove_stopwords=remove_stopwords, # type: ignore 250 stem=stem, # type: ignore 251 language=language, 252 ) # type: ignore 253 return self
Set input parameters to BM25
Arguments:
- avgdl: average document length in the corpus
- n_docs: number of documents in the corpus
- doc_freq: document frequency of each term in the corpus
- b: length normalization parameter
- k1: term frequency normalization parameter
- lower_case: whether to lower case the text
- remove_punctuation: whether to remove punctuation from the text
- remove_stopwords: whether to remove stopwords from the text
- stem: whether to stem the text
- language: language of the text for stopwords and stemmer
255 @staticmethod 256 def default() -> "BM25Encoder": 257 """Create a BM25 model from pre-made params for the MS MARCO passages corpus""" 258 bm25 = BM25Encoder() 259 url = "https://storage.googleapis.com/pinecone-datasets-dev/bm25_params/msmarco_bm25_params_v4_0_0.json" 260 with tempfile.TemporaryDirectory() as tmp_dir: 261 tmp_path = Path(tmp_dir, "msmarco_bm25_params.json") 262 wget.download(url, str(tmp_path)) 263 bm25.load(str(tmp_path)) 264 return bm25
Create a BM25 model from pre-made params for the MS MARCO passages corpus