From dba5e80d76ff31c6b4afc18c5009386d1c9fbab4 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Wed, 19 Feb 2025 14:53:10 +0100 Subject: [PATCH 1/4] Fix StackOp forward() Call OperatorTensor::forward() instead of Operator::forward(), ensuring dims are forwarded. --- src/operator/Stack.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index a938f470d..17502e0f4 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -132,7 +132,7 @@ std::set<std::string> StackOp::getAvailableBackends() const { } void StackOp::forward() { - Operator::forward(); + OperatorTensor::forward(); ++forwardStep(); } -- GitLab From 5e40685480cbb187b1d1f0adff6122663a47e4ea Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Wed, 19 Feb 2025 15:00:12 +0100 Subject: [PATCH 2/4] Add backward step counter and backward() declaration for stack op --- include/aidge/operator/Stack.hpp | 23 +++++++++++++++++++++-- src/operator/Stack.cpp | 5 +++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index 214428447..83e4f68e1 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -89,10 +89,16 @@ public: * @brief Executes the forward pass for the Stack operation. */ void forward() override; + + /** + * @brief Executes the backward pass for the Stack operation. + */ + void backward() override; }; enum class StackAttr { ForwardStep, // Tracks the current step in the forward pass. + BackwardStep, // Tracks the current step in the forward pass. MaxElements // Maximum number of elements that can be stacked. }; } // namespace Aidge @@ -123,7 +129,7 @@ namespace Aidge { class StackOp : public OperatorTensor, public Registrable<StackOp, std::string, std::function<std::unique_ptr<OperatorImpl>(const StackOp&)>> { private: - using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t>; + using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t, std::uint32_t>; template <StackAttr e> using attr = typename Attributes_::template attr<e>; const std::shared_ptr<Attributes_> mAttributes; @@ -181,6 +187,11 @@ public: */ void forward() override; + /** + * @brief Executes the backward pass for the `Stack` operation. + */ + void backward() override; + /** * @brief Access the operator's attributes. * @return A shared pointer to the operator's attributes. @@ -205,6 +216,15 @@ public: return mAttributes->template getAttr<StackAttr::ForwardStep>(); } + /** + * @brief Get or set the backward step counter for the operator. + * @return A reference to the backward step counter. + */ + inline std::uint32_t& backwardStep() const { + return mAttributes->template getAttr<StackAttr::BackwardStep>(); + } + + /** * @brief Retrieve the names of the operator's input tensors. * @return A vector of strings representing input tensor names. @@ -239,5 +259,4 @@ public: std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string& name = ""); } // namespace Aidge - #endif /* AIDGE_CORE_OPERATOR_STACK_H_ */ diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 17502e0f4..32a413174 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -61,10 +61,13 @@ void StackOpImpl::forward() { op.forwardStep() * op.getInput(0)->size()); } +void StackOpImpl::backward() {} + StackOp::StackOp(std::uint32_t maxElements) : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), mAttributes(std::make_shared<Attributes_>( attr<StackAttr::MaxElements>(maxElements), + attr<StackAttr::BackwardStep>(0), attr<StackAttr::ForwardStep>(0))) { mImpl = std::make_shared<StackOpImpl>(*this); } @@ -136,6 +139,8 @@ void StackOp::forward() { ++forwardStep(); } +void StackOp::backward() {} + std::shared_ptr<Node> Stack(std::uint32_t maxElements, const std::string &name) { return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), -- GitLab From 01c74a6af3b59687d211e15eb98cbd8cc73e7635 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Wed, 19 Feb 2025 15:03:48 +0100 Subject: [PATCH 3/4] feat: Implement backward() for stack, add unit test --- src/operator/Stack.cpp | 19 ++++++++- unit_tests/operator/Test_StackImpl.cpp | 58 ++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 32a413174..064b51aa9 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -61,7 +61,18 @@ void StackOpImpl::forward() { op.forwardStep() * op.getInput(0)->size()); } -void StackOpImpl::backward() {} +void StackOpImpl::backward() { + const StackOp &op = dynamic_cast<const StackOp &>(mOp); + AIDGE_ASSERT(op.backwardStep() > 0, "Stack operator has not been run forward"); + + auto inputGrad = op.getInput(0)->grad(); + auto outputGrad = op.getOutput(0)->grad(); + + Log::notice("Size of stack in grad : {}", inputGrad->size()); + Log::notice("Size of stack out grad : {}", outputGrad->size()); + + *inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone(); +} StackOp::StackOp(std::uint32_t maxElements) : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), @@ -137,9 +148,13 @@ std::set<std::string> StackOp::getAvailableBackends() const { void StackOp::forward() { OperatorTensor::forward(); ++forwardStep(); + backwardStep() = forwardStep(); } -void StackOp::backward() {} +void StackOp::backward() { + OperatorTensor::backward(); + --backwardStep(); +} std::shared_ptr<Node> Stack(std::uint32_t maxElements, const std::string &name) { diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp index ccdf5787d..fe9ac0519 100644 --- a/unit_tests/operator/Test_StackImpl.cpp +++ b/unit_tests/operator/Test_StackImpl.cpp @@ -166,4 +166,62 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { } } } + +TEST_CASE("[core/operator] Stack(backward)", "[Stack][Backward]") { + SECTION("Stack backward with fixed values") { + std::shared_ptr<Tensor> stack1 = + std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + std::shared_ptr<Tensor> stack2 = + std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}}); + + auto stack = Stack(2, "stack"); + std::shared_ptr<StackOp> op = + std::static_pointer_cast<StackOp>(stack->getOperator()); + + //op->associateInput(0, stack1); + op->associateInput(0, stack1); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + op->forwardDims(); + + // Simulate forward pass + op->forward(); + op->forward(); + + auto newGrad = std::make_shared<Tensor>( + Tensor(Array2D<int, 2, 3>({{{1, 2, 3}, {4, 5, 6}}}))); + op->getOutput(0)->setGrad(newGrad); + + REQUIRE_NOTHROW(op->backward()); + REQUIRE(*op->getInput(0)->grad() == *stack2); + + REQUIRE_NOTHROW(op->backward()); + REQUIRE(*op->getInput(0)->grad() == *stack1); + } + + SECTION("Edge cases") { + std::shared_ptr<Tensor> stack1 = + std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + std::shared_ptr<Tensor> stack2 = + std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}}); + + auto stack = Stack(2, "stack"); + std::shared_ptr<StackOp> op = + std::static_pointer_cast<StackOp>(stack->getOperator()); + + op->associateInput(0, stack1); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + + + // Need to run forward before + REQUIRE_THROWS(op->backward()); + + op->forward(); + op->backward(); + REQUIRE(*op->getInput(0)->grad() == Tensor(Array1D<int, 3>({{0,0,0}}))); + } +} + + } // namespace Aidge -- GitLab From 0c6e330242d3b61a8e7d45439254881199e98d0c Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Fri, 21 Feb 2025 08:12:28 +0100 Subject: [PATCH 4/4] Remove debug print --- src/operator/Stack.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 064b51aa9..9f8cd1639 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -68,9 +68,6 @@ void StackOpImpl::backward() { auto inputGrad = op.getInput(0)->grad(); auto outputGrad = op.getOutput(0)->grad(); - Log::notice("Size of stack in grad : {}", inputGrad->size()); - Log::notice("Size of stack out grad : {}", outputGrad->size()); - *inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone(); } -- GitLab