diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index a0641032281c6bedb4459a0d08da1193d6375129..9798dfe639475b78c761f0450c80635c5c80a63d 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() { { assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } - + } } // Compute dimensions of every node @@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { assert(mNodes.size()>0 && "There must be at least one Node to replace"); bool replacable; - std::shared_ptr<Node> previousInputNode; - std::shared_ptr<Node> newInputNode; - std::shared_ptr<Node> previousOutputNode; + std::shared_ptr<Node> previousInputNode = (*inputNodes().begin()); + std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin()); std::shared_ptr<Node> newOutputNode; - + auto gNew = std::make_shared<GraphView>(); gNew->add(newNodes, false); if (newNodes.empty()) { replacable = (outputNodes().size() == 1) && - (inputNodes().size() == 1) && - ((*outputNodes().begin())->nbOutputs() == 1) && - ((*inputNodes().begin())->nbInputs() == 1); - previousOutputNode = (*outputNodes().begin()); - previousInputNode = (*inputNodes().begin()); + (inputNodes().size() == 1) && + ((*outputNodes().begin())->nbOutputs() == 1) && + ((*inputNodes().begin())->nbDataInputs() == 1); newOutputNode = previousInputNode->input(0).first; } else { - replacable = ((outputNodes().size() == gNew->outputNodes().size()) && - (outputNodes().size() == 1)); - previousOutputNode = (*outputNodes().begin()); newOutputNode = (*gNew->outputNodes().begin()); - replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); + replacable = (outputNodes().size() == gNew->outputNodes().size()) && + (outputNodes().size() == 1) && + (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); } if (replacable) { @@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { mOutputNodes.erase(val); } } -} \ No newline at end of file +}