1import functools
2
3from AnyQt import QtGui, QtCore
4from AnyQt.QtCore import pyqtSignal, QSize
5from AnyQt.QtWidgets import (QVBoxLayout, QButtonGroup, QRadioButton,
6                             QGroupBox, QTreeWidgetItem, QTreeWidget,
7                             QStyleOptionViewItem, QStyledItemDelegate, QStyle)
8
9from Orange.widgets import settings
10from Orange.widgets import gui
11from Orange.widgets.settings import DomainContextHandler
12from Orange.widgets.widget import OWWidget, Input, Output, Msg
13from Orange.data import Table
14from orangecontrib.text.corpus import Corpus
15from orangecontrib.text.topics import Topic, LdaWrapper, HdpWrapper, LsiWrapper
16from orangecontrib.text.widgets.utils.concurrent import asynchronous
17
18
19class TopicWidget(gui.OWComponent, QGroupBox):
20    Model = NotImplemented
21    valueChanged = pyqtSignal(object)
22
23    parameters = ()
24    spin_format = '{description}:'
25
26    def __init__(self, master, **kwargs):
27        QGroupBox.__init__(self, **kwargs)
28        gui.OWComponent.__init__(self, master)
29        self.model = self.create_model()
30        QVBoxLayout(self)
31        for parameter, description, minv, maxv, step, _type in self.parameters:
32            spin = gui.spin(self, self, parameter, minv=minv, maxv=maxv, step=step,
33                            label=self.spin_format.format(description=description, parameter=parameter),
34                            labelWidth=220, spinType=_type)
35            spin.clearFocus()
36            spin.editingFinished.connect(self.on_change)
37
38    def on_change(self):
39        self.model = self.create_model()
40        self.valueChanged.emit(self)
41
42    def create_model(self):
43        return self.Model(**{par[0]: getattr(self, par[0]) for par in self.parameters})
44
45    def report_model(self):
46        return self.model.name, ((par[1], getattr(self, par[0])) for par in self.parameters)
47
48
49class LdaWidget(TopicWidget):
50    Model = LdaWrapper
51
52    parameters = (
53        ('num_topics', 'Number of topics', 1, 500, 1, int),
54    )
55    num_topics = settings.Setting(10)
56
57
58class LsiWidget(TopicWidget):
59    Model = LsiWrapper
60
61    parameters = (
62        ('num_topics', 'Number of topics', 1, 500, 1, int),
63    )
64    num_topics = settings.Setting(10)
65
66
67class HdpWidget(TopicWidget):
68    Model = HdpWrapper
69
70    spin_format = '{description}:'
71    parameters = (
72        ('gamma', 'First level concentration (γ)', .1, 10, .5, float),
73        ('alpha', 'Second level concentration (α)', 1, 10, 1, int),
74        ('eta', 'The topic Dirichlet (α)', 0.001, .5, .01, float),
75        ('T', 'Top level truncation level (Τ)', 10, 150, 1, int),
76        ('K', 'Second level truncation level (Κ)', 1, 50, 1, int),
77        ('kappa', 'Learning rate (κ)', .1, 10., .1, float),
78        ('tau', 'Slow down parameter (τ)', 16., 256., 1., float),
79    )
80    gamma = settings.Setting(1)
81    alpha = settings.Setting(1)
82    eta = settings.Setting(.01)
83    T = settings.Setting(150)
84    K = settings.Setting(15)
85    kappa = settings.Setting(1)
86    tau = settings.Setting(64)
87
88
89def require(attribute):
90    def decorator(func):
91        @functools.wraps(func)
92        def wrapper(self, *args, **kwargs):
93            if getattr(self, attribute, None) is not None:
94                return func(self, *args, **kwargs)
95        return wrapper
96    return decorator
97
98
99class OWTopicModeling(OWWidget):
100    name = "Topic Modelling"
101    description = "Uncover the hidden thematic structure in a corpus."
102    icon = "icons/TopicModeling.svg"
103    priority = 400
104    keywords = ["LDA"]
105
106    settingsHandler = DomainContextHandler()
107
108    # Input/output
109    class Inputs:
110        corpus = Input("Corpus", Corpus)
111
112    class Outputs:
113        corpus = Output("Corpus", Table)
114        selected_topic = Output("Selected Topic", Topic)
115        all_topics = Output("All Topics", Table)
116
117    want_main_area = True
118
119    methods = [
120        (LsiWidget, 'lsi'),
121        (LdaWidget, 'lda'),
122        (HdpWidget, 'hdp'),
123    ]
124
125    # Settings
126    autocommit = settings.Setting(True)
127    method_index = settings.Setting(0)
128
129    lsi = settings.SettingProvider(LsiWidget)
130    hdp = settings.SettingProvider(HdpWidget)
131    lda = settings.SettingProvider(LdaWidget)
132
133    selection = settings.Setting(None, schema_only=True)
134
135    control_area_width = 300
136
137    class Warning(OWWidget.Warning):
138        less_topics_found = Msg('Less topics found than requested.')
139
140    def __init__(self):
141        super().__init__()
142        self.corpus = None
143        self.learning_thread = None
144        self.__pending_selection = self.selection
145
146        # Commit button
147        gui.auto_commit(self.buttonsArea, self, 'autocommit', 'Commit', box=False)
148
149        button_group = QButtonGroup(self, exclusive=True)
150        button_group.buttonClicked[int].connect(self.change_method)
151
152        self.widgets = []
153        method_layout = QVBoxLayout()
154        self.controlArea.layout().addLayout(method_layout)
155        for i, (method, attr_name) in enumerate(self.methods):
156            widget = method(self, title='Options')
157            widget.setFixedWidth(self.control_area_width)
158            widget.valueChanged.connect(self.commit)
159            self.widgets.append(widget)
160            setattr(self, attr_name, widget)
161
162            rb = QRadioButton(text=widget.Model.name)
163            button_group.addButton(rb, i)
164            method_layout.addWidget(rb)
165            method_layout.addWidget(widget)
166
167        button_group.button(self.method_index).setChecked(True)
168        self.toggle_widgets()
169        method_layout.addStretch()
170
171        # Topics description
172        self.topic_desc = TopicViewer()
173        self.topic_desc.topicSelected.connect(self.send_topic_by_id)
174        self.mainArea.layout().addWidget(self.topic_desc)
175        self.topic_desc.setFocus()
176
177    @Inputs.corpus
178    def set_data(self, data=None):
179        self.Warning.less_topics_found.clear()
180        self.corpus = data
181        self.apply()
182
183    def commit(self):
184        if self.corpus is not None:
185            self.apply()
186
187    @property
188    def model(self):
189        return self.widgets[self.method_index].model
190
191    def change_method(self, new_index):
192        if self.method_index != new_index:
193            self.method_index = new_index
194            self.toggle_widgets()
195            self.commit()
196
197    def toggle_widgets(self):
198        for i, widget in enumerate(self.widgets):
199            widget.setVisible(i == self.method_index)
200
201    def apply(self):
202        self.learning_task.stop()
203        if self.corpus is not None:
204            self.learning_task()
205        else:
206            self.on_result(None)
207
208    @asynchronous
209    def learning_task(self):
210        return self.model.fit_transform(self.corpus.copy(), chunk_number=100,
211                                        on_progress=self.on_progress)
212
213    @learning_task.on_start
214    def on_start(self):
215        self.Warning.less_topics_found.clear()
216        self.progressBarInit()
217        self.topic_desc.clear()
218
219    @learning_task.on_result
220    def on_result(self, corpus):
221        self.progressBarFinished()
222        self.Outputs.corpus.send(corpus)
223        if corpus is None:
224            self.topic_desc.clear()
225            self.Outputs.selected_topic.send(None)
226            self.Outputs.all_topics.send(None)
227        else:
228            self.topic_desc.show_model(self.model)
229            if self.__pending_selection:
230                self.topic_desc.select(self.__pending_selection)
231                self.__pending_selection = None
232            if self.model.actual_topics != self.model.num_topics:
233                self.Warning.less_topics_found()
234            self.Outputs.all_topics.send(self.model.get_all_topics_table())
235
236    @learning_task.callback
237    def on_progress(self, p):
238        self.progressBarSet(100 * p)
239
240    def send_report(self):
241        self.report_items(*self.widgets[self.method_index].report_model())
242        if self.corpus is not None:
243            self.report_items('Topics', self.topic_desc.report())
244
245    def send_topic_by_id(self, topic_id=None):
246        self.selection = topic_id
247        if self.model.model and topic_id is not None:
248            self.Outputs.selected_topic.send(
249                self.model.get_topics_table_by_id(topic_id))
250
251
252class TopicViewerTreeWidgetItem(QTreeWidgetItem):
253    def __init__(self, topic_id, words, weights, parent,
254                 color_by_weights=False):
255        super().__init__(parent)
256        self.topic_id = topic_id
257        self.words = words
258        self.weights = weights
259        self.color_by_weights = color_by_weights
260
261        self.setText(0, '{:d}'.format(topic_id + 1))
262        self.setText(1, ', '.join(self._color(word, weight)
263                                  for word, weight in zip(words, weights)))
264
265    def _color(self, word, weight):
266        if self.color_by_weights:
267            red = '#ff6600'
268            green = '#00cc00'
269            color = green if weight > 0 else red
270            return '<span style="color: {}">{}</span>'.format(color, word)
271        else:
272            return word
273
274    def report(self):
275        return self.text(0), self.text(1)
276
277
278class TopicViewer(QTreeWidget):
279    """ Just keeps stuff organized. Holds topic visualization widget and related functions.
280
281    """
282
283    columns = ['Topic', 'Topic keywords']
284    topicSelected = pyqtSignal(object)
285
286    def __init__(self):
287        super().__init__()
288        self.setColumnCount(len(self.columns))
289        self.setHeaderLabels(self.columns)
290        self.resize_columns()
291        self.itemSelectionChanged.connect(self.selected_topic_changed)
292        self.setItemDelegate(HTMLDelegate())    # enable colors
293        self.selected_id = None
294
295    def resize_columns(self):
296        for i in range(self.columnCount()):
297            self.resizeColumnToContents(i)
298
299    def show_model(self, topic_model):
300        self.clear()
301        if topic_model.model:
302            for i in range(topic_model.num_topics):
303                words, weights = topic_model.get_top_words_by_id(i)
304                if words:
305                    it = TopicViewerTreeWidgetItem(
306                        i, words, weights, self,
307                        color_by_weights=topic_model.has_negative_weights)
308                    self.addTopLevelItem(it)
309
310            self.resize_columns()
311            self.setCurrentItem(self.topLevelItem(0))
312
313    def selected_topic_changed(self):
314        selected = self.selectedItems()
315        if selected:
316            self.select(selected[0].topic_id)
317            self.topicSelected.emit(self.selected_id)
318        else:
319            self.topicSelected.emit(None)
320
321    def report(self):
322        root = self.invisibleRootItem()
323        child_count = root.childCount()
324        return [root.child(i).report()
325                for i in range(child_count)]
326
327    def sizeHint(self):
328        return QSize(700, 300)
329
330    def select(self, index):
331        self.selected_id = index
332        self.setCurrentItem(self.topLevelItem(index))
333
334
335class HTMLDelegate(QStyledItemDelegate):
336    """ This delegate enables coloring of words in QTreeWidgetItem.
337    Adopted from https://stackoverflow.com/a/5443112/892987 """
338    def paint(self, painter, option, index):
339        options = QStyleOptionViewItem(option)
340        self.initStyleOption(options,index)
341
342        style = QApplication.style() if options.widget is None else options.widget.style()
343
344        doc = QtGui.QTextDocument()
345        doc.setHtml(options.text)
346
347        options.text = ""
348        style.drawControl(QStyle.CE_ItemViewItem, options, painter)
349
350        ctx = QtGui.QAbstractTextDocumentLayout.PaintContext()
351
352        if options.state & QStyle.State_Selected:
353            ctx.palette.setColor(QtGui.QPalette.Text,
354                                 options.palette.color(QtGui.QPalette.Active,
355                                                       QtGui.QPalette.HighlightedText))
356
357        textRect = style.subElementRect(QStyle.SE_ItemViewItemText, options)
358        painter.save()
359        painter.translate(textRect.topLeft())
360        painter.setClipRect(textRect.translated(-textRect.topLeft()))
361        doc.documentLayout().draw(painter, ctx)
362
363        painter.restore()
364
365    def sizeHint(self, option, index):
366        options = QStyleOptionViewItem(option)
367        self.initStyleOption(options,index)
368
369        doc = QtGui.QTextDocument()
370        doc.setHtml(options.text)
371        doc.setTextWidth(options.rect.width())
372        return QtCore.QSize(int(doc.idealWidth()), int(doc.size().height()))
373
374
375if __name__ == '__main__':
376    from AnyQt.QtWidgets import QApplication
377
378    app = QApplication([])
379    widget = OWTopicModeling()
380    # widget.set_data(Corpus.from_file('book-excerpts'))
381    widget.set_data(Corpus.from_file('deerwester'))
382    widget.show()
383    app.exec()
384    widget.saveSettings()
385