Skip to content
Snippets Groups Projects

feat: Backward of Stack Operator

Merged Jerome Hue requested to merge jeromeh/aidge_core:stack-backward into dev
Files
3
@@ -89,10 +89,16 @@ public:
@@ -89,10 +89,16 @@ public:
* @brief Executes the forward pass for the Stack operation.
* @brief Executes the forward pass for the Stack operation.
*/
*/
void forward() override;
void forward() override;
 
 
/**
 
* @brief Executes the backward pass for the Stack operation.
 
*/
 
void backward() override;
};
};
enum class StackAttr {
enum class StackAttr {
ForwardStep, // Tracks the current step in the forward pass.
ForwardStep, // Tracks the current step in the forward pass.
 
BackwardStep, // Tracks the current step in the forward pass.
MaxElements // Maximum number of elements that can be stacked.
MaxElements // Maximum number of elements that can be stacked.
};
};
} // namespace Aidge
} // namespace Aidge
@@ -123,7 +129,7 @@ namespace Aidge {
@@ -123,7 +129,7 @@ namespace Aidge {
class StackOp : public OperatorTensor,
class StackOp : public OperatorTensor,
public Registrable<StackOp, std::string, std::function<std::unique_ptr<OperatorImpl>(const StackOp&)>> {
public Registrable<StackOp, std::string, std::function<std::unique_ptr<OperatorImpl>(const StackOp&)>> {
private:
private:
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t>;
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t, std::uint32_t, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
const std::shared_ptr<Attributes_> mAttributes;
@@ -181,6 +187,11 @@ public:
@@ -181,6 +187,11 @@ public:
*/
*/
void forward() override;
void forward() override;
 
/**
 
* @brief Executes the backward pass for the `Stack` operation.
 
*/
 
void backward() override;
 
/**
/**
* @brief Access the operator's attributes.
* @brief Access the operator's attributes.
* @return A shared pointer to the operator's attributes.
* @return A shared pointer to the operator's attributes.
@@ -205,6 +216,15 @@ public:
@@ -205,6 +216,15 @@ public:
return mAttributes->template getAttr<StackAttr::ForwardStep>();
return mAttributes->template getAttr<StackAttr::ForwardStep>();
}
}
 
/**
 
* @brief Get or set the backward step counter for the operator.
 
* @return A reference to the backward step counter.
 
*/
 
inline std::uint32_t& backwardStep() const {
 
return mAttributes->template getAttr<StackAttr::BackwardStep>();
 
}
 
 
/**
/**
* @brief Retrieve the names of the operator's input tensors.
* @brief Retrieve the names of the operator's input tensors.
* @return A vector of strings representing input tensor names.
* @return A vector of strings representing input tensor names.
@@ -239,5 +259,4 @@ public:
@@ -239,5 +259,4 @@ public:
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string& name = "");
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string& name = "");
} // namespace Aidge
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_STACK_H_ */
#endif /* AIDGE_CORE_OPERATOR_STACK_H_ */
Loading