pinecone_text.sparse.bm25_tokenizer
1import string 2import nltk 3from typing import List 4 5from nltk import word_tokenize, SnowballStemmer 6from nltk.corpus import stopwords 7 8 9class BM25Tokenizer: 10 def __init__( 11 self, 12 lower_case: bool, 13 remove_punctuation: bool, 14 remove_stopwords: bool, 15 stem: bool, 16 language: str, 17 ): 18 self.nltk_setup() 19 20 self.lower_case = lower_case 21 self.remove_punctuation = remove_punctuation 22 self.remove_stopwords = remove_stopwords 23 self.stem = stem 24 self.language = language 25 self._stemmer = SnowballStemmer(language) 26 self._stop_words = set(stopwords.words(language)) 27 self._punctuation = set(string.punctuation) 28 29 if self.stem and not self.lower_case: 30 raise ValueError( 31 "Stemming applying lower case to tokens, so lower_case must be True if stem is True" 32 ) 33 34 @staticmethod 35 def nltk_setup() -> None: 36 try: 37 nltk.data.find("tokenizers/punkt") 38 except LookupError: 39 nltk.download("punkt") 40 41 try: 42 nltk.data.find("corpora/stopwords") 43 except LookupError: 44 nltk.download("stopwords") 45 46 def __call__(self, text: str) -> List[str]: 47 tokens = word_tokenize(text) 48 if self.lower_case: 49 tokens = [word.lower() for word in tokens] 50 if self.remove_punctuation: 51 tokens = [word for word in tokens if word not in self._punctuation] 52 if self.remove_stopwords: 53 if self.lower_case: 54 tokens = [word for word in tokens if word not in self._stop_words] 55 else: 56 tokens = [ 57 word for word in tokens if word.lower() not in self._stop_words 58 ] 59 if self.stem: 60 tokens = [self._stemmer.stem(word) for word in tokens] 61 return tokens
class
BM25Tokenizer:
10class BM25Tokenizer: 11 def __init__( 12 self, 13 lower_case: bool, 14 remove_punctuation: bool, 15 remove_stopwords: bool, 16 stem: bool, 17 language: str, 18 ): 19 self.nltk_setup() 20 21 self.lower_case = lower_case 22 self.remove_punctuation = remove_punctuation 23 self.remove_stopwords = remove_stopwords 24 self.stem = stem 25 self.language = language 26 self._stemmer = SnowballStemmer(language) 27 self._stop_words = set(stopwords.words(language)) 28 self._punctuation = set(string.punctuation) 29 30 if self.stem and not self.lower_case: 31 raise ValueError( 32 "Stemming applying lower case to tokens, so lower_case must be True if stem is True" 33 ) 34 35 @staticmethod 36 def nltk_setup() -> None: 37 try: 38 nltk.data.find("tokenizers/punkt") 39 except LookupError: 40 nltk.download("punkt") 41 42 try: 43 nltk.data.find("corpora/stopwords") 44 except LookupError: 45 nltk.download("stopwords") 46 47 def __call__(self, text: str) -> List[str]: 48 tokens = word_tokenize(text) 49 if self.lower_case: 50 tokens = [word.lower() for word in tokens] 51 if self.remove_punctuation: 52 tokens = [word for word in tokens if word not in self._punctuation] 53 if self.remove_stopwords: 54 if self.lower_case: 55 tokens = [word for word in tokens if word not in self._stop_words] 56 else: 57 tokens = [ 58 word for word in tokens if word.lower() not in self._stop_words 59 ] 60 if self.stem: 61 tokens = [self._stemmer.stem(word) for word in tokens] 62 return tokens
BM25Tokenizer( lower_case: bool, remove_punctuation: bool, remove_stopwords: bool, stem: bool, language: str)
11 def __init__( 12 self, 13 lower_case: bool, 14 remove_punctuation: bool, 15 remove_stopwords: bool, 16 stem: bool, 17 language: str, 18 ): 19 self.nltk_setup() 20 21 self.lower_case = lower_case 22 self.remove_punctuation = remove_punctuation 23 self.remove_stopwords = remove_stopwords 24 self.stem = stem 25 self.language = language 26 self._stemmer = SnowballStemmer(language) 27 self._stop_words = set(stopwords.words(language)) 28 self._punctuation = set(string.punctuation) 29 30 if self.stem and not self.lower_case: 31 raise ValueError( 32 "Stemming applying lower case to tokens, so lower_case must be True if stem is True" 33 )