diff --git a/include/aidge/operator/Pop.hpp b/include/aidge/operator/Pop.hpp index 0624286f7c8a8a84dd5aac5eae16c019b7a9e88b..2cf567329496e5f8a7745ab3461dc6f74d0ea1ba 100644 --- a/include/aidge/operator/Pop.hpp +++ b/include/aidge/operator/Pop.hpp @@ -86,6 +86,11 @@ public: * @brief Executes the forward pass for the `Pop` operation. */ void forward() override; + + /** + * @brief Executes the backward pass for the `Pop` operation. + */ + void backward() override; }; /** @@ -93,7 +98,8 @@ public: * @brief Attributes specific to the `Pop` operator. */ enum class PopAttr { - ForwardStep // Tracks the current step in the forward pass + ForwardStep, // Tracks the current step in the forward pass + BackwardStep // Tracks the current step in the backward pass }; /** @@ -115,7 +121,7 @@ public: static const std::string Type; private: - using Attributes_ = StaticAttributes<PopAttr, std::uint32_t>; + using Attributes_ = StaticAttributes<PopAttr, std::uint32_t, std::uint32_t>; template <PopAttr e> using attr = typename Attributes_::template attr<e>; const std::shared_ptr<Attributes_> mAttributes; @@ -171,6 +177,11 @@ public: */ void forward() override; + /** + * @brief Executes the backward pass for the `Pop` operation. + */ + void backward() override; + /** * @brief Access the operator's attributes. * @return A shared pointer to the operator's attributes. @@ -183,6 +194,8 @@ public: */ inline std::uint32_t& forwardStep() const { return mAttributes->template getAttr<PopAttr::ForwardStep>(); } + inline std::uint32_t& backwardStep() const { return mAttributes->template getAttr<PopAttr::BackwardStep>(); } + /** * @brief Retrieve the names of the operator's input tensors. * @return A vector of strings representing input tensor names. @@ -214,7 +227,7 @@ namespace { */ template <> const char *const EnumStrings<Aidge::PopAttr>::data[] = { - "forward_step" + "forward_step", "backward_step" }; } diff --git a/src/operator/Pop.cpp b/src/operator/Pop.cpp index fa77d18e7e3c5b30466304e04cf2ad95affce20e..c93078ed159257ee52602dc2fdf675b24af05155 100644 --- a/src/operator/Pop.cpp +++ b/src/operator/Pop.cpp @@ -12,6 +12,7 @@ #include "aidge/operator/Pop.hpp" #include <memory> +#include <stdexcept> #include <string> #include "aidge/data/Tensor.hpp" @@ -36,13 +37,25 @@ void Aidge::Pop_OpImpl::forward() { *op.getOutput(0) = op.getInput(0)->extract({op.forwardStep()}).clone(); } +void Aidge::Pop_OpImpl::backward() { + const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp); + + auto outputGrad = op.getOutput(0)->grad(); + auto inputGrad = op.getInput(0)->grad(); + + inputGrad->getImpl()->copy( + outputGrad->getImpl()->rawPtr(), + outputGrad->size(), + (op.backwardStep()-1) * outputGrad->size()); +} + ////////////////////////////////////////////////////////// const std::string Aidge::Pop_Op::Type = "Pop"; Aidge::Pop_Op::Pop_Op() : OperatorTensor(Type, {InputCategory::Data}, 1), - mAttributes(std::make_shared<Attributes_>(attr<PopAttr::ForwardStep>(0))) + mAttributes(std::make_shared<Attributes_>(attr<PopAttr::ForwardStep>(0), attr<PopAttr::BackwardStep>(0))) { mImpl = std::make_shared<Pop_OpImpl>(*this); } @@ -77,6 +90,7 @@ bool Aidge::Pop_Op::forwardDims(bool /*allowDataDependency*/) { void Aidge::Pop_Op::updateConsummerProducer() { Operator::updateConsummerProducer(); mAttributes->template getAttr<PopAttr::ForwardStep>() = 0; + mAttributes->template getAttr<PopAttr::BackwardStep>() = 0; } void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { @@ -96,10 +110,16 @@ std::set<std::string> Aidge::Pop_Op::getAvailableBackends() const { void Aidge::Pop_Op::forward() { OperatorTensor::forward(); ++mAttributes->template getAttr<PopAttr::ForwardStep>(); + backwardStep() = forwardStep(); +} + +void Aidge::Pop_Op::backward() { + OperatorTensor::backward(); + --mAttributes->template getAttr<PopAttr::BackwardStep>(); } /////////////////////////////////////////// std::shared_ptr<Aidge::Node> Aidge::Pop(const std::string& name) { return std::make_shared<Node>(std::make_shared<Pop_Op>(), name); -} \ No newline at end of file +} diff --git a/unit_tests/operator/Test_PopImpl.cpp b/unit_tests/operator/Test_PopImpl.cpp index d3c87ef7289e4516442885f7449060055c428c49..2f639a13964bd6b9fc95cbf4a22f8cc235b333f7 100644 --- a/unit_tests/operator/Test_PopImpl.cpp +++ b/unit_tests/operator/Test_PopImpl.cpp @@ -14,11 +14,12 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/Pop.hpp" +#include "aidge/utils/ArrayHelpers.hpp" #include "aidge/utils/TensorUtils.hpp" using namespace Aidge; -TEST_CASE("[cpu/operator] Pop(forward)", "[Pop][CPU]") { +TEST_CASE("[cpu/operator] Pop", "[Pop][CPU]") { std::shared_ptr<Tensor> pop1 = std::make_shared<Tensor>(Array1D<int,3>{{4,5,6}}); std::shared_ptr<Tensor> pop2 = std::make_shared<Tensor>(Array1D<int,3>{{1,2,3}}); std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,2,3>{{{1,2,3}, {4,5,6}}}); @@ -34,4 +35,18 @@ TEST_CASE("[cpu/operator] Pop(forward)", "[Pop][CPU]") { REQUIRE(*op->getOutput(0) == *pop2); REQUIRE_NOTHROW(pop->forward()); REQUIRE(*op->getOutput(0) == *pop1); + + + // Backward + auto expectedGrad1 = Tensor(Array2D<int, 2, 3>({{{0,0,0},{1,1,1}}})); + auto expectedGrad2 = Tensor(Array2D<int, 2, 3>({{{2,2,2},{1,1,1}}})); + + op->getOutput(0)->setGrad(std::make_shared<Tensor>(Array1D<int,3>({1,1,1}))); + REQUIRE_NOTHROW(pop->backward()); + REQUIRE(*op->getInput(0)->grad() == expectedGrad1); + + + op->getOutput(0)->setGrad(std::make_shared<Tensor>(Array1D<int,3>({2,2,2}))); + REQUIRE_NOTHROW(pop->backward()); + REQUIRE(*op->getInput(0)->grad() == expectedGrad2); }