Skip to content
Snippets Groups Projects

Add: backward pass for Pop operator

Merged Jerome Hue requested to merge jeromeh/aidge_core:pop-backward into dev
Files
3
@@ -86,6 +86,11 @@ public:
@@ -86,6 +86,11 @@ public:
* @brief Executes the forward pass for the `Pop` operation.
* @brief Executes the forward pass for the `Pop` operation.
*/
*/
void forward() override;
void forward() override;
 
 
/**
 
* @brief Executes the backward pass for the `Pop` operation.
 
*/
 
void backward() override;
};
};
/**
/**
@@ -93,7 +98,8 @@ public:
@@ -93,7 +98,8 @@ public:
* @brief Attributes specific to the `Pop` operator.
* @brief Attributes specific to the `Pop` operator.
*/
*/
enum class PopAttr {
enum class PopAttr {
ForwardStep // Tracks the current step in the forward pass
ForwardStep, // Tracks the current step in the forward pass
 
BackwardStep // Tracks the current step in the backward pass
};
};
/**
/**
@@ -115,7 +121,7 @@ public:
@@ -115,7 +121,7 @@ public:
static const std::string Type;
static const std::string Type;
private:
private:
using Attributes_ = StaticAttributes<PopAttr, std::uint32_t>;
using Attributes_ = StaticAttributes<PopAttr, std::uint32_t, std::uint32_t>;
template <PopAttr e> using attr = typename Attributes_::template attr<e>;
template <PopAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
const std::shared_ptr<Attributes_> mAttributes;
@@ -171,6 +177,11 @@ public:
@@ -171,6 +177,11 @@ public:
*/
*/
void forward() override;
void forward() override;
 
/**
 
* @brief Executes the backward pass for the `Pop` operation.
 
*/
 
void backward() override;
 
/**
/**
* @brief Access the operator's attributes.
* @brief Access the operator's attributes.
* @return A shared pointer to the operator's attributes.
* @return A shared pointer to the operator's attributes.
@@ -183,6 +194,8 @@ public:
@@ -183,6 +194,8 @@ public:
*/
*/
inline std::uint32_t& forwardStep() const { return mAttributes->template getAttr<PopAttr::ForwardStep>(); }
inline std::uint32_t& forwardStep() const { return mAttributes->template getAttr<PopAttr::ForwardStep>(); }
 
inline std::uint32_t& backwardStep() const { return mAttributes->template getAttr<PopAttr::BackwardStep>(); }
 
/**
/**
* @brief Retrieve the names of the operator's input tensors.
* @brief Retrieve the names of the operator's input tensors.
* @return A vector of strings representing input tensor names.
* @return A vector of strings representing input tensor names.
@@ -214,7 +227,7 @@ namespace {
@@ -214,7 +227,7 @@ namespace {
*/
*/
template <>
template <>
const char *const EnumStrings<Aidge::PopAttr>::data[] = {
const char *const EnumStrings<Aidge::PopAttr>::data[] = {
"forward_step"
"forward_step", "backward_step"
};
};
}
}
Loading