1 use std::cmp::Ordering; 2 use std::collections::{BTreeSet, BinaryHeap}; 3 use std::io::{self, BufRead, BufReader}; 4 5 use ordered_float::OrderedFloat; 6 7 use super::{Keyword, KeywordExtract, STOP_WORDS}; 8 use crate::FxHashMap as HashMap; 9 use crate::Jieba; 10 11 static DEFAULT_IDF: &str = include_str!("../data/idf.txt"); 12 13 #[derive(Debug, Clone, Eq, PartialEq)] 14 struct HeapNode<'a> { 15 tfidf: OrderedFloat<f64>, 16 word: &'a str, 17 } 18 19 impl<'a> Ord for HeapNode<'a> { cmp(&self, other: &HeapNode) -> Ordering20 fn cmp(&self, other: &HeapNode) -> Ordering { 21 other.tfidf.cmp(&self.tfidf).then_with(|| self.word.cmp(&other.word)) 22 } 23 } 24 25 impl<'a> PartialOrd for HeapNode<'a> { partial_cmp(&self, other: &HeapNode) -> Option<Ordering>26 fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> { 27 Some(self.cmp(other)) 28 } 29 } 30 31 /// TF-IDF keywords extraction 32 /// 33 /// Require `tfidf` feature to be enabled 34 #[derive(Debug)] 35 pub struct TFIDF<'a> { 36 jieba: &'a Jieba, 37 idf_dict: HashMap<String, f64>, 38 median_idf: f64, 39 stop_words: BTreeSet<String>, 40 } 41 42 impl<'a> TFIDF<'a> { new_with_jieba(jieba: &'a Jieba) -> Self43 pub fn new_with_jieba(jieba: &'a Jieba) -> Self { 44 let mut instance = TFIDF { 45 jieba, 46 idf_dict: HashMap::default(), 47 median_idf: 0.0, 48 stop_words: STOP_WORDS.clone(), 49 }; 50 51 let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes()); 52 instance.load_dict(&mut default_dict).unwrap(); 53 instance 54 } 55 load_dict<R: BufRead>(&mut self, dict: &mut R) -> io::Result<()>56 pub fn load_dict<R: BufRead>(&mut self, dict: &mut R) -> io::Result<()> { 57 let mut buf = String::new(); 58 let mut idf_heap = BinaryHeap::new(); 59 while dict.read_line(&mut buf)? > 0 { 60 let parts: Vec<&str> = buf.trim().split_whitespace().collect(); 61 if parts.is_empty() { 62 continue; 63 } 64 65 let word = parts[0]; 66 if let Some(idf) = parts.get(1).and_then(|x| x.parse::<f64>().ok()) { 67 self.idf_dict.insert(word.to_string(), idf); 68 idf_heap.push(OrderedFloat(idf)); 69 } 70 71 buf.clear(); 72 } 73 74 let m = idf_heap.len() / 2; 75 for _ in 0..m { 76 idf_heap.pop(); 77 } 78 79 self.median_idf = idf_heap.pop().unwrap().into_inner(); 80 81 Ok(()) 82 } 83 84 /// Add a new stop word add_stop_word(&mut self, word: String) -> bool85 pub fn add_stop_word(&mut self, word: String) -> bool { 86 self.stop_words.insert(word) 87 } 88 89 /// Remove an existing stop word remove_stop_word(&mut self, word: &str) -> bool90 pub fn remove_stop_word(&mut self, word: &str) -> bool { 91 self.stop_words.remove(word) 92 } 93 94 /// Replace all stop words with new stop words set set_stop_words(&mut self, stop_words: BTreeSet<String>)95 pub fn set_stop_words(&mut self, stop_words: BTreeSet<String>) { 96 self.stop_words = stop_words 97 } 98 99 #[inline] filter(&self, s: &str) -> bool100 fn filter(&self, s: &str) -> bool { 101 if s.chars().count() < 2 { 102 return false; 103 } 104 105 if self.stop_words.contains(&s.to_lowercase()) { 106 return false; 107 } 108 109 true 110 } 111 } 112 113 impl<'a> KeywordExtract for TFIDF<'a> { extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword>114 fn extract_tags(&self, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> { 115 let tags = self.jieba.tag(sentence, false); 116 let mut allowed_pos_set = BTreeSet::new(); 117 118 for s in allowed_pos { 119 allowed_pos_set.insert(s); 120 } 121 122 let mut term_freq: HashMap<String, u64> = HashMap::default(); 123 for t in &tags { 124 if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) { 125 continue; 126 } 127 128 if !self.filter(t.word) { 129 continue; 130 } 131 132 let entry = term_freq.entry(String::from(t.word)).or_insert(0); 133 *entry += 1; 134 } 135 136 let total: u64 = term_freq.values().sum(); 137 let mut heap = BinaryHeap::new(); 138 for (cnt, (k, tf)) in term_freq.iter().enumerate() { 139 let idf = self.idf_dict.get(k).unwrap_or(&self.median_idf); 140 let node = HeapNode { 141 tfidf: OrderedFloat(*tf as f64 * idf / total as f64), 142 word: k, 143 }; 144 heap.push(node); 145 if cnt >= top_k { 146 heap.pop(); 147 } 148 } 149 150 let mut res = Vec::new(); 151 for _ in 0..top_k { 152 if let Some(w) = heap.pop() { 153 res.push(Keyword { 154 keyword: String::from(w.word), 155 weight: w.tfidf.into_inner(), 156 }); 157 } 158 } 159 160 res.reverse(); 161 res 162 } 163 } 164 165 #[cfg(test)] 166 mod tests { 167 use super::*; 168 169 #[test] test_init_with_default_idf_dict()170 fn test_init_with_default_idf_dict() { 171 let jieba = super::Jieba::new(); 172 let _ = TFIDF::new_with_jieba(&jieba); 173 } 174 175 #[test] test_extract_tags()176 fn test_extract_tags() { 177 let jieba = super::Jieba::new(); 178 let keyword_extractor = TFIDF::new_with_jieba(&jieba); 179 let mut top_k = keyword_extractor.extract_tags( 180 "今天纽约的天气真好啊,京华大酒店的张尧经理吃了一只北京烤鸭。后天纽约的天气不好,昨天纽约的天气也不好,北京烤鸭真好吃", 181 3, 182 vec![], 183 ); 184 assert_eq!( 185 top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(), 186 vec!["北京烤鸭", "纽约", "天气"] 187 ); 188 189 top_k = keyword_extractor.extract_tags( 190 "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。", 191 5, 192 vec![], 193 ); 194 assert_eq!( 195 top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(), 196 vec!["欧亚", "吉林", "置业", "万元", "增资"] 197 ); 198 199 top_k = keyword_extractor.extract_tags( 200 "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。", 201 5, 202 vec![String::from("ns"), String::from("n"), String::from("vn"), String::from("v")], 203 ); 204 assert_eq!( 205 top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(), 206 vec!["欧亚", "吉林", "置业", "增资", "实现"] 207 ); 208 } 209 } 210