1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chromeos/components/quick_answers/understanding/intent_generator.h"
6
7 #include <map>
8
9 #include "base/i18n/case_conversion.h"
10 #include "base/no_destructor.h"
11 #include "base/strings/string_split.h"
12 #include "base/strings/utf_string_conversions.h"
13 #include "chromeos/components/quick_answers/quick_answers_model.h"
14 #include "chromeos/components/quick_answers/utils/quick_answers_utils.h"
15 #include "chromeos/constants/chromeos_features.h"
16 #include "chromeos/services/machine_learning/public/cpp/service_connection.h"
17 #include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
18
19 namespace chromeos {
20 namespace quick_answers {
21 namespace {
22
23 using chromeos::machine_learning::mojom::LoadModelResult;
24 using machine_learning::mojom::TextAnnotationPtr;
25 using machine_learning::mojom::TextAnnotationRequestPtr;
26 using machine_learning::mojom::TextClassifier;
27
28 // TODO(llin): Finalize on the threshold based on user feedback.
29 constexpr int kUnitConversionIntentAndSelectionLengthDiffThreshold = 5;
30 constexpr int kTranslationTextLengthThreshold = 100;
31 constexpr int kDefinitionIntentAndSelectionLengthDiffThreshold = 2;
32
33 // TODO(b/169370175): Remove the temporary invalid set after we ramp up to v2
34 // model.
35 // Set of invalid characters for definition annonations.
36 constexpr char kInvalidCharactersSet[] = "()[]{}<>_&|!";
37
GetIntentTypeMap()38 const std::map<std::string, IntentType>& GetIntentTypeMap() {
39 static base::NoDestructor<std::map<std::string, IntentType>> kIntentTypeMap(
40 {{"unit", IntentType::kUnit}, {"dictionary", IntentType::kDictionary}});
41 return *kIntentTypeMap;
42 }
43
ExtractEntity(const std::string & selected_text,const std::vector<TextAnnotationPtr> & annotations,std::string * entity_str,std::string * type)44 bool ExtractEntity(const std::string& selected_text,
45 const std::vector<TextAnnotationPtr>& annotations,
46 std::string* entity_str,
47 std::string* type) {
48 for (auto& annotation : annotations) {
49 // The offset in annotation result is by chars instead of by bytes. Converts
50 // to string16 to support extracting substring from string with UTF-16
51 // characters.
52 *entity_str = base::UTF16ToUTF8(
53 base::UTF8ToUTF16(selected_text)
54 .substr(annotation->start_offset,
55 annotation->end_offset - annotation->start_offset));
56
57 // Use the first entity type.
58 auto intent_type_map = GetIntentTypeMap();
59 for (const auto& entity : annotation->entities) {
60 if (intent_type_map.find(entity->name) != intent_type_map.end()) {
61 *type = entity->name;
62 return true;
63 }
64 }
65 }
66
67 return false;
68 }
69
RewriteIntent(const std::string & selected_text,const std::string & entity_str,const IntentType intent)70 IntentType RewriteIntent(const std::string& selected_text,
71 const std::string& entity_str,
72 const IntentType intent) {
73 int intent_and_selection_length_diff =
74 base::UTF8ToUTF16(selected_text).length() -
75 base::UTF8ToUTF16(entity_str).length();
76 if ((intent == IntentType::kUnit &&
77 intent_and_selection_length_diff >
78 kUnitConversionIntentAndSelectionLengthDiffThreshold) ||
79 (intent == IntentType::kDictionary &&
80 intent_and_selection_length_diff >
81 kDefinitionIntentAndSelectionLengthDiffThreshold)) {
82 // Override intent type to |kUnknown| if length diff between intent
83 // text and selection text is above the threshold.
84 return IntentType::kUnknown;
85 }
86
87 return intent;
88 }
89
90 // TODO(b/169370175): There is an issue with text classifier that
91 // concatenated words are annotated as definitions. Before we switch to v2
92 // model, skip such kind of queries for definition annotation for now.
ShouldSkipDefinition(const std::string & text)93 bool ShouldSkipDefinition(const std::string& text) {
94 DCHECK(text.length());
95 // Skip the query for definition annotation if the selected text contains
96 // capitalized characters in the middle and not all capitalized.
97 const auto& text_utf16 = base::UTF8ToUTF16(text);
98 bool has_capitalized_middle_characters =
99 text_utf16.substr(1) != base::i18n::ToLower(text_utf16.substr(1));
100 bool are_all_characters_capitalized =
101 text_utf16 == base::i18n::ToUpper(text_utf16);
102 if (has_capitalized_middle_characters && !are_all_characters_capitalized)
103 return true;
104 // Skip the query for definition annotation if the selected text contains
105 // invalid characters.
106 if (text.find_first_of(kInvalidCharactersSet) != std::string::npos)
107 return true;
108
109 return false;
110 }
111
IsPreferredLanguage(const std::string & detected_locale,const std::string & preferred_languages_string)112 bool IsPreferredLanguage(const std::string& detected_locale,
113 const std::string& preferred_languages_string) {
114 auto preferred_languages =
115 base::SplitString(preferred_languages_string, ",", base::TRIM_WHITESPACE,
116 base::SPLIT_WANT_NONEMPTY);
117 return base::Contains(preferred_languages, detected_locale);
118 }
119
120 } // namespace
121
IntentGenerator(IntentGeneratorCallback complete_callback)122 IntentGenerator::IntentGenerator(IntentGeneratorCallback complete_callback)
123 : complete_callback_(std::move(complete_callback)) {
124 }
125
~IntentGenerator()126 IntentGenerator::~IntentGenerator() {
127 if (complete_callback_)
128 std::move(complete_callback_)
129 .Run(IntentInfo(std::string(), IntentType::kUnknown));
130 }
131
GenerateIntent(const QuickAnswersRequest & request)132 void IntentGenerator::GenerateIntent(const QuickAnswersRequest& request) {
133 if (!features::IsQuickAnswersTextAnnotatorEnabled()) {
134 std::move(complete_callback_)
135 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
136 return;
137 }
138
139 // Load text classifier.
140 chromeos::machine_learning::ServiceConnection::GetInstance()
141 ->LoadTextClassifier(text_classifier_.BindNewPipeAndPassReceiver(),
142 base::BindOnce(&IntentGenerator::LoadModelCallback,
143 weak_factory_.GetWeakPtr(), request));
144 }
145
LoadModelCallback(const QuickAnswersRequest & request,LoadModelResult result)146 void IntentGenerator::LoadModelCallback(const QuickAnswersRequest& request,
147 LoadModelResult result) {
148 if (result != LoadModelResult::OK) {
149 LOG(ERROR) << "Failed to load TextClassifier.";
150 std::move(complete_callback_)
151 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
152 return;
153 }
154
155 if (text_classifier_) {
156 language_detector_ =
157 std::make_unique<LanguageDetector>(text_classifier_.get());
158 language_detector_->DetectLanguage(
159 request.context.surrounding_text, request.selected_text,
160 base::BindOnce(&IntentGenerator::LanguageDetectorCallback,
161 weak_factory_.GetWeakPtr(), request));
162 }
163 }
164
AnnotationCallback(const QuickAnswersRequest & request,std::vector<TextAnnotationPtr> annotations)165 void IntentGenerator::AnnotationCallback(
166 const QuickAnswersRequest& request,
167 std::vector<TextAnnotationPtr> annotations) {
168 std::string entity_str;
169 std::string type;
170
171 if (ExtractEntity(request.selected_text, annotations, &entity_str, &type)) {
172 auto intent_type_map = GetIntentTypeMap();
173 auto it = intent_type_map.find(type);
174 if (it != intent_type_map.end()) {
175 // Skip the entity for definition annonation.
176 if (it->second == IntentType::kDictionary &&
177 ShouldSkipDefinition(request.selected_text)) {
178 std::move(complete_callback_)
179 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
180 return;
181 }
182 std::move(complete_callback_)
183 .Run(IntentInfo(entity_str, RewriteIntent(request.selected_text,
184 entity_str, it->second)));
185 return;
186 }
187 }
188 std::move(complete_callback_)
189 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
190 }
191
LanguageDetectorCallback(const QuickAnswersRequest & request,base::Optional<std::string> detected_locale)192 void IntentGenerator::LanguageDetectorCallback(
193 const QuickAnswersRequest& request,
194 base::Optional<std::string> detected_locale) {
195 language_detector_.reset();
196
197 // Generate translation intent if the detected language is different to the
198 // system language and is not one of the preferred languages.
199 if (detected_locale.has_value() &&
200 detected_locale.value() != request.context.device_properties.language &&
201 !IsPreferredLanguage(
202 detected_locale.value(),
203 request.context.device_properties.preferred_languages)) {
204 MaybeGenerateTranslationIntent(request, detected_locale.value());
205 return;
206 }
207
208 TextAnnotationRequestPtr text_annotation_request =
209 machine_learning::mojom::TextAnnotationRequest::New();
210
211 // TODO(b/159664194): There is a issue with text classifier that some
212 // capitalized words are not annotated properly. Convert the text to lower
213 // case for now. Clean up after the issue is fixed.
214 text_annotation_request->text = base::UTF16ToUTF8(
215 base::i18n::ToLower(base::UTF8ToUTF16(request.selected_text)));
216 text_annotation_request->default_locales =
217 request.context.device_properties.language;
218
219 text_classifier_->Annotate(
220 std::move(text_annotation_request),
221 base::BindOnce(&IntentGenerator::AnnotationCallback,
222 weak_factory_.GetWeakPtr(), request));
223 }
224
MaybeGenerateTranslationIntent(const QuickAnswersRequest & request,const std::string & detected_locale)225 void IntentGenerator::MaybeGenerateTranslationIntent(
226 const QuickAnswersRequest& request,
227 const std::string& detected_locale) {
228 DCHECK(complete_callback_);
229
230 if (!features::IsQuickAnswersTranslationEnabled()) {
231 std::move(complete_callback_)
232 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
233 return;
234 }
235
236 // Don't generate translation intent if no device language is provided or the
237 // length of selected text is above the threshold. Returns unknown intent
238 // type.
239 if (request.context.device_properties.language.empty() ||
240 request.selected_text.length() > kTranslationTextLengthThreshold) {
241 std::move(complete_callback_)
242 .Run(IntentInfo(request.selected_text, IntentType::kUnknown));
243 return;
244 }
245
246 std::move(complete_callback_)
247 .Run(IntentInfo(request.selected_text, IntentType::kTranslation,
248 detected_locale,
249 request.context.device_properties.language));
250 }
251
252 } // namespace quick_answers
253 } // namespace chromeos
254