Skip to content
Snippets Groups Projects
Commit 01c74a6a authored by Jerome Hue's avatar Jerome Hue
Browse files

feat: Implement backward() for stack, add unit test

parent 5e406854
No related branches found
No related tags found
1 merge request!342feat: Backward of Stack Operator
Pipeline #66007 passed
...@@ -61,7 +61,18 @@ void StackOpImpl::forward() { ...@@ -61,7 +61,18 @@ void StackOpImpl::forward() {
op.forwardStep() * op.getInput(0)->size()); op.forwardStep() * op.getInput(0)->size());
} }
void StackOpImpl::backward() {} void StackOpImpl::backward() {
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
AIDGE_ASSERT(op.backwardStep() > 0, "Stack operator has not been run forward");
auto inputGrad = op.getInput(0)->grad();
auto outputGrad = op.getOutput(0)->grad();
Log::notice("Size of stack in grad : {}", inputGrad->size());
Log::notice("Size of stack out grad : {}", outputGrad->size());
*inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone();
}
StackOp::StackOp(std::uint32_t maxElements) StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
...@@ -137,9 +148,13 @@ std::set<std::string> StackOp::getAvailableBackends() const { ...@@ -137,9 +148,13 @@ std::set<std::string> StackOp::getAvailableBackends() const {
void StackOp::forward() { void StackOp::forward() {
OperatorTensor::forward(); OperatorTensor::forward();
++forwardStep(); ++forwardStep();
backwardStep() = forwardStep();
} }
void StackOp::backward() {} void StackOp::backward() {
OperatorTensor::backward();
--backwardStep();
}
std::shared_ptr<Node> Stack(std::uint32_t maxElements, std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) { const std::string &name) {
......
...@@ -166,4 +166,62 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { ...@@ -166,4 +166,62 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
} }
} }
} }
TEST_CASE("[core/operator] Stack(backward)", "[Stack][Backward]") {
SECTION("Stack backward with fixed values") {
std::shared_ptr<Tensor> stack1 =
std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}});
std::shared_ptr<Tensor> stack2 =
std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}});
auto stack = Stack(2, "stack");
std::shared_ptr<StackOp> op =
std::static_pointer_cast<StackOp>(stack->getOperator());
//op->associateInput(0, stack1);
op->associateInput(0, stack1);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
op->forwardDims();
// Simulate forward pass
op->forward();
op->forward();
auto newGrad = std::make_shared<Tensor>(
Tensor(Array2D<int, 2, 3>({{{1, 2, 3}, {4, 5, 6}}})));
op->getOutput(0)->setGrad(newGrad);
REQUIRE_NOTHROW(op->backward());
REQUIRE(*op->getInput(0)->grad() == *stack2);
REQUIRE_NOTHROW(op->backward());
REQUIRE(*op->getInput(0)->grad() == *stack1);
}
SECTION("Edge cases") {
std::shared_ptr<Tensor> stack1 =
std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}});
std::shared_ptr<Tensor> stack2 =
std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}});
auto stack = Stack(2, "stack");
std::shared_ptr<StackOp> op =
std::static_pointer_cast<StackOp>(stack->getOperator());
op->associateInput(0, stack1);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
// Need to run forward before
REQUIRE_THROWS(op->backward());
op->forward();
op->backward();
REQUIRE(*op->getInput(0)->grad() == Tensor(Array1D<int, 3>({{0,0,0}})));
}
}
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment