1 // Copyright (C) 2020 by Yuri Victorovich. All rights reserved.
2
3 #include "nn-widget.h"
4 #include "svg-graphics-generator.h"
5
6 #include <QMouseEvent>
7 #include <QByteArray>
8
9
NnWidget(QWidget * parent)10 NnWidget::NnWidget(QWidget *parent)
11 : ZoomableSvgWidget(parent)
12 , model(nullptr)
13 {
14 }
15
16 /// interface
17
open(const PluginInterface::Model * model_)18 void NnWidget::open(const PluginInterface::Model *model_) {
19 load(SvgGraphics::generateModelSvg(model_,
20 {&modelIndexes.allOperatorBoxes, &modelIndexes.allTensorLabelBoxes, &modelIndexes.allInputBoxes, &modelIndexes.allOutputBoxes}));
21 model = model_;
22 }
23
close()24 void NnWidget::close() {
25 clearIndices();
26 model = nullptr;
27 load(QByteArray());
28 resize(0,0);
29 }
30
31 /// overridden
32
mousePressEvent(QMouseEvent * event)33 void NnWidget::mousePressEvent(QMouseEvent *event) {
34 if (model) {
35 auto searchResult = findObjectAtThePoint(event->pos());
36 if (searchResult.operatorId != -1)
37 emit clickedOnOperator((PluginInterface::TensorId)searchResult.operatorId);
38 else if (searchResult.innerTensorId != -1)
39 emit clickedOnTensorEdge((PluginInterface::TensorId)searchResult.innerTensorId);
40 else if (searchResult.inputTensorId != -1)
41 emit clickedOnInput((PluginInterface::TensorId)searchResult.inputTensorId);
42 else if (searchResult.outputTensorId != -1)
43 emit clickedOnOutput((PluginInterface::TensorId)searchResult.outputTensorId);
44 else
45 emit clickedOnBlankSpace();
46 }
47
48 // pass
49 ZoomableSvgWidget::mousePressEvent(event);
50 }
51
52 /// internals
53
clearIndices()54 void NnWidget::clearIndices() {
55 for (auto index : {&modelIndexes.allOperatorBoxes,&modelIndexes.allTensorLabelBoxes,&modelIndexes.allInputBoxes,&modelIndexes.allOutputBoxes})
56 index->clear();
57 }
58
findObjectAtThePoint(const QPointF & pt) const59 NnWidget::AnyObject NnWidget::findObjectAtThePoint(const QPointF &pt) const {
60 const QPointF pts = pt/getScalingFactor();
61
62 // XXX ad hoc algorithm until we find some good geoindexing implementation
63
64 // operator box?
65 for (PluginInterface::OperatorId oid = 0, oide = modelIndexes.allOperatorBoxes.size(); oid < oide; oid++)
66 if (modelIndexes.allOperatorBoxes[oid].contains(pts))
67 return {(int)oid,-1,-1,-1};
68
69 // tensor label?
70 for (PluginInterface::TensorId tid = 0, tide = modelIndexes.allTensorLabelBoxes.size(); tid < tide; tid++)
71 if (modelIndexes.allTensorLabelBoxes[tid].contains(pts))
72 return {-1,(int)tid,-1,-1};
73
74 // input box?
75 for (unsigned idx = 0, idxe = modelIndexes.allInputBoxes.size(); idx < idxe; idx++)
76 if (modelIndexes.allInputBoxes[idx].contains(pts))
77 return {-1,-1,(int)model->getInputs()[idx],-1};
78
79 // output box?
80 for (unsigned idx = 0, idxe = modelIndexes.allOutputBoxes.size(); idx < idxe; idx++)
81 if (modelIndexes.allOutputBoxes[idx].contains(pts))
82 return {-1,-1,-1,(int)model->getOutputs()[idx]};
83
84 return {-1,-1,-1,-1}; // not found
85 }
86
87