Skip to content
Snippets Groups Projects

feat: Backward of Stack Operator

Merged Jerome Hue requested to merge jeromeh/aidge_core:stack-backward into dev
2 files
+ 75
2
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 17
2
@@ -61,7 +61,18 @@ void StackOpImpl::forward() {
op.forwardStep() * op.getInput(0)->size());
}
void StackOpImpl::backward() {}
void StackOpImpl::backward() {
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
AIDGE_ASSERT(op.backwardStep() > 0, "Stack operator has not been run forward");
auto inputGrad = op.getInput(0)->grad();
auto outputGrad = op.getOutput(0)->grad();
Log::notice("Size of stack in grad : {}", inputGrad->size());
Log::notice("Size of stack out grad : {}", outputGrad->size());
*inputGrad = outputGrad->extract({op.backwardStep() -1 }).clone();
}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
@@ -137,9 +148,13 @@ std::set<std::string> StackOp::getAvailableBackends() const {
void StackOp::forward() {
OperatorTensor::forward();
++forwardStep();
backwardStep() = forwardStep();
}
void StackOp::backward() {}
void StackOp::backward() {
OperatorTensor::backward();
--backwardStep();
}
std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) {
Loading