Skip to content
Snippets Groups Projects

Fuse bn

Merged Cyril Moineau requested to merge fuseBN into main
1 file
+ 11
15
Compare changes
  • Side-by-side
  • Inline
+ 11
15
@@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() {
{
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
}
}
// Compute dimensions of every node
@@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
assert(mNodes.size()>0 && "There must be at least one Node to replace");
bool replacable;
std::shared_ptr<Node> previousInputNode;
std::shared_ptr<Node> newInputNode;
std::shared_ptr<Node> previousOutputNode;
std::shared_ptr<Node> previousInputNode = (*inputNodes().begin());
std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin());
std::shared_ptr<Node> newOutputNode;
auto gNew = std::make_shared<GraphView>();
gNew->add(newNodes, false);
if (newNodes.empty()) {
replacable = (outputNodes().size() == 1) &&
(inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbInputs() == 1);
previousOutputNode = (*outputNodes().begin());
previousInputNode = (*inputNodes().begin());
(inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbDataInputs() == 1);
newOutputNode = previousInputNode->input(0).first;
} else {
replacable = ((outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1));
previousOutputNode = (*outputNodes().begin());
newOutputNode = (*gNew->outputNodes().begin());
replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
replacable = (outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1) &&
(previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
}
if (replacable) {
@@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
mOutputNodes.erase(val);
}
}
}
\ No newline at end of file
}
Loading