diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 32a413174739f0b773441c431c3c4013b01a455d..064b51aa9a72c439b6af7d26b05ed04205261cc8 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -61,7 +61,18 @@ void StackOpImpl::forward() { op.forwardStep() * op.getInput(0)->size()); } -void StackOpImpl::backward() {} +void StackOpImpl::backward() { + const StackOp &op = dynamic_cast<const StackOp &>(mOp); + AIDGE_ASSERT(op.backwardStep() > 0, "Stack operator has not been run forward"); + + auto inputGrad = op.getInput(0)->grad(); + auto outputGrad = op.getOutput(0)->grad(); + + Log::notice("Size of stack in grad : {}", inputGrad->size()); + Log::notice("Size of stack out grad : {}", outputGrad->size()); + + *inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone(); +} StackOp::StackOp(std::uint32_t maxElements) : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), @@ -137,9 +148,13 @@ std::set<std::string> StackOp::getAvailableBackends() const { void StackOp::forward() { OperatorTensor::forward(); ++forwardStep(); + backwardStep() = forwardStep(); } -void StackOp::backward() {} +void StackOp::backward() { + OperatorTensor::backward(); + --backwardStep(); +} std::shared_ptr<Node> Stack(std::uint32_t maxElements, const std::string &name) { diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp index ccdf5787d666f030b8856704eb0e4fb108089075..fe9ac0519c740bdf5d6be96f2dde187425c043a1 100644 --- a/unit_tests/operator/Test_StackImpl.cpp +++ b/unit_tests/operator/Test_StackImpl.cpp @@ -166,4 +166,62 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { } } } + +TEST_CASE("[core/operator] Stack(backward)", "[Stack][Backward]") { + SECTION("Stack backward with fixed values") { + std::shared_ptr<Tensor> stack1 = + std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + std::shared_ptr<Tensor> stack2 = + std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}}); + + auto stack = Stack(2, "stack"); + std::shared_ptr<StackOp> op = + std::static_pointer_cast<StackOp>(stack->getOperator()); + + //op->associateInput(0, stack1); + op->associateInput(0, stack1); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + op->forwardDims(); + + // Simulate forward pass + op->forward(); + op->forward(); + + auto newGrad = std::make_shared<Tensor>( + Tensor(Array2D<int, 2, 3>({{{1, 2, 3}, {4, 5, 6}}}))); + op->getOutput(0)->setGrad(newGrad); + + REQUIRE_NOTHROW(op->backward()); + REQUIRE(*op->getInput(0)->grad() == *stack2); + + REQUIRE_NOTHROW(op->backward()); + REQUIRE(*op->getInput(0)->grad() == *stack1); + } + + SECTION("Edge cases") { + std::shared_ptr<Tensor> stack1 = + std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}}); + std::shared_ptr<Tensor> stack2 = + std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}}); + + auto stack = Stack(2, "stack"); + std::shared_ptr<StackOp> op = + std::static_pointer_cast<StackOp>(stack->getOperator()); + + op->associateInput(0, stack1); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + + + // Need to run forward before + REQUIRE_THROWS(op->backward()); + + op->forward(); + op->backward(); + REQUIRE(*op->getInput(0)->grad() == Tensor(Array1D<int, 3>({{0,0,0}}))); + } +} + + } // namespace Aidge