1 /****************************************************************************
2 **
3 ** Copyright (C) 2017 Klaralvdalens Datakonsult AB (KDAB).
4 ** Contact: https://www.qt.io/licensing/
5 **
6 ** This file is part of the QtGui module of the Qt Toolkit.
7 **
8 ** $QT_BEGIN_LICENSE:LGPL$
9 ** Commercial License Usage
10 ** Licensees holding valid commercial Qt licenses may use this file in
11 ** accordance with the commercial license agreement provided with the
12 ** Software or, alternatively, in accordance with the terms contained in
13 ** a written agreement between you and The Qt Company. For licensing terms
14 ** and conditions see https://www.qt.io/terms-conditions. For further
15 ** information use the contact form at https://www.qt.io/contact-us.
16 **
17 ** GNU Lesser General Public License Usage
18 ** Alternatively, this file may be used under the terms of the GNU Lesser
19 ** General Public License version 3 as published by the Free Software
20 ** Foundation and appearing in the file LICENSE.LGPL3 included in the
21 ** packaging of this file. Please review the following information to
22 ** ensure the GNU Lesser General Public License version 3 requirements
23 ** will be met: https://www.gnu.org/licenses/lgpl-3.0.html.
24 **
25 ** GNU General Public License Usage
26 ** Alternatively, this file may be used under the terms of the GNU
27 ** General Public License version 2.0 or (at your option) the GNU General
28 ** Public license version 3 or any later version approved by the KDE Free
29 ** Qt Foundation. The licenses are as published by the Free Software
30 ** Foundation and appearing in the file LICENSE.GPL2 and LICENSE.GPL3
31 ** included in the packaging of this file. Please review the following
32 ** information to ensure the GNU General Public License requirements will
33 ** be met: https://www.gnu.org/licenses/gpl-2.0.html and
34 ** https://www.gnu.org/licenses/gpl-3.0.html.
35 **
36 ** $QT_END_LICENSE$
37 **
38 ****************************************************************************/
39 
40 #include "qshadergraph_p.h"
41 
42 QT_BEGIN_NAMESPACE
43 
44 
45 namespace
46 {
copyOutputNodes(const QVector<QShaderNode> & nodes,const QVector<QShaderGraph::Edge> & edges)47     QVector<QShaderNode> copyOutputNodes(const QVector<QShaderNode> &nodes, const QVector<QShaderGraph::Edge> &edges)
48     {
49         auto res = QVector<QShaderNode>();
50         std::copy_if(nodes.cbegin(), nodes.cend(),
51                      std::back_inserter(res),
52                      [&edges] (const QShaderNode &node) {
53                          return node.type() == QShaderNode::Output ||
54                                 (node.type() == QShaderNode::Function &&
55                                  !std::any_of(edges.cbegin(),
56                                               edges.cend(),
57                                               [&node] (const QShaderGraph::Edge &edge) {
58                                                   return edge.sourceNodeUuid ==
59                                                          node.uuid();
60                                               }));
61                      });
62         return res;
63     }
64 
incomingEdges(const QVector<QShaderGraph::Edge> & edges,const QUuid & uuid)65     QVector<QShaderGraph::Edge> incomingEdges(const QVector<QShaderGraph::Edge> &edges, const QUuid &uuid)
66     {
67         auto res = QVector<QShaderGraph::Edge>();
68         std::copy_if(edges.cbegin(), edges.cend(),
69                      std::back_inserter(res),
70                      [uuid] (const QShaderGraph::Edge &edge) {
71                          return edge.sourceNodeUuid == uuid;
72                      });
73         return res;
74     }
75 
outgoingEdges(const QVector<QShaderGraph::Edge> & edges,const QUuid & uuid)76     QVector<QShaderGraph::Edge> outgoingEdges(const QVector<QShaderGraph::Edge> &edges, const QUuid &uuid)
77     {
78         auto res = QVector<QShaderGraph::Edge>();
79         std::copy_if(edges.cbegin(), edges.cend(),
80                      std::back_inserter(res),
81                      [uuid] (const QShaderGraph::Edge &edge) {
82                          return edge.targetNodeUuid == uuid;
83                      });
84         return res;
85     }
86 
nodeToStatement(const QShaderNode & node,int & nextVarId)87     QShaderGraph::Statement nodeToStatement(const QShaderNode &node, int &nextVarId)
88     {
89         auto statement = QShaderGraph::Statement();
90         statement.node = node;
91 
92         const QVector<QShaderNodePort> ports = node.ports();
93         for (const QShaderNodePort &port : ports) {
94             if (port.direction == QShaderNodePort::Input) {
95                 statement.inputs.append(-1);
96             } else {
97                 statement.outputs.append(nextVarId);
98                 nextVarId++;
99             }
100         }
101         return statement;
102     }
103 
completeStatement(const QHash<QUuid,QShaderGraph::Statement> & idHash,const QVector<QShaderGraph::Edge> edges,const QUuid & uuid)104     QShaderGraph::Statement completeStatement(const QHash<QUuid, QShaderGraph::Statement> &idHash,
105                                               const QVector<QShaderGraph::Edge> edges,
106                                               const QUuid &uuid)
107     {
108         auto targetStatement = idHash.value(uuid);
109         for (const QShaderGraph::Edge &edge : edges) {
110             if (edge.targetNodeUuid != uuid)
111                 continue;
112 
113             const QShaderGraph::Statement sourceStatement = idHash.value(edge.sourceNodeUuid);
114             const int sourcePortIndex = sourceStatement.portIndex(QShaderNodePort::Output, edge.sourcePortName);
115             const int targetPortIndex = targetStatement.portIndex(QShaderNodePort::Input, edge.targetPortName);
116 
117             if (sourcePortIndex < 0 || targetPortIndex < 0)
118                 continue;
119 
120             const QVector<int> sourceOutputs = sourceStatement.outputs;
121             QVector<int> &targetInputs = targetStatement.inputs;
122             targetInputs[targetPortIndex] = sourceOutputs[sourcePortIndex];
123         }
124         return targetStatement;
125     }
126 
removeNodesWithUnboundInputs(QVector<QShaderGraph::Statement> & statements,const QVector<QShaderGraph::Edge> & allEdges)127     void removeNodesWithUnboundInputs(QVector<QShaderGraph::Statement> &statements,
128                                       const QVector<QShaderGraph::Edge> &allEdges)
129     {
130         // A node is invalid if any of its input ports is disconected
131         // or connected to the output port of another invalid node.
132 
133         // Keeps track of the edges from the nodes we know to be valid
134         // to unvisited nodes
135         auto currentEdges = QVector<QShaderGraph::Edge>();
136 
137         statements.erase(std::remove_if(statements.begin(),
138                                         statements.end(),
139                                         [&currentEdges, &allEdges] (const QShaderGraph::Statement &statement) {
140             const QShaderNode &node = statement.node;
141             const QVector<QShaderGraph::Edge> outgoing = outgoingEdges(currentEdges, node.uuid());
142             const QVector<QShaderNodePort> ports = node.ports();
143 
144             bool allInputsConnected = true;
145             for (const QShaderNodePort &port : node.ports()) {
146                 if (port.direction == QShaderNodePort::Output)
147                     continue;
148 
149                 const auto edgeIt = std::find_if(outgoing.cbegin(),
150                                                  outgoing.cend(),
151                                                  [&port] (const QShaderGraph::Edge &edge) {
152                     return edge.targetPortName == port.name;
153                 });
154 
155                 if (edgeIt != outgoing.cend())
156                     currentEdges.removeAll(*edgeIt);
157                 else
158                     allInputsConnected = false;
159             }
160 
161             if (allInputsConnected) {
162                 const QVector<QShaderGraph::Edge> incoming = incomingEdges(allEdges, node.uuid());
163                 currentEdges.append(incoming);
164             }
165 
166             return !allInputsConnected;
167         }),
168                          statements.end());
169     }
170 }
171 
uuid() const172 QUuid QShaderGraph::Statement::uuid() const noexcept
173 {
174     return node.uuid();
175 }
176 
portIndex(QShaderNodePort::Direction direction,const QString & portName) const177 int QShaderGraph::Statement::portIndex(QShaderNodePort::Direction direction, const QString &portName) const noexcept
178 {
179     const QVector<QShaderNodePort> ports = node.ports();
180     int index = 0;
181     for (const QShaderNodePort &port : ports) {
182         if (port.name == portName && port.direction == direction)
183             return index;
184         else if (port.direction == direction)
185             index++;
186     }
187     return -1;
188 }
189 
addNode(const QShaderNode & node)190 void QShaderGraph::addNode(const QShaderNode &node)
191 {
192     removeNode(node);
193     m_nodes.append(node);
194 }
195 
removeNode(const QShaderNode & node)196 void QShaderGraph::removeNode(const QShaderNode &node)
197 {
198     const auto it = std::find_if(m_nodes.begin(), m_nodes.end(),
199                                  [node] (const QShaderNode &n) { return n.uuid() == node.uuid(); });
200     if (it != m_nodes.end())
201         m_nodes.erase(it);
202 }
203 
nodes() const204 QVector<QShaderNode> QShaderGraph::nodes() const noexcept
205 {
206     return m_nodes;
207 }
208 
addEdge(const QShaderGraph::Edge & edge)209 void QShaderGraph::addEdge(const QShaderGraph::Edge &edge)
210 {
211     if (m_edges.contains(edge))
212         return;
213     m_edges.append(edge);
214 }
215 
removeEdge(const QShaderGraph::Edge & edge)216 void QShaderGraph::removeEdge(const QShaderGraph::Edge &edge)
217 {
218     m_edges.removeAll(edge);
219 }
220 
edges() const221 QVector<QShaderGraph::Edge> QShaderGraph::edges() const noexcept
222 {
223     return m_edges;
224 }
225 
createStatements(const QStringList & enabledLayers) const226 QVector<QShaderGraph::Statement> QShaderGraph::createStatements(const QStringList &enabledLayers) const
227 {
228     const auto intersectsEnabledLayers = [enabledLayers] (const QStringList &layers) {
229         return layers.isEmpty()
230             || std::any_of(layers.cbegin(), layers.cend(),
231                            [enabledLayers] (const QString &s) { return enabledLayers.contains(s); });
232     };
233 
234     const QVector<QShaderNode> enabledNodes = [this, intersectsEnabledLayers] {
235         auto res = QVector<QShaderNode>();
236         std::copy_if(m_nodes.cbegin(), m_nodes.cend(),
237                      std::back_inserter(res),
238                      [intersectsEnabledLayers] (const QShaderNode &node) {
239                          return intersectsEnabledLayers(node.layers());
240                      });
241         return res;
242     }();
243 
244     const QVector<Edge> enabledEdges = [this, intersectsEnabledLayers] {
245         auto res = QVector<Edge>();
246         std::copy_if(m_edges.cbegin(), m_edges.cend(),
247                      std::back_inserter(res),
248                      [intersectsEnabledLayers] (const Edge &edge) {
249                          return intersectsEnabledLayers(edge.layers);
250                      });
251         return res;
252     }();
253 
254     const QHash<QUuid, Statement> idHash = [enabledNodes] {
255         auto nextVarId = 0;
256         auto res = QHash<QUuid, Statement>();
257         for (const QShaderNode &node : enabledNodes)
258             res.insert(node.uuid(), nodeToStatement(node, nextVarId));
259         return res;
260     }();
261 
262     auto result = QVector<Statement>();
263     QVector<Edge> currentEdges = enabledEdges;
264     QVector<QUuid> currentUuids = [enabledNodes, enabledEdges] {
265         const QVector<QShaderNode> inputs = copyOutputNodes(enabledNodes, enabledEdges);
266         auto res = QVector<QUuid>();
267         std::transform(inputs.cbegin(), inputs.cend(),
268                        std::back_inserter(res),
269                        [](const QShaderNode &node) { return node.uuid(); });
270         return res;
271     }();
272 
273     // Implements Kahn's algorithm to flatten the graph
274     // https://en.wikipedia.org/wiki/Topological_sorting#Kahn.27s_algorithm
275     //
276     // We implement it with a small twist though, we follow the edges backward
277     // because we want to track the dependencies from the output nodes and not the
278     // input nodes
279     while (!currentUuids.isEmpty()) {
280         const QUuid uuid = currentUuids.takeFirst();
281         result.append(completeStatement(idHash, enabledEdges, uuid));
282 
283         const QVector<QShaderGraph::Edge> outgoing = outgoingEdges(currentEdges, uuid);
284         for (const QShaderGraph::Edge &outgoingEdge : outgoing) {
285             currentEdges.removeAll(outgoingEdge);
286             const QUuid nextUuid = outgoingEdge.sourceNodeUuid;
287             const QVector<QShaderGraph::Edge> incoming = incomingEdges(currentEdges, nextUuid);
288             if (incoming.isEmpty()) {
289                 currentUuids.append(nextUuid);
290             }
291         }
292     }
293 
294     std::reverse(result.begin(), result.end());
295 
296     removeNodesWithUnboundInputs(result, enabledEdges);
297 
298     return result;
299 }
300 
operator ==(const QShaderGraph::Edge & lhs,const QShaderGraph::Edge & rhs)301 bool operator==(const QShaderGraph::Edge &lhs, const QShaderGraph::Edge &rhs) noexcept
302 {
303     return lhs.sourceNodeUuid == rhs.sourceNodeUuid
304         && lhs.sourcePortName == rhs.sourcePortName
305         && lhs.targetNodeUuid == rhs.targetNodeUuid
306         && lhs.targetPortName == rhs.targetPortName;
307 }
308 
operator ==(const QShaderGraph::Statement & lhs,const QShaderGraph::Statement & rhs)309 bool operator==(const QShaderGraph::Statement &lhs, const QShaderGraph::Statement &rhs) noexcept
310 {
311     return lhs.inputs == rhs.inputs
312         && lhs.outputs == rhs.outputs
313         && lhs.node.uuid() == rhs.node.uuid();
314 }
315 
316 QT_END_NAMESPACE
317