Skip to content
Snippets Groups Projects
Commit 3ee75f49 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'stack-backward' into 'dev'

feat: Backward of Stack Operator

See merge request !342
parents 03de25a0 0c6e3302
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!342feat: Backward of Stack Operator
Pipeline #66019 passed
......@@ -89,10 +89,16 @@ public:
* @brief Executes the forward pass for the Stack operation.
*/
void forward() override;
/**
* @brief Executes the backward pass for the Stack operation.
*/
void backward() override;
};
enum class StackAttr {
ForwardStep, // Tracks the current step in the forward pass.
BackwardStep, // Tracks the current step in the forward pass.
MaxElements // Maximum number of elements that can be stacked.
};
} // namespace Aidge
......@@ -123,7 +129,7 @@ namespace Aidge {
class StackOp : public OperatorTensor,
public Registrable<StackOp, std::string, std::function<std::unique_ptr<OperatorImpl>(const StackOp&)>> {
private:
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t>;
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
......@@ -181,6 +187,11 @@ public:
*/
void forward() override;
/**
* @brief Executes the backward pass for the `Stack` operation.
*/
void backward() override;
/**
* @brief Access the operator's attributes.
* @return A shared pointer to the operator's attributes.
......@@ -205,6 +216,15 @@ public:
return mAttributes->template getAttr<StackAttr::ForwardStep>();
}
/**
* @brief Get or set the backward step counter for the operator.
* @return A reference to the backward step counter.
*/
inline std::uint32_t& backwardStep() const {
return mAttributes->template getAttr<StackAttr::BackwardStep>();
}
/**
* @brief Retrieve the names of the operator's input tensors.
* @return A vector of strings representing input tensor names.
......@@ -239,5 +259,4 @@ public:
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string& name = "");
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_STACK_H_ */
......@@ -61,10 +61,21 @@ void StackOpImpl::forward() {
op.forwardStep() * op.getInput(0)->size());
}
void StackOpImpl::backward() {
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
AIDGE_ASSERT(op.backwardStep() > 0, "Stack operator has not been run forward");
auto inputGrad = op.getInput(0)->grad();
auto outputGrad = op.getOutput(0)->grad();
*inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone();
}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::BackwardStep>(0),
attr<StackAttr::ForwardStep>(0))) {
mImpl = std::make_shared<StackOpImpl>(*this);
}
......@@ -132,8 +143,14 @@ std::set<std::string> StackOp::getAvailableBackends() const {
}
void StackOp::forward() {
Operator::forward();
OperatorTensor::forward();
++forwardStep();
backwardStep() = forwardStep();
}
void StackOp::backward() {
OperatorTensor::backward();
--backwardStep();
}
std::shared_ptr<Node> Stack(std::uint32_t maxElements,
......
......@@ -166,4 +166,62 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
}
}
}
TEST_CASE("[core/operator] Stack(backward)", "[Stack][Backward]") {
SECTION("Stack backward with fixed values") {
std::shared_ptr<Tensor> stack1 =
std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}});
std::shared_ptr<Tensor> stack2 =
std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}});
auto stack = Stack(2, "stack");
std::shared_ptr<StackOp> op =
std::static_pointer_cast<StackOp>(stack->getOperator());
//op->associateInput(0, stack1);
op->associateInput(0, stack1);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
op->forwardDims();
// Simulate forward pass
op->forward();
op->forward();
auto newGrad = std::make_shared<Tensor>(
Tensor(Array2D<int, 2, 3>({{{1, 2, 3}, {4, 5, 6}}})));
op->getOutput(0)->setGrad(newGrad);
REQUIRE_NOTHROW(op->backward());
REQUIRE(*op->getInput(0)->grad() == *stack2);
REQUIRE_NOTHROW(op->backward());
REQUIRE(*op->getInput(0)->grad() == *stack1);
}
SECTION("Edge cases") {
std::shared_ptr<Tensor> stack1 =
std::make_shared<Tensor>(Array1D<int, 3>{{1, 2, 3}});
std::shared_ptr<Tensor> stack2 =
std::make_shared<Tensor>(Array1D<int, 3>{{4, 5, 6}});
auto stack = Stack(2, "stack");
std::shared_ptr<StackOp> op =
std::static_pointer_cast<StackOp>(stack->getOperator());
op->associateInput(0, stack1);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
// Need to run forward before
REQUIRE_THROWS(op->backward());
op->forward();
op->backward();
REQUIRE(*op->getInput(0)->grad() == Tensor(Array1D<int, 3>({{0,0,0}})));
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment