1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "pinyindictionary.h"
8 #include "libime/core/datrie.h"
9 #include "libime/core/lattice.h"
10 #include "libime/core/lrucache.h"
11 #include "libime/core/utils.h"
12 #include "pinyindata.h"
13 #include "pinyindecoder_p.h"
14 #include "pinyinencoder.h"
15 #include "pinyinmatchstate_p.h"
16 #include <boost/algorithm/string.hpp>
17 #include <boost/ptr_container/ptr_vector.hpp>
18 #include <boost/unordered_map.hpp>
19 #include <cmath>
20 #include <fstream>
21 #include <iomanip>
22 #include <queue>
23 #include <string_view>
24 #include <type_traits>
25 
26 namespace libime {
27 
28 static const float fuzzyCost = std::log10(0.5f);
29 static const size_t minimumLongWordLength = 3;
30 static const float invalidPinyinCost = -100.0f;
31 static const char pinyinHanziSep = '!';
32 
33 static constexpr uint32_t pinyinBinaryFormatMagic = 0x000fc613;
34 static constexpr uint32_t pinyinBinaryFormatVersion = 0x1;
35 
36 struct PinyinSegmentGraphPathHasher {
PinyinSegmentGraphPathHasherlibime::PinyinSegmentGraphPathHasher37     PinyinSegmentGraphPathHasher(const SegmentGraph &graph) : graph_(graph) {}
38 
39     // Generate a "|" separated raw pinyin string from given path, skip all
40     // separator.
pathToPinyinslibime::PinyinSegmentGraphPathHasher41     std::string pathToPinyins(const SegmentGraphPath &path) const {
42         std::string result;
43         result.reserve(path.size() + path.back()->index() -
44                        path.front()->index() + 1);
45         const auto &data = graph_.data();
46         auto iter = path.begin();
47         while (iter + 1 < path.end()) {
48             auto begin = (*iter)->index();
49             auto end = (*std::next(iter))->index();
50             iter++;
51             if (data[begin] == '\'') {
52                 continue;
53             }
54             while (begin < end) {
55                 result.push_back(data[begin]);
56                 begin++;
57             }
58             result.push_back('|');
59         }
60         return result;
61     }
62 
63     // Generate hash for path but avoid allocate the string.
operator ()libime::PinyinSegmentGraphPathHasher64     size_t operator()(const SegmentGraphPath &path) const {
65         if (path.size() <= 1) {
66             return 0;
67         }
68         boost::hash<char> hasher;
69 
70         size_t seed = 0;
71         const auto &data = graph_.data();
72         auto iter = path.begin();
73         while (iter + 1 < path.end()) {
74             auto begin = (*iter)->index();
75             auto end = (*std::next(iter))->index();
76             iter++;
77             if (data[begin] == '\'') {
78                 continue;
79             }
80             while (begin < end) {
81                 boost::hash_combine(seed, hasher(data[begin]));
82                 begin++;
83             }
84             boost::hash_combine(seed, hasher('|'));
85         }
86         return seed;
87     }
88 
89     // Check equality of pinyin string and the path. The string s should be
90     // equal to pathToPinyins(path), but this function just try to avoid
91     // allocate a string for comparison.
operator ()libime::PinyinSegmentGraphPathHasher92     bool operator()(const SegmentGraphPath &path, const std::string &s) const {
93         if (path.size() <= 1) {
94             return false;
95         }
96         auto is = s.begin();
97         const auto &data = graph_.data();
98         auto iter = path.begin();
99         while (iter + 1 < path.end() && is != s.end()) {
100             auto begin = (*iter)->index();
101             auto end = (*std::next(iter))->index();
102             iter++;
103             if (data[begin] == '\'') {
104                 continue;
105             }
106             while (begin < end && is != s.end()) {
107                 if (*is != data[begin]) {
108                     return false;
109                 }
110                 is++;
111                 begin++;
112             }
113             if (begin != end) {
114                 return false;
115             }
116 
117             if (is == s.end() || *is != '|') {
118                 return false;
119             }
120             is++;
121         }
122         return iter + 1 == path.end() && is == s.end();
123     }
124 
125 private:
126     const SegmentGraph &graph_;
127 };
128 
129 struct SegmentGraphNodeGreater {
operator ()libime::SegmentGraphNodeGreater130     bool operator()(const SegmentGraphNode *lhs,
131                     const SegmentGraphNode *rhs) const {
132         return lhs->index() > rhs->index();
133     }
134 };
135 
136 // Check if the prev not is a pinyin. Separator always contrains in its own
137 // segment.
prevIsSeparator(const SegmentGraph & graph,const SegmentGraphNode & node)138 const SegmentGraphNode *prevIsSeparator(const SegmentGraph &graph,
139                                         const SegmentGraphNode &node) {
140     if (node.prevSize() == 1) {
141         const auto &prev = node.prevs().front();
142         auto pinyin = graph.segment(prev, node);
143         if (boost::starts_with(pinyin, "\'")) {
144             return &prev;
145         }
146     }
147     return nullptr;
148 }
149 
150 class PinyinMatchContext {
151 public:
PinyinMatchContext(const SegmentGraph & graph,const GraphMatchCallback & callback,const std::unordered_set<const SegmentGraphNode * > & ignore,PinyinMatchState * matchState)152     explicit PinyinMatchContext(
153         const SegmentGraph &graph, const GraphMatchCallback &callback,
154         const std::unordered_set<const SegmentGraphNode *> &ignore,
155         PinyinMatchState *matchState)
156         : graph_(graph), hasher_(graph), callback_(callback), ignore_(ignore),
157           matchedPathsMap_(&matchState->d_func()->matchedPaths_),
158           nodeCacheMap_(&matchState->d_func()->nodeCacheMap_),
159           matchCacheMap_(&matchState->d_func()->matchCacheMap_),
160           flags_(matchState->fuzzyFlags()),
161           spProfile_(matchState->shuangpinProfile()),
162           partialLongWordLimit_(matchState->partialLongWordLimit()) {}
163 
PinyinMatchContext(const SegmentGraph & graph,const GraphMatchCallback & callback,const std::unordered_set<const SegmentGraphNode * > & ignore,NodeToMatchedPinyinPathsMap & matchedPaths)164     explicit PinyinMatchContext(
165         const SegmentGraph &graph, const GraphMatchCallback &callback,
166         const std::unordered_set<const SegmentGraphNode *> &ignore,
167         NodeToMatchedPinyinPathsMap &matchedPaths)
168         : graph_(graph), hasher_(graph), callback_(callback), ignore_(ignore),
169           matchedPathsMap_(&matchedPaths) {}
170 
171     PinyinMatchContext(const PinyinMatchContext &) = delete;
172 
173     const SegmentGraph &graph_;
174     PinyinSegmentGraphPathHasher hasher_;
175 
176     const GraphMatchCallback &callback_;
177     const std::unordered_set<const SegmentGraphNode *> &ignore_;
178     NodeToMatchedPinyinPathsMap *matchedPathsMap_;
179     PinyinTrieNodeCache *nodeCacheMap_ = nullptr;
180     PinyinMatchResultCache *matchCacheMap_ = nullptr;
181     PinyinFuzzyFlags flags_{PinyinFuzzyFlag::None};
182     std::shared_ptr<const ShuangpinProfile> spProfile_;
183     size_t partialLongWordLimit_ = 0;
184 };
185 
186 class PinyinDictionaryPrivate : fcitx::QPtrHolder<PinyinDictionary> {
187 public:
PinyinDictionaryPrivate(PinyinDictionary * q)188     PinyinDictionaryPrivate(PinyinDictionary *q)
189         : fcitx::QPtrHolder<PinyinDictionary>(q) {}
190 
191     void addEmptyMatch(const PinyinMatchContext &context,
192                        const SegmentGraphNode &currentNode,
193                        MatchedPinyinPaths &currentMatches) const;
194 
195     void findMatchesBetween(const PinyinMatchContext &context,
196                             const SegmentGraphNode &prevNode,
197                             const SegmentGraphNode &currentNode,
198                             MatchedPinyinPaths &currentMatches) const;
199 
200     bool matchWords(const PinyinMatchContext &context,
201                     const MatchedPinyinPaths &newPaths) const;
202     bool matchWordsForOnePath(const PinyinMatchContext &context,
203                               const MatchedPinyinPath &path) const;
204 
205     void matchNode(const PinyinMatchContext &context,
206                    const SegmentGraphNode &currentNode) const;
207 
208     fcitx::ScopedConnection conn_;
209     std::vector<PinyinDictFlags> flags_;
210 };
211 
addEmptyMatch(const PinyinMatchContext & context,const SegmentGraphNode & currentNode,MatchedPinyinPaths & currentMatches) const212 void PinyinDictionaryPrivate::addEmptyMatch(
213     const PinyinMatchContext &context, const SegmentGraphNode &currentNode,
214     MatchedPinyinPaths &currentMatches) const {
215     FCITX_Q();
216     const SegmentGraph &graph = context.graph_;
217     // Create a new starting point for current node, and put it in matchResult.
218     if (&currentNode != &graph.end() &&
219         !boost::starts_with(
220             graph.segment(currentNode.index(), currentNode.index() + 1),
221             "\'")) {
222         SegmentGraphPath vec;
223         if (const auto *prev = prevIsSeparator(graph, currentNode)) {
224             vec.push_back(prev);
225         }
226 
227         vec.push_back(&currentNode);
228         for (size_t i = 0; i < q->dictSize(); i++) {
229             if (flags_[i].test(PinyinDictFlag::FullMatch) &&
230                 &currentNode != &graph.start()) {
231                 continue;
232             }
233             const auto &trie = *q->trie(i);
234             currentMatches.emplace_back(&trie, 0, vec, flags_[i]);
235             currentMatches.back().triePositions().emplace_back(0, 0);
236         }
237     }
238 }
239 
240 PinyinTriePositions
traverseAlongPathOneStepBySyllables(const MatchedPinyinPath & path,const MatchedPinyinSyllables & syls)241 traverseAlongPathOneStepBySyllables(const MatchedPinyinPath &path,
242                                     const MatchedPinyinSyllables &syls) {
243     PinyinTriePositions positions;
244     for (const auto &pr : path.triePositions()) {
245         uint64_t _pos;
246         size_t fuzzies;
247         std::tie(_pos, fuzzies) = pr;
248         for (const auto &syl : syls) {
249             // make a copy
250             auto pos = _pos;
251             auto initial = static_cast<char>(syl.first);
252             auto result = path.trie()->traverse(&initial, 1, pos);
253             if (PinyinTrie::isNoPath(result)) {
254                 continue;
255             }
256             const auto &finals = syl.second;
257 
258             auto updateNext = [fuzzies, &path, &positions](auto finalPair,
259                                                            auto pos) {
260                 auto final = static_cast<char>(finalPair.first);
261                 auto result = path.trie()->traverse(&final, 1, pos);
262 
263                 if (!PinyinTrie::isNoPath(result)) {
264                     size_t newFuzzies = fuzzies + (finalPair.second ? 1 : 0);
265                     positions.emplace_back(pos, newFuzzies);
266                 }
267             };
268             if (finals.size() > 1 || finals[0].first != PinyinFinal::Invalid) {
269                 for (auto final : finals) {
270                     updateNext(final, pos);
271                 }
272             } else {
273                 for (char test = PinyinEncoder::firstFinal;
274                      test <= PinyinEncoder::lastFinal; test++) {
275                     updateNext(std::make_pair(test, true), pos);
276                 }
277             }
278         }
279     }
280     return positions;
281 }
282 
283 template <typename T>
matchWordsOnTrie(const MatchedPinyinPath & path,bool matchLongWord,const T & callback)284 void matchWordsOnTrie(const MatchedPinyinPath &path, bool matchLongWord,
285                       const T &callback) {
286     for (const auto &pr : path.triePositions()) {
287         uint64_t pos;
288         size_t fuzzies;
289         std::tie(pos, fuzzies) = pr;
290         float extraCost = fuzzies * fuzzyCost;
291         if (matchLongWord) {
292             path.trie()->foreach(
293                 [&path, &callback, extraCost](PinyinTrie::value_type value,
294                                               size_t len, uint64_t pos) {
295                     std::string s;
296                     s.reserve(len + path.size() * 2);
297                     path.trie()->suffix(s, len + path.size() * 2, pos);
298                     if (size_t separator =
299                             s.find(pinyinHanziSep, path.size() * 2);
300                         separator != std::string::npos) {
301                         std::string_view view(s);
302                         auto encodedPinyin = view.substr(0, separator);
303                         auto hanzi = view.substr(separator + 1);
304                         float overLengthCost =
305                             fuzzyCost *
306                             (encodedPinyin.size() / 2 - path.size());
307                         callback(encodedPinyin, hanzi,
308                                  value + extraCost + overLengthCost);
309                     }
310                     return true;
311                 },
312                 pos);
313         } else {
314             const char sep = pinyinHanziSep;
315             auto result = path.trie()->traverse(&sep, 1, pos);
316             if (PinyinTrie::isNoPath(result)) {
317                 continue;
318             }
319 
320             path.trie()->foreach(
321                 [&path, &callback, extraCost](PinyinTrie::value_type value,
322                                               size_t len, uint64_t pos) {
323                     std::string s;
324                     s.reserve(len + path.size() * 2 + 1);
325                     path.trie()->suffix(s, len + path.size() * 2 + 1, pos);
326                     std::string_view view(s);
327                     auto encodedPinyin = view.substr(0, path.size() * 2);
328                     auto hanzi = view.substr(path.size() * 2 + 1);
329                     callback(encodedPinyin, hanzi, value + extraCost);
330                     return true;
331                 },
332                 pos);
333         }
334     }
335 }
336 
matchWordsForOnePath(const PinyinMatchContext & context,const MatchedPinyinPath & path) const337 bool PinyinDictionaryPrivate::matchWordsForOnePath(
338     const PinyinMatchContext &context, const MatchedPinyinPath &path) const {
339     bool matched = false;
340     assert(path.path_.size() >= 2);
341     const SegmentGraphNode &prevNode = *path.path_[path.path_.size() - 2];
342 
343     if (path.flags_.test(PinyinDictFlag::FullMatch) &&
344         (path.path_.front() != &context.graph_.start() ||
345          path.path_.back() != &context.graph_.end())) {
346         return false;
347     }
348 
349     // minimumLongWordLength is to prevent algorithm runs too slow.
350     const bool matchLongWordEnabled =
351         context.partialLongWordLimit_ &&
352         std::max(minimumLongWordLength, context.partialLongWordLimit_) + 1 <=
353             path.path_.size() &&
354         !path.flags_.test(PinyinDictFlag::FullMatch);
355 
356     const bool matchLongWord =
357         (path.path_.back() == &context.graph_.end() && matchLongWordEnabled);
358 
359     auto foundOneWord = [&path, &prevNode, &matched,
360                          &context](std::string_view encodedPinyin,
361                                    WordNode &word, float cost) {
362         context.callback_(
363             path.path_, word, cost,
364             std::make_unique<PinyinLatticeNodePrivate>(encodedPinyin));
365         if (path.size() == 1 &&
366             path.path_[path.path_.size() - 2] == &prevNode) {
367             matched = true;
368         }
369     };
370 
371     if (context.matchCacheMap_) {
372         auto &matchCache = (*context.matchCacheMap_)[path.trie()];
373         auto *result =
374             matchCache.find(path.path_, context.hasher_, context.hasher_);
375         if (!result) {
376             result =
377                 matchCache.insert(context.hasher_.pathToPinyins(path.path_));
378             result->clear();
379 
380             auto &items = *result;
381             matchWordsOnTrie(path, matchLongWordEnabled,
382                              [&items](std::string_view encodedPinyin,
383                                       std::string_view hanzi, float cost) {
384                                  items.emplace_back(hanzi, cost, encodedPinyin);
385                              });
386         }
387         for (auto &item : *result) {
388             if (!matchLongWord &&
389                 item.encodedPinyin_.size() / 2 > path.size()) {
390                 continue;
391             }
392             foundOneWord(item.encodedPinyin_, item.word_, item.value_);
393         }
394     } else {
395         matchWordsOnTrie(path, matchLongWord,
396                          [&foundOneWord](std::string_view encodedPinyin,
397                                          std::string_view hanzi, float cost) {
398                              WordNode word(hanzi, InvalidWordIndex);
399                              foundOneWord(encodedPinyin, word, cost);
400                          });
401     }
402 
403     return matched;
404 }
405 
matchWords(const PinyinMatchContext & context,const MatchedPinyinPaths & newPaths) const406 bool PinyinDictionaryPrivate::matchWords(
407     const PinyinMatchContext &context,
408     const MatchedPinyinPaths &newPaths) const {
409     bool matched = false;
410     for (const auto &path : newPaths) {
411         matched |= matchWordsForOnePath(context, path);
412     }
413 
414     return matched;
415 }
416 
findMatchesBetween(const PinyinMatchContext & context,const SegmentGraphNode & prevNode,const SegmentGraphNode & currentNode,MatchedPinyinPaths & currentMatches) const417 void PinyinDictionaryPrivate::findMatchesBetween(
418     const PinyinMatchContext &context, const SegmentGraphNode &prevNode,
419     const SegmentGraphNode &currentNode,
420     MatchedPinyinPaths &currentMatches) const {
421     const SegmentGraph &graph = context.graph_;
422     auto &matchedPathsMap = *context.matchedPathsMap_;
423     auto pinyin = graph.segment(prevNode, currentNode);
424     // If predecessor is a separator, just copy every existing match result
425     // over and don't traverse on the trie.
426     if (boost::starts_with(pinyin, "\'")) {
427         const auto &prevMatches = matchedPathsMap[&prevNode];
428         for (const auto &match : prevMatches) {
429             // copy the path, and append current node.
430             auto path = match.path_;
431             path.push_back(&currentNode);
432             currentMatches.emplace_back(match.result_, std::move(path),
433                                         match.flags_);
434         }
435         // If the last segment is separator, there
436         if (&currentNode == &graph.end()) {
437             WordNode word("", 0);
438             context.callback_({&prevNode, &currentNode}, word, 0, nullptr);
439         }
440         return;
441     }
442 
443     const auto syls =
444         context.spProfile_
445             ? PinyinEncoder::shuangpinToSyllables(pinyin, *context.spProfile_,
446                                                   context.flags_)
447             : PinyinEncoder::stringToSyllables(pinyin, context.flags_);
448     const MatchedPinyinPaths &prevMatchedPaths = matchedPathsMap[&prevNode];
449     MatchedPinyinPaths newPaths;
450     for (const auto &path : prevMatchedPaths) {
451         // Make a copy of path so we can modify based on it.
452         auto segmentPath = path.path_;
453         segmentPath.push_back(&currentNode);
454 
455         // A map from trie (dict) to a lru cache.
456         if (context.nodeCacheMap_) {
457             auto &nodeCache = (*context.nodeCacheMap_)[path.trie()];
458             auto *p =
459                 nodeCache.find(segmentPath, context.hasher_, context.hasher_);
460             std::shared_ptr<MatchedPinyinTrieNodes> result;
461             if (!p) {
462                 result = std::make_shared<MatchedPinyinTrieNodes>(
463                     path.trie(), path.size() + 1);
464                 nodeCache.insert(context.hasher_.pathToPinyins(segmentPath),
465                                  result);
466                 result->triePositions_ =
467                     traverseAlongPathOneStepBySyllables(path, syls);
468             } else {
469                 result = *p;
470                 assert(result->size_ == path.size() + 1);
471             }
472 
473             if (!result->triePositions_.empty()) {
474                 newPaths.emplace_back(result, segmentPath, path.flags_);
475             }
476         } else {
477             // make an empty one
478             newPaths.emplace_back(path.trie(), path.size() + 1, segmentPath,
479                                   path.flags_);
480 
481             newPaths.back().result_->triePositions_ =
482                 traverseAlongPathOneStepBySyllables(path, syls);
483             // if there's nothing, pop it.
484             if (newPaths.back().triePositions().empty()) {
485                 newPaths.pop_back();
486             }
487         }
488     }
489 
490     if (!context.ignore_.count(&currentNode)) {
491         // after we match current syllable, we first try to match word.
492         if (!matchWords(context, newPaths)) {
493             // If we failed to match any length 1 word, add a new empty word
494             // to make lattice connect together.
495             SegmentGraphPath vec;
496             vec.reserve(3);
497             if (const auto *prevPrev =
498                     prevIsSeparator(context.graph_, prevNode)) {
499                 vec.push_back(prevPrev);
500             }
501             vec.push_back(&prevNode);
502             vec.push_back(&currentNode);
503             WordNode word(pinyin, InvalidWordIndex);
504             context.callback_(vec, word, invalidPinyinCost, nullptr);
505         }
506     }
507 
508     std::move(newPaths.begin(), newPaths.end(),
509               std::back_inserter(currentMatches));
510 }
511 
matchNode(const PinyinMatchContext & context,const SegmentGraphNode & currentNode) const512 void PinyinDictionaryPrivate::matchNode(
513     const PinyinMatchContext &context,
514     const SegmentGraphNode &currentNode) const {
515     auto &matchedPathsMap = *context.matchedPathsMap_;
516     // Check if the node has been searched already.
517     if (matchedPathsMap.count(&currentNode)) {
518         return;
519     }
520     auto &currentMatches = matchedPathsMap[&currentNode];
521     // To create a new start.
522     addEmptyMatch(context, currentNode, currentMatches);
523 
524     // Iterate all predecessor and search from them.
525     for (const auto &prevNode : currentNode.prevs()) {
526         findMatchesBetween(context, prevNode, currentNode, currentMatches);
527     }
528 }
529 
matchPrefixImpl(const SegmentGraph & graph,const GraphMatchCallback & callback,const std::unordered_set<const SegmentGraphNode * > & ignore,void * helper) const530 void PinyinDictionary::matchPrefixImpl(
531     const SegmentGraph &graph, const GraphMatchCallback &callback,
532     const std::unordered_set<const SegmentGraphNode *> &ignore,
533     void *helper) const {
534     FCITX_D();
535 
536     NodeToMatchedPinyinPathsMap localMatchedPaths;
537     PinyinMatchContext context =
538         helper ? PinyinMatchContext{graph, callback, ignore,
539                                     static_cast<PinyinMatchState *>(helper)}
540                : PinyinMatchContext{graph, callback, ignore, localMatchedPaths};
541 
542     // A queue to make sure that node with smaller index will be visted first
543     // because we want to make sure every predecessor node are visited before
544     // visit the current node.
545     using SegmentGraphNodeQueue =
546         std::priority_queue<const SegmentGraphNode *,
547                             std::vector<const SegmentGraphNode *>,
548                             SegmentGraphNodeGreater>;
549     SegmentGraphNodeQueue q;
550 
551     const auto &start = graph.start();
552     q.push(&start);
553 
554     // The match is done with a bfs.
555     // E.g
556     // xian is
557     // start - xi - an - end
558     //       \           /
559     //        -- xian ---
560     // We start with start, then xi, then an and xian, then end.
561     while (!q.empty()) {
562         const auto *currentNode = q.top();
563         q.pop();
564 
565         // Push successors into the queue.
566         for (const auto &node : currentNode->nexts()) {
567             q.push(&node);
568         }
569 
570         d->matchNode(context, *currentNode);
571     }
572 }
573 
matchWords(const char * data,size_t size,PinyinMatchCallback callback) const574 void PinyinDictionary::matchWords(const char *data, size_t size,
575                                   PinyinMatchCallback callback) const {
576     if (!PinyinEncoder::isValidUserPinyin(data, size)) {
577         return;
578     }
579 
580     std::list<std::pair<const PinyinTrie *, PinyinTrie::position_type>> nodes;
581     for (size_t i = 0; i < dictSize(); i++) {
582         const auto &trie = *this->trie(i);
583         nodes.emplace_back(&trie, 0);
584     }
585     for (size_t i = 0; i <= size && !nodes.empty(); i++) {
586         char current;
587         if (i < size) {
588             current = data[i];
589         } else {
590             current = pinyinHanziSep;
591         }
592         decltype(nodes) extraNodes;
593         auto iter = nodes.begin();
594         while (iter != nodes.end()) {
595             if (current != 0) {
596                 PinyinTrie::value_type result;
597                 result = iter->first->traverse(&current, 1, iter->second);
598 
599                 if (PinyinTrie::isNoPath(result)) {
600                     nodes.erase(iter++);
601                 } else {
602                     iter++;
603                 }
604             } else {
605                 bool changed = false;
606                 for (char test = PinyinEncoder::firstFinal;
607                      test <= PinyinEncoder::lastFinal; test++) {
608                     decltype(extraNodes)::value_type p = *iter;
609                     auto result = p.first->traverse(&test, 1, p.second);
610                     if (!PinyinTrie::isNoPath(result)) {
611                         extraNodes.push_back(p);
612                         changed = true;
613                     }
614                 }
615                 if (changed) {
616                     *iter = extraNodes.back();
617                     extraNodes.pop_back();
618                     iter++;
619                 } else {
620                     nodes.erase(iter++);
621                 }
622             }
623         }
624         nodes.splice(nodes.end(), std::move(extraNodes));
625     }
626 
627     for (auto &node : nodes) {
628         node.first->foreach(
629             [&node, &callback, size](PinyinTrie::value_type value, size_t len,
630                                      uint64_t pos) {
631                 std::string s;
632                 node.first->suffix(s, len + size + 1, pos);
633 
634                 auto view = std::string_view(s);
635                 return callback(s.substr(0, size), view.substr(size + 1),
636                                 value);
637             },
638             node.second);
639     }
640 }
PinyinDictionary()641 PinyinDictionary::PinyinDictionary()
642     : d_ptr(std::make_unique<PinyinDictionaryPrivate>(this)) {
643     FCITX_D();
644     d->conn_ = connect<TrieDictionary::dictSizeChanged>([this](size_t size) {
645         FCITX_D();
646         d->flags_.resize(size);
647     });
648     d->flags_.resize(dictSize());
649 }
650 
~PinyinDictionary()651 PinyinDictionary::~PinyinDictionary() {}
652 
load(size_t idx,const char * filename,PinyinDictFormat format)653 void PinyinDictionary::load(size_t idx, const char *filename,
654                             PinyinDictFormat format) {
655     std::ifstream in(filename, std::ios::in | std::ios::binary);
656     throw_if_io_fail(in);
657     load(idx, in, format);
658 }
659 
load(size_t idx,std::istream & in,PinyinDictFormat format)660 void PinyinDictionary::load(size_t idx, std::istream &in,
661                             PinyinDictFormat format) {
662     switch (format) {
663     case PinyinDictFormat::Text:
664         loadText(idx, in);
665         break;
666     case PinyinDictFormat::Binary:
667         loadBinary(idx, in);
668         break;
669     default:
670         throw std::invalid_argument("invalid format type");
671     }
672     emit<PinyinDictionary::dictionaryChanged>(idx);
673 }
674 
loadText(size_t idx,std::istream & in)675 void PinyinDictionary::loadText(size_t idx, std::istream &in) {
676     DATrie<float> trie;
677 
678     std::string buf;
679     auto isSpaceCheck = boost::is_any_of(" \n\t\r\v\f");
680     while (!in.eof()) {
681         if (!std::getline(in, buf)) {
682             break;
683         }
684 
685         boost::trim_if(buf, isSpaceCheck);
686         std::vector<std::string> tokens;
687         boost::split(tokens, buf, isSpaceCheck);
688         if (tokens.size() == 3 || tokens.size() == 2) {
689             const std::string &hanzi = tokens[0];
690             std::string_view pinyin = tokens[1];
691             float prob = 0.0F;
692             if (tokens.size() == 3) {
693                 prob = std::stof(tokens[2]);
694             }
695 
696             auto result = PinyinEncoder::encodeFullPinyin(pinyin);
697             result.push_back(pinyinHanziSep);
698             result.insert(result.end(), hanzi.begin(), hanzi.end());
699             trie.set(result.data(), result.size(), prob);
700         }
701     }
702     *mutableTrie(idx) = std::move(trie);
703 }
704 
loadBinary(size_t idx,std::istream & in)705 void PinyinDictionary::loadBinary(size_t idx, std::istream &in) {
706     DATrie<float> trie;
707     uint32_t magic = 0;
708     uint32_t version = 0;
709     throw_if_io_fail(unmarshall(in, magic));
710     if (magic != pinyinBinaryFormatMagic) {
711         throw std::invalid_argument("Invalid pinyin magic.");
712     }
713     throw_if_io_fail(unmarshall(in, version));
714     if (version != pinyinBinaryFormatVersion) {
715         throw std::invalid_argument("Invalid pinyin version.");
716     }
717     trie.load(in);
718     *mutableTrie(idx) = std::move(trie);
719 }
720 
save(size_t idx,const char * filename,PinyinDictFormat format)721 void PinyinDictionary::save(size_t idx, const char *filename,
722                             PinyinDictFormat format) {
723     std::ofstream fout(filename, std::ios::out | std::ios::binary);
724     throw_if_io_fail(fout);
725     save(idx, fout, format);
726 }
727 
save(size_t idx,std::ostream & out,PinyinDictFormat format)728 void PinyinDictionary::save(size_t idx, std::ostream &out,
729                             PinyinDictFormat format) {
730     switch (format) {
731     case PinyinDictFormat::Text:
732         saveText(idx, out);
733         break;
734     case PinyinDictFormat::Binary:
735         throw_if_io_fail(marshall(out, pinyinBinaryFormatMagic));
736         throw_if_io_fail(marshall(out, pinyinBinaryFormatVersion));
737         mutableTrie(idx)->save(out);
738         break;
739     default:
740         throw std::invalid_argument("invalid format type");
741     }
742 }
743 
saveText(size_t idx,std::ostream & out)744 void PinyinDictionary::saveText(size_t idx, std::ostream &out) {
745     std::string buf;
746     std::ios state(nullptr);
747     state.copyfmt(out);
748     const auto &trie = *this->trie(idx);
749     trie.foreach([&trie, &buf, &out](float value, size_t _len,
750                                      PinyinTrie::position_type pos) {
751         trie.suffix(buf, _len, pos);
752         auto sep = buf.find(pinyinHanziSep);
753         if (sep == std::string::npos) {
754             return true;
755         }
756         std::string_view ref(buf);
757         auto fullPinyin = PinyinEncoder::decodeFullPinyin(ref.data(), sep);
758         out << ref.substr(sep + 1) << " " << fullPinyin << " "
759             << std::setprecision(16) << value << std::endl;
760         return true;
761     });
762     out.copyfmt(state);
763 }
764 
addWord(size_t idx,std::string_view fullPinyin,std::string_view hanzi,float cost)765 void PinyinDictionary::addWord(size_t idx, std::string_view fullPinyin,
766                                std::string_view hanzi, float cost) {
767     auto result = PinyinEncoder::encodeFullPinyin(fullPinyin);
768     result.push_back(pinyinHanziSep);
769     result.insert(result.end(), hanzi.begin(), hanzi.end());
770     TrieDictionary::addWord(idx, std::string_view(result.data(), result.size()),
771                             cost);
772 }
773 
removeWord(size_t idx,std::string_view fullPinyin,std::string_view hanzi)774 bool PinyinDictionary::removeWord(size_t idx, std::string_view fullPinyin,
775                                   std::string_view hanzi) {
776     auto result = PinyinEncoder::encodeFullPinyin(fullPinyin);
777     result.push_back(pinyinHanziSep);
778     result.insert(result.end(), hanzi.begin(), hanzi.end());
779     return TrieDictionary::removeWord(
780         idx, std::string_view(result.data(), result.size()));
781 }
782 
setFlags(size_t idx,PinyinDictFlags flags)783 void PinyinDictionary::setFlags(size_t idx, PinyinDictFlags flags) {
784     FCITX_D();
785     if (idx >= dictSize()) {
786         return;
787     }
788     d->flags_.resize(dictSize());
789     d->flags_[idx] = flags;
790 }
791 } // namespace libime
792