pinecone_text.sparse.bm25_tokenizer

 1import string
 2import nltk
 3from typing import List
 4
 5from nltk import word_tokenize, SnowballStemmer
 6from nltk.corpus import stopwords
 7
 8
 9class BM25Tokenizer:
10    def __init__(
11        self,
12        lower_case: bool,
13        remove_punctuation: bool,
14        remove_stopwords: bool,
15        stem: bool,
16        language: str,
17    ):
18        self.nltk_setup()
19
20        self.lower_case = lower_case
21        self.remove_punctuation = remove_punctuation
22        self.remove_stopwords = remove_stopwords
23        self.stem = stem
24        self.language = language
25        self._stemmer = SnowballStemmer(language)
26        self._stop_words = set(stopwords.words(language))
27        self._punctuation = set(string.punctuation)
28
29        if self.stem and not self.lower_case:
30            raise ValueError(
31                "Stemming applying lower case to tokens, so lower_case must be True if stem is True"
32            )
33
34    @staticmethod
35    def nltk_setup() -> None:
36        try:
37            nltk.data.find("tokenizers/punkt")
38        except LookupError:
39            nltk.download("punkt")
40
41        try:
42            nltk.data.find("corpora/stopwords")
43        except LookupError:
44            nltk.download("stopwords")
45
46    def __call__(self, text: str) -> List[str]:
47        tokens = word_tokenize(text)
48        if self.lower_case:
49            tokens = [word.lower() for word in tokens]
50        if self.remove_punctuation:
51            tokens = [word for word in tokens if word not in self._punctuation]
52        if self.remove_stopwords:
53            if self.lower_case:
54                tokens = [word for word in tokens if word not in self._stop_words]
55            else:
56                tokens = [
57                    word for word in tokens if word.lower() not in self._stop_words
58                ]
59        if self.stem:
60            tokens = [self._stemmer.stem(word) for word in tokens]
61        return tokens
class BM25Tokenizer:
10class BM25Tokenizer:
11    def __init__(
12        self,
13        lower_case: bool,
14        remove_punctuation: bool,
15        remove_stopwords: bool,
16        stem: bool,
17        language: str,
18    ):
19        self.nltk_setup()
20
21        self.lower_case = lower_case
22        self.remove_punctuation = remove_punctuation
23        self.remove_stopwords = remove_stopwords
24        self.stem = stem
25        self.language = language
26        self._stemmer = SnowballStemmer(language)
27        self._stop_words = set(stopwords.words(language))
28        self._punctuation = set(string.punctuation)
29
30        if self.stem and not self.lower_case:
31            raise ValueError(
32                "Stemming applying lower case to tokens, so lower_case must be True if stem is True"
33            )
34
35    @staticmethod
36    def nltk_setup() -> None:
37        try:
38            nltk.data.find("tokenizers/punkt")
39        except LookupError:
40            nltk.download("punkt")
41
42        try:
43            nltk.data.find("corpora/stopwords")
44        except LookupError:
45            nltk.download("stopwords")
46
47    def __call__(self, text: str) -> List[str]:
48        tokens = word_tokenize(text)
49        if self.lower_case:
50            tokens = [word.lower() for word in tokens]
51        if self.remove_punctuation:
52            tokens = [word for word in tokens if word not in self._punctuation]
53        if self.remove_stopwords:
54            if self.lower_case:
55                tokens = [word for word in tokens if word not in self._stop_words]
56            else:
57                tokens = [
58                    word for word in tokens if word.lower() not in self._stop_words
59                ]
60        if self.stem:
61            tokens = [self._stemmer.stem(word) for word in tokens]
62        return tokens
BM25Tokenizer( lower_case: bool, remove_punctuation: bool, remove_stopwords: bool, stem: bool, language: str)
11    def __init__(
12        self,
13        lower_case: bool,
14        remove_punctuation: bool,
15        remove_stopwords: bool,
16        stem: bool,
17        language: str,
18    ):
19        self.nltk_setup()
20
21        self.lower_case = lower_case
22        self.remove_punctuation = remove_punctuation
23        self.remove_stopwords = remove_stopwords
24        self.stem = stem
25        self.language = language
26        self._stemmer = SnowballStemmer(language)
27        self._stop_words = set(stopwords.words(language))
28        self._punctuation = set(string.punctuation)
29
30        if self.stem and not self.lower_case:
31            raise ValueError(
32                "Stemming applying lower case to tokens, so lower_case must be True if stem is True"
33            )
@staticmethod
def nltk_setup() -> None:
35    @staticmethod
36    def nltk_setup() -> None:
37        try:
38            nltk.data.find("tokenizers/punkt")
39        except LookupError:
40            nltk.download("punkt")
41
42        try:
43            nltk.data.find("corpora/stopwords")
44        except LookupError:
45            nltk.download("stopwords")