diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp index 79f255ac3205540998584bf414820c56c046c042..3b42db7fe18d2269b95cf35fd92851d1e3684bad 100644 --- a/src/recipes/GraphViewHelper.cpp +++ b/src/recipes/GraphViewHelper.cpp @@ -12,11 +12,12 @@ #include <memory> #include <set> +#include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/ErrorHandling.hpp" -#include "aidge/recipies/GraphViewHelper.hpp" +#include "aidge/recipes/GraphViewHelper.hpp" std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { @@ -47,12 +48,10 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { for (const auto& node : gv->getNodes()) { // TODO: check that each node is an OperatorTensor - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); - const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type()); + const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator()); for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { - const auto& t = op->getOutput(o); - t -> grad() -> setDataType(t -> dataType()); - t -> grad() -> setBackend(t -> getImpl() -> backend()); + op->getOutput(o)->initGradient(); } } } \ No newline at end of file