Skip to content
Snippets Groups Projects
Commit 5e406854 authored by Jerome Hue's avatar Jerome Hue
Browse files

Add backward step counter and backward() declaration for stack op

parent dba5e80d
No related branches found
No related tags found
1 merge request!342feat: Backward of Stack Operator
......@@ -89,10 +89,16 @@ public:
* @brief Executes the forward pass for the Stack operation.
*/
void forward() override;
/**
* @brief Executes the backward pass for the Stack operation.
*/
void backward() override;
};
enum class StackAttr {
ForwardStep, // Tracks the current step in the forward pass.
BackwardStep, // Tracks the current step in the forward pass.
MaxElements // Maximum number of elements that can be stacked.
};
} // namespace Aidge
......@@ -123,7 +129,7 @@ namespace Aidge {
class StackOp : public OperatorTensor,
public Registrable<StackOp, std::string, std::function<std::unique_ptr<OperatorImpl>(const StackOp&)>> {
private:
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t>;
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
......@@ -181,6 +187,11 @@ public:
*/
void forward() override;
/**
* @brief Executes the backward pass for the `Stack` operation.
*/
void backward() override;
/**
* @brief Access the operator's attributes.
* @return A shared pointer to the operator's attributes.
......@@ -205,6 +216,15 @@ public:
return mAttributes->template getAttr<StackAttr::ForwardStep>();
}
/**
* @brief Get or set the backward step counter for the operator.
* @return A reference to the backward step counter.
*/
inline std::uint32_t& backwardStep() const {
return mAttributes->template getAttr<StackAttr::BackwardStep>();
}
/**
* @brief Retrieve the names of the operator's input tensors.
* @return A vector of strings representing input tensor names.
......@@ -239,5 +259,4 @@ public:
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string& name = "");
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_STACK_H_ */
......@@ -61,10 +61,13 @@ void StackOpImpl::forward() {
op.forwardStep() * op.getInput(0)->size());
}
void StackOpImpl::backward() {}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::BackwardStep>(0),
attr<StackAttr::ForwardStep>(0))) {
mImpl = std::make_shared<StackOpImpl>(*this);
}
......@@ -136,6 +139,8 @@ void StackOp::forward() {
++forwardStep();
}
void StackOp::backward() {}
std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment