1 use std::cmp::Ordering;
2 use std::collections::{BTreeSet, BinaryHeap};
3 use std::io::{self, BufRead, BufReader};
4 
5 use ordered_float::OrderedFloat;
6 
7 use super::{Keyword, KeywordExtract, STOP_WORDS};
8 use crate::FxHashMap as HashMap;
9 use crate::Jieba;
10 
11 static DEFAULT_IDF: &str = include_str!("../data/idf.txt");
12 
13 #[derive(Debug, Clone, Eq, PartialEq)]
14 struct HeapNode<'a> {
15     tfidf: OrderedFloat<f64>,
16     word: &'a str,
17 }
18 
19 impl<'a> Ord for HeapNode<'a> {
cmp(&self, other: &HeapNode) -> Ordering20     fn cmp(&self, other: &HeapNode) -> Ordering {
21         other.tfidf.cmp(&self.tfidf).then_with(|| self.word.cmp(&other.word))
22     }
23 }
24 
25 impl<'a> PartialOrd for HeapNode<'a> {
partial_cmp(&self, other: &HeapNode) -> Option<Ordering>26     fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> {
27         Some(self.cmp(other))
28     }
29 }
30 
31 /// TF-IDF keywords extraction
32 ///
33 /// Require `tfidf` feature to be enabled
34 #[derive(Debug)]
35 pub struct TFIDF<'a> {
36     jieba: &'a Jieba,
37     idf_dict: HashMap<String, f64>,
38     median_idf: f64,
39     stop_words: BTreeSet<String>,
40 }
41 
42 impl<'a> TFIDF<'a> {
new_with_jieba(jieba: &'a Jieba) -> Self43     pub fn new_with_jieba(jieba: &'a Jieba) -> Self {
44         let mut instance = TFIDF {
45             jieba,
46             idf_dict: HashMap::default(),
47             median_idf: 0.0,
48             stop_words: STOP_WORDS.clone(),
49         };
50 
51         let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes());
52         instance.load_dict(&mut default_dict).unwrap();
53         instance
54     }
55 
load_dict<R: BufRead>(&mut self, dict: &mut R) -> io::Result<()>56     pub fn load_dict<R: BufRead>(&mut self, dict: &mut R) -> io::Result<()> {
57         let mut buf = String::new();
58         let mut idf_heap = BinaryHeap::new();
59         while dict.read_line(&mut buf)? > 0 {
60             let parts: Vec<&str> = buf.trim().split_whitespace().collect();
61             if parts.is_empty() {
62                 continue;
63             }
64 
65             let word = parts[0];
66             if let Some(idf) = parts.get(1).and_then(|x| x.parse::<f64>().ok()) {
67                 self.idf_dict.insert(word.to_string(), idf);
68                 idf_heap.push(OrderedFloat(idf));
69             }
70 
71             buf.clear();
72         }
73 
74         let m = idf_heap.len() / 2;
75         for _ in 0..m {
76             idf_heap.pop();
77         }
78 
79         self.median_idf = idf_heap.pop().unwrap().into_inner();
80 
81         Ok(())
82     }
83 
84     /// Add a new stop word
add_stop_word(&mut self, word: String) -> bool85     pub fn add_stop_word(&mut self, word: String) -> bool {
86         self.stop_words.insert(word)
87     }
88 
89     /// Remove an existing stop word
remove_stop_word(&mut self, word: &str) -> bool90     pub fn remove_stop_word(&mut self, word: &str) -> bool {
91         self.stop_words.remove(word)
92     }
93 
94     /// Replace all stop words with new stop words set
set_stop_words(&mut self, stop_words: BTreeSet<String>)95     pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) {
96         self.stop_words = stop_words
97     }
98 
99     #[inline]
filter(&self, s: &str) -> bool100     fn filter(&self, s: &str) -> bool {
101         if s.chars().count() < 2 {
102             return false;
103         }
104 
105         if self.stop_words.contains(&s.to_lowercase()) {
106             return false;
107         }
108 
109         true
110     }
111 }
112 
113 impl<'a> KeywordExtract for TFIDF<'a> {
extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword>114     fn extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
115         let tags = self.jieba.tag(sentence, false);
116         let mut allowed_pos_set = BTreeSet::new();
117 
118         for s in allowed_pos {
119             allowed_pos_set.insert(s);
120         }
121 
122         let mut term_freq: HashMap<String, u64> = HashMap::default();
123         for t in &tags {
124             if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
125                 continue;
126             }
127 
128             if !self.filter(t.word) {
129                 continue;
130             }
131 
132             let entry = term_freq.entry(String::from(t.word)).or_insert(0);
133             *entry += 1;
134         }
135 
136         let total: u64 = term_freq.values().sum();
137         let mut heap = BinaryHeap::new();
138         for (cnt, (k, tf)) in term_freq.iter().enumerate() {
139             let idf = self.idf_dict.get(k).unwrap_or(&self.median_idf);
140             let node = HeapNode {
141                 tfidf: OrderedFloat(*tf as f64 * idf / total as f64),
142                 word: k,
143             };
144             heap.push(node);
145             if cnt >= top_k {
146                 heap.pop();
147             }
148         }
149 
150         let mut res = Vec::new();
151         for _ in 0..top_k {
152             if let Some(w) = heap.pop() {
153                 res.push(Keyword {
154                     keyword: String::from(w.word),
155                     weight: w.tfidf.into_inner(),
156                 });
157             }
158         }
159 
160         res.reverse();
161         res
162     }
163 }
164 
165 #[cfg(test)]
166 mod tests {
167     use super::*;
168 
169     #[test]
test_init_with_default_idf_dict()170     fn test_init_with_default_idf_dict() {
171         let jieba = super::Jieba::new();
172         let _ = TFIDF::new_with_jieba(&jieba);
173     }
174 
175     #[test]
test_extract_tags()176     fn test_extract_tags() {
177         let jieba = super::Jieba::new();
178         let keyword_extractor = TFIDF::new_with_jieba(&jieba);
179         let mut top_k = keyword_extractor.extract_tags(
180             "今天纽约的天气真好啊,京华大酒店的张尧经理吃了一只北京烤鸭。后天纽约的天气不好,昨天纽约的天气也不好,北京烤鸭真好吃",
181             3,
182             vec![],
183         );
184         assert_eq!(
185             top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
186             vec!["北京烤鸭", "纽约", "天气"]
187         );
188 
189         top_k = keyword_extractor.extract_tags(
190             "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
191             5,
192             vec![],
193         );
194         assert_eq!(
195             top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
196             vec!["欧亚", "吉林", "置业", "万元", "增资"]
197         );
198 
199         top_k = keyword_extractor.extract_tags(
200             "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
201             5,
202             vec![String::from("ns"), String::from("n"), String::from("vn"), String::from("v")],
203         );
204         assert_eq!(
205             top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
206             vec!["欧亚", "吉林", "置业", "增资", "实现"]
207         );
208     }
209 }
210