Skip to content
Snippets Groups Projects
Commit 1b2b1699 authored by Jerome Hue's avatar Jerome Hue
Browse files

Add pop backward

parent cb7ee7d5
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!340Add: backward pass for Pop operator
......@@ -86,6 +86,11 @@ public:
* @brief Executes the forward pass for the `Pop` operation.
*/
void forward() override;
/**
* @brief Executes the forward pass for the `Pop` operation.
*/
void backward() override;
};
/**
......@@ -93,7 +98,8 @@ public:
* @brief Attributes specific to the `Pop` operator.
*/
enum class PopAttr {
ForwardStep // Tracks the current step in the forward pass
ForwardStep, // Tracks the current step in the forward pass
BackwardStep
};
/**
......@@ -115,7 +121,7 @@ public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<PopAttr, std::uint32_t>;
using Attributes_ = StaticAttributes<PopAttr, std::uint32_t, std::uint32_t>;
template <PopAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
......@@ -171,6 +177,11 @@ public:
*/
void forward() override;
/**
* @brief Executes the backward pass for the `Pop` operation.
*/
void backward() override;
/**
* @brief Access the operator's attributes.
* @return A shared pointer to the operator's attributes.
......@@ -183,6 +194,8 @@ public:
*/
inline std::uint32_t& forwardStep() const { return mAttributes->template getAttr<PopAttr::ForwardStep>(); }
inline std::uint32_t& backwardStep() const { return mAttributes->template getAttr<PopAttr::BackwardStep>(); }
/**
* @brief Retrieve the names of the operator's input tensors.
* @return A vector of strings representing input tensor names.
......@@ -214,7 +227,7 @@ namespace {
*/
template <>
const char *const EnumStrings<Aidge::PopAttr>::data[] = {
"forward_step"
"forward_step", "backward_step"
};
}
......
......@@ -12,6 +12,7 @@
#include "aidge/operator/Pop.hpp"
#include <memory>
#include <stdexcept>
#include <string>
#include "aidge/data/Tensor.hpp"
......@@ -36,13 +37,25 @@ void Aidge::Pop_OpImpl::forward() {
*op.getOutput(0) = op.getInput(0)->extract({op.forwardStep()}).clone();
}
void Aidge::Pop_OpImpl::backward() {
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
auto outputGrad = op.getOutput(0)->grad();
auto inputGrad = op.getInput(0)->grad();
inputGrad->getImpl()->copy(
outputGrad->getImpl()->rawPtr(),
outputGrad->size(),
(op.backwardStep()-1) * outputGrad->size());
}
//////////////////////////////////////////////////////////
const std::string Aidge::Pop_Op::Type = "Pop";
Aidge::Pop_Op::Pop_Op()
: OperatorTensor(Type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(attr<PopAttr::ForwardStep>(0)))
mAttributes(std::make_shared<Attributes_>(attr<PopAttr::ForwardStep>(0), attr<PopAttr::BackwardStep>(0)))
{
mImpl = std::make_shared<Pop_OpImpl>(*this);
}
......@@ -77,6 +90,7 @@ bool Aidge::Pop_Op::forwardDims(bool /*allowDataDependency*/) {
void Aidge::Pop_Op::updateConsummerProducer() {
Operator::updateConsummerProducer();
mAttributes->template getAttr<PopAttr::ForwardStep>() = 0;
mAttributes->template getAttr<PopAttr::BackwardStep>() = 0;
}
void Aidge::Pop_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......@@ -96,10 +110,16 @@ std::set<std::string> Aidge::Pop_Op::getAvailableBackends() const {
void Aidge::Pop_Op::forward() {
OperatorTensor::forward();
++mAttributes->template getAttr<PopAttr::ForwardStep>();
backwardStep() = forwardStep();
}
void Aidge::Pop_Op::backward() {
OperatorTensor::backward();
--mAttributes->template getAttr<PopAttr::BackwardStep>();
}
///////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Pop(const std::string& name) {
return std::make_shared<Node>(std::make_shared<Pop_Op>(), name);
}
\ No newline at end of file
}
......@@ -14,6 +14,7 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
......@@ -34,4 +35,19 @@ TEST_CASE("[cpu/operator] Pop(forward)", "[Pop][CPU]") {
REQUIRE(*op->getOutput(0) == *pop2);
REQUIRE_NOTHROW(pop->forward());
REQUIRE(*op->getOutput(0) == *pop1);
// Backward
auto expectedGrad1 = Tensor(Array2D<int, 2, 3>({{{0,0,0},{1,1,1}}}));
auto expectedGrad2 = Tensor(Array2D<int, 2, 3>({{{2,2,2},{1,1,1}}}));
op->getOutput(0)->setGrad(std::make_shared<Tensor>(Array1D<int,3>({1,1,1})));
REQUIRE_NOTHROW(pop->backward());
REQUIRE(*op->getInput(0)->grad() == expectedGrad1);
op->getOutput(0)->setGrad(std::make_shared<Tensor>(Array1D<int,3>({2,2,2})));
REQUIRE_NOTHROW(pop->backward());
REQUIRE(*op->getInput(0)->grad() == expectedGrad2);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment