Skip to content

Commit

Permalink
Fix performance issues with graph edge iteration in ShaderGraph
Browse files Browse the repository at this point in the history
In complex materials graph and shader graph edge iteration can be
extremely slow, because some edges may be visited unnecessarily
multiple times. This is especially noticable in two functions:
ShaderGraph::addUpstreamDependencies and ShaderGraph::optimize() .

GraphIterator and ShaderGraphEdgeIterator classes iterate over DAGs
without marking nodes as visited, which may lead to exponential
traversal time for some DAGs:
https://stackoverflow.com/a/69326676

This patch adds two functions which efficiently generate a unique list
of (shader) graph edges and uses those lists instead of graph iterators.
  • Loading branch information
nadult committed Sep 20, 2024
1 parent 3ec6e87 commit e8c8ccf
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 3 deletions.
5 changes: 5 additions & 0 deletions source/MaterialXCore/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ GraphIterator Element::traverseGraph() const
return GraphIterator(getSelfNonConst());
}

vector<Edge> Element::uniqueGraphEdges() const
{
return GraphIterator::uniqueGraphEdges(getSelfNonConst());
}

Edge Element::getUpstreamEdge(size_t) const
{
return NULL_EDGE;
Expand Down
4 changes: 4 additions & 0 deletions source/MaterialXCore/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,10 @@ class MX_CORE_API Element : public std::enable_shared_from_this<Element>
/// @sa getUpstreamElement
GraphIterator traverseGraph() const;

/// Returns a vector of all unique graph edges.
/// @throws ExceptionFoundCycle if a cycle is encountered.
vector<Edge> uniqueGraphEdges() const;

/// Return the Edge with the given index that lies directly upstream from
/// this element in the dataflow graph.
/// @param index An optional index of the edge to be returned, where the
Expand Down
2 changes: 1 addition & 1 deletion source/MaterialXCore/Interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ bool Output::hasUpstreamCycle() const
{
try
{
for (Edge edge : traverseGraph()) { }
uniqueGraphEdges();
}
catch (ExceptionFoundCycle&)
{
Expand Down
35 changes: 35 additions & 0 deletions source/MaterialXCore/Traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,41 @@ void GraphIterator::returnPathDownstream(ElementPtr upstreamElem)
_connectingElem = ElementPtr();
}

static void uniqueTraverse(ElementPtr upstreamElem, std::set<Edge>& visitedEdges, std::set<ElementPtr>& pathNodes, vector<Edge>& out)
{
if (pathNodes.count(upstreamElem))
{
throw ExceptionFoundCycle("Encountered cycle at element: " + upstreamElem->asString());
}
pathNodes.emplace(upstreamElem);

for (size_t i = 0, upstreamEdgeCount = upstreamElem->getUpstreamEdgeCount(); i < upstreamEdgeCount; i++)
{
Edge edge = upstreamElem->getUpstreamEdge(i);
if (visitedEdges.find(edge) != visitedEdges.end() || !edge.getUpstreamElement())
continue;

visitedEdges.emplace(edge);
out.emplace_back(edge);
uniqueTraverse(edge.getUpstreamElement(), visitedEdges, pathNodes, out);
}

pathNodes.erase(upstreamElem);
}

vector<Edge> GraphIterator::uniqueGraphEdges(ElementPtr root)
{
vector<Edge> out;
std::set<Edge> visitedEdges;
std::set<ElementPtr> pathNodes;

if (root)
{
uniqueTraverse(root, visitedEdges, pathNodes, out);
}
return out;
}

//
// InheritanceIterator methods
//
Expand Down
3 changes: 3 additions & 0 deletions source/MaterialXCore/Traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ class MX_CORE_API GraphIterator

/// @}

/// Returns a list of all unique graph edges.
static vector<Edge> uniqueGraphEdges(ElementPtr);

private:
void extendPathUpstream(ElementPtr upstreamElem, ElementPtr connectingElem);
void returnPathDownstream(ElementPtr upstreamElem);
Expand Down
32 changes: 30 additions & 2 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& conte
{
std::set<ElementPtr> processedOutputs;

for (Edge edge : root.traverseGraph())
for (Edge edge : root.uniqueGraphEdges())
{
ElementPtr upstreamElement = edge.getUpstreamElement();
if (!upstreamElement)
Expand Down Expand Up @@ -900,7 +900,7 @@ void ShaderGraph::optimize()
ShaderOutput* upstreamPort = outputSocket->getConnection();
if (upstreamPort && upstreamPort->getNode() != this)
{
for (ShaderGraphEdge edge : ShaderGraph::traverseUpstream(upstreamPort))
for (ShaderGraphEdge edge : ShaderGraphEdgeIterator::uniqueGraphEdges(upstreamPort))
{
ShaderNode* node = edge.upstream->getNode();
if (usedNodesSet.count(node) == 0)
Expand Down Expand Up @@ -1259,4 +1259,32 @@ void ShaderGraphEdgeIterator::returnPathDownstream(ShaderOutput* upstream)
_downstream = nullptr;
}

static void uniqueTraverse(ShaderNode* upstreamNode, std::set<ShaderGraphEdge>& visited, vector<ShaderGraphEdge>& out)
{
for (size_t i = 0, numInputs = upstreamNode->numInputs(); i < numInputs; i++)
{
ShaderInput* input = upstreamNode->getInput(i);
ShaderOutput* output = input->getConnection();
ShaderGraphEdge edge(output, input);

if (visited.find(edge) != visited.end() || !output)
continue;

visited.emplace(edge);
out.emplace_back(edge);
uniqueTraverse(output->getNode(), visited, out);
}
}

vector<ShaderGraphEdge> ShaderGraphEdgeIterator::uniqueGraphEdges(ShaderOutput* root)
{
vector<ShaderGraphEdge> out;
std::set<ShaderGraphEdge> visited;
if (root)
{
uniqueTraverse(root->getNode(), visited, out);
}
return out;
}

MATERIALX_NAMESPACE_END
19 changes: 19 additions & 0 deletions source/MaterialXGenShader/ShaderGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ class MX_GENSHADER_API ShaderGraphEdge
downstream(down)
{
}

bool operator==(const ShaderGraphEdge& rhs) const
{
return upstream == rhs.upstream && downstream == rhs.downstream;
}

bool operator!=(const ShaderGraphEdge& rhs) const
{
return !(*this == rhs);
}

bool operator<(const ShaderGraphEdge& rhs) const
{
return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream);
}

ShaderOutput* upstream;
ShaderInput* downstream;
};
Expand Down Expand Up @@ -251,6 +267,9 @@ class MX_GENSHADER_API ShaderGraphEdgeIterator
/// Return the end iterator.
static const ShaderGraphEdgeIterator& end();

/// Returns a list of all unique ShaderGraph edges
static vector<ShaderGraphEdge> uniqueGraphEdges(ShaderOutput*);

private:
void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
void returnPathDownstream(ShaderOutput* upstream);
Expand Down

0 comments on commit e8c8ccf

Please sign in to comment.