1import functools 2 3from AnyQt import QtGui, QtCore 4from AnyQt.QtCore import pyqtSignal, QSize 5from AnyQt.QtWidgets import (QVBoxLayout, QButtonGroup, QRadioButton, 6 QGroupBox, QTreeWidgetItem, QTreeWidget, 7 QStyleOptionViewItem, QStyledItemDelegate, QStyle) 8 9from Orange.widgets import settings 10from Orange.widgets import gui 11from Orange.widgets.settings import DomainContextHandler 12from Orange.widgets.widget import OWWidget, Input, Output, Msg 13from Orange.data import Table 14from orangecontrib.text.corpus import Corpus 15from orangecontrib.text.topics import Topic, LdaWrapper, HdpWrapper, LsiWrapper 16from orangecontrib.text.widgets.utils.concurrent import asynchronous 17 18 19class TopicWidget(gui.OWComponent, QGroupBox): 20 Model = NotImplemented 21 valueChanged = pyqtSignal(object) 22 23 parameters = () 24 spin_format = '{description}:' 25 26 def __init__(self, master, **kwargs): 27 QGroupBox.__init__(self, **kwargs) 28 gui.OWComponent.__init__(self, master) 29 self.model = self.create_model() 30 QVBoxLayout(self) 31 for parameter, description, minv, maxv, step, _type in self.parameters: 32 spin = gui.spin(self, self, parameter, minv=minv, maxv=maxv, step=step, 33 label=self.spin_format.format(description=description, parameter=parameter), 34 labelWidth=220, spinType=_type) 35 spin.clearFocus() 36 spin.editingFinished.connect(self.on_change) 37 38 def on_change(self): 39 self.model = self.create_model() 40 self.valueChanged.emit(self) 41 42 def create_model(self): 43 return self.Model(**{par[0]: getattr(self, par[0]) for par in self.parameters}) 44 45 def report_model(self): 46 return self.model.name, ((par[1], getattr(self, par[0])) for par in self.parameters) 47 48 49class LdaWidget(TopicWidget): 50 Model = LdaWrapper 51 52 parameters = ( 53 ('num_topics', 'Number of topics', 1, 500, 1, int), 54 ) 55 num_topics = settings.Setting(10) 56 57 58class LsiWidget(TopicWidget): 59 Model = LsiWrapper 60 61 parameters = ( 62 ('num_topics', 'Number of topics', 1, 500, 1, int), 63 ) 64 num_topics = settings.Setting(10) 65 66 67class HdpWidget(TopicWidget): 68 Model = HdpWrapper 69 70 spin_format = '{description}:' 71 parameters = ( 72 ('gamma', 'First level concentration (γ)', .1, 10, .5, float), 73 ('alpha', 'Second level concentration (α)', 1, 10, 1, int), 74 ('eta', 'The topic Dirichlet (α)', 0.001, .5, .01, float), 75 ('T', 'Top level truncation level (Τ)', 10, 150, 1, int), 76 ('K', 'Second level truncation level (Κ)', 1, 50, 1, int), 77 ('kappa', 'Learning rate (κ)', .1, 10., .1, float), 78 ('tau', 'Slow down parameter (τ)', 16., 256., 1., float), 79 ) 80 gamma = settings.Setting(1) 81 alpha = settings.Setting(1) 82 eta = settings.Setting(.01) 83 T = settings.Setting(150) 84 K = settings.Setting(15) 85 kappa = settings.Setting(1) 86 tau = settings.Setting(64) 87 88 89def require(attribute): 90 def decorator(func): 91 @functools.wraps(func) 92 def wrapper(self, *args, **kwargs): 93 if getattr(self, attribute, None) is not None: 94 return func(self, *args, **kwargs) 95 return wrapper 96 return decorator 97 98 99class OWTopicModeling(OWWidget): 100 name = "Topic Modelling" 101 description = "Uncover the hidden thematic structure in a corpus." 102 icon = "icons/TopicModeling.svg" 103 priority = 400 104 keywords = ["LDA"] 105 106 settingsHandler = DomainContextHandler() 107 108 # Input/output 109 class Inputs: 110 corpus = Input("Corpus", Corpus) 111 112 class Outputs: 113 corpus = Output("Corpus", Table) 114 selected_topic = Output("Selected Topic", Topic) 115 all_topics = Output("All Topics", Table) 116 117 want_main_area = True 118 119 methods = [ 120 (LsiWidget, 'lsi'), 121 (LdaWidget, 'lda'), 122 (HdpWidget, 'hdp'), 123 ] 124 125 # Settings 126 autocommit = settings.Setting(True) 127 method_index = settings.Setting(0) 128 129 lsi = settings.SettingProvider(LsiWidget) 130 hdp = settings.SettingProvider(HdpWidget) 131 lda = settings.SettingProvider(LdaWidget) 132 133 selection = settings.Setting(None, schema_only=True) 134 135 control_area_width = 300 136 137 class Warning(OWWidget.Warning): 138 less_topics_found = Msg('Less topics found than requested.') 139 140 def __init__(self): 141 super().__init__() 142 self.corpus = None 143 self.learning_thread = None 144 self.__pending_selection = self.selection 145 146 # Commit button 147 gui.auto_commit(self.buttonsArea, self, 'autocommit', 'Commit', box=False) 148 149 button_group = QButtonGroup(self, exclusive=True) 150 button_group.buttonClicked[int].connect(self.change_method) 151 152 self.widgets = [] 153 method_layout = QVBoxLayout() 154 self.controlArea.layout().addLayout(method_layout) 155 for i, (method, attr_name) in enumerate(self.methods): 156 widget = method(self, title='Options') 157 widget.setFixedWidth(self.control_area_width) 158 widget.valueChanged.connect(self.commit) 159 self.widgets.append(widget) 160 setattr(self, attr_name, widget) 161 162 rb = QRadioButton(text=widget.Model.name) 163 button_group.addButton(rb, i) 164 method_layout.addWidget(rb) 165 method_layout.addWidget(widget) 166 167 button_group.button(self.method_index).setChecked(True) 168 self.toggle_widgets() 169 method_layout.addStretch() 170 171 # Topics description 172 self.topic_desc = TopicViewer() 173 self.topic_desc.topicSelected.connect(self.send_topic_by_id) 174 self.mainArea.layout().addWidget(self.topic_desc) 175 self.topic_desc.setFocus() 176 177 @Inputs.corpus 178 def set_data(self, data=None): 179 self.Warning.less_topics_found.clear() 180 self.corpus = data 181 self.apply() 182 183 def commit(self): 184 if self.corpus is not None: 185 self.apply() 186 187 @property 188 def model(self): 189 return self.widgets[self.method_index].model 190 191 def change_method(self, new_index): 192 if self.method_index != new_index: 193 self.method_index = new_index 194 self.toggle_widgets() 195 self.commit() 196 197 def toggle_widgets(self): 198 for i, widget in enumerate(self.widgets): 199 widget.setVisible(i == self.method_index) 200 201 def apply(self): 202 self.learning_task.stop() 203 if self.corpus is not None: 204 self.learning_task() 205 else: 206 self.on_result(None) 207 208 @asynchronous 209 def learning_task(self): 210 return self.model.fit_transform(self.corpus.copy(), chunk_number=100, 211 on_progress=self.on_progress) 212 213 @learning_task.on_start 214 def on_start(self): 215 self.Warning.less_topics_found.clear() 216 self.progressBarInit() 217 self.topic_desc.clear() 218 219 @learning_task.on_result 220 def on_result(self, corpus): 221 self.progressBarFinished() 222 self.Outputs.corpus.send(corpus) 223 if corpus is None: 224 self.topic_desc.clear() 225 self.Outputs.selected_topic.send(None) 226 self.Outputs.all_topics.send(None) 227 else: 228 self.topic_desc.show_model(self.model) 229 if self.__pending_selection: 230 self.topic_desc.select(self.__pending_selection) 231 self.__pending_selection = None 232 if self.model.actual_topics != self.model.num_topics: 233 self.Warning.less_topics_found() 234 self.Outputs.all_topics.send(self.model.get_all_topics_table()) 235 236 @learning_task.callback 237 def on_progress(self, p): 238 self.progressBarSet(100 * p) 239 240 def send_report(self): 241 self.report_items(*self.widgets[self.method_index].report_model()) 242 if self.corpus is not None: 243 self.report_items('Topics', self.topic_desc.report()) 244 245 def send_topic_by_id(self, topic_id=None): 246 self.selection = topic_id 247 if self.model.model and topic_id is not None: 248 self.Outputs.selected_topic.send( 249 self.model.get_topics_table_by_id(topic_id)) 250 251 252class TopicViewerTreeWidgetItem(QTreeWidgetItem): 253 def __init__(self, topic_id, words, weights, parent, 254 color_by_weights=False): 255 super().__init__(parent) 256 self.topic_id = topic_id 257 self.words = words 258 self.weights = weights 259 self.color_by_weights = color_by_weights 260 261 self.setText(0, '{:d}'.format(topic_id + 1)) 262 self.setText(1, ', '.join(self._color(word, weight) 263 for word, weight in zip(words, weights))) 264 265 def _color(self, word, weight): 266 if self.color_by_weights: 267 red = '#ff6600' 268 green = '#00cc00' 269 color = green if weight > 0 else red 270 return '<span style="color: {}">{}</span>'.format(color, word) 271 else: 272 return word 273 274 def report(self): 275 return self.text(0), self.text(1) 276 277 278class TopicViewer(QTreeWidget): 279 """ Just keeps stuff organized. Holds topic visualization widget and related functions. 280 281 """ 282 283 columns = ['Topic', 'Topic keywords'] 284 topicSelected = pyqtSignal(object) 285 286 def __init__(self): 287 super().__init__() 288 self.setColumnCount(len(self.columns)) 289 self.setHeaderLabels(self.columns) 290 self.resize_columns() 291 self.itemSelectionChanged.connect(self.selected_topic_changed) 292 self.setItemDelegate(HTMLDelegate()) # enable colors 293 self.selected_id = None 294 295 def resize_columns(self): 296 for i in range(self.columnCount()): 297 self.resizeColumnToContents(i) 298 299 def show_model(self, topic_model): 300 self.clear() 301 if topic_model.model: 302 for i in range(topic_model.num_topics): 303 words, weights = topic_model.get_top_words_by_id(i) 304 if words: 305 it = TopicViewerTreeWidgetItem( 306 i, words, weights, self, 307 color_by_weights=topic_model.has_negative_weights) 308 self.addTopLevelItem(it) 309 310 self.resize_columns() 311 self.setCurrentItem(self.topLevelItem(0)) 312 313 def selected_topic_changed(self): 314 selected = self.selectedItems() 315 if selected: 316 self.select(selected[0].topic_id) 317 self.topicSelected.emit(self.selected_id) 318 else: 319 self.topicSelected.emit(None) 320 321 def report(self): 322 root = self.invisibleRootItem() 323 child_count = root.childCount() 324 return [root.child(i).report() 325 for i in range(child_count)] 326 327 def sizeHint(self): 328 return QSize(700, 300) 329 330 def select(self, index): 331 self.selected_id = index 332 self.setCurrentItem(self.topLevelItem(index)) 333 334 335class HTMLDelegate(QStyledItemDelegate): 336 """ This delegate enables coloring of words in QTreeWidgetItem. 337 Adopted from https://stackoverflow.com/a/5443112/892987 """ 338 def paint(self, painter, option, index): 339 options = QStyleOptionViewItem(option) 340 self.initStyleOption(options,index) 341 342 style = QApplication.style() if options.widget is None else options.widget.style() 343 344 doc = QtGui.QTextDocument() 345 doc.setHtml(options.text) 346 347 options.text = "" 348 style.drawControl(QStyle.CE_ItemViewItem, options, painter) 349 350 ctx = QtGui.QAbstractTextDocumentLayout.PaintContext() 351 352 if options.state & QStyle.State_Selected: 353 ctx.palette.setColor(QtGui.QPalette.Text, 354 options.palette.color(QtGui.QPalette.Active, 355 QtGui.QPalette.HighlightedText)) 356 357 textRect = style.subElementRect(QStyle.SE_ItemViewItemText, options) 358 painter.save() 359 painter.translate(textRect.topLeft()) 360 painter.setClipRect(textRect.translated(-textRect.topLeft())) 361 doc.documentLayout().draw(painter, ctx) 362 363 painter.restore() 364 365 def sizeHint(self, option, index): 366 options = QStyleOptionViewItem(option) 367 self.initStyleOption(options,index) 368 369 doc = QtGui.QTextDocument() 370 doc.setHtml(options.text) 371 doc.setTextWidth(options.rect.width()) 372 return QtCore.QSize(int(doc.idealWidth()), int(doc.size().height())) 373 374 375if __name__ == '__main__': 376 from AnyQt.QtWidgets import QApplication 377 378 app = QApplication([]) 379 widget = OWTopicModeling() 380 # widget.set_data(Corpus.from_file('book-excerpts')) 381 widget.set_data(Corpus.from_file('deerwester')) 382 widget.show() 383 app.exec() 384 widget.saveSettings() 385