Skip to content
Snippets Groups Projects
Commit 94230383 authored by Jerome Hue's avatar Jerome Hue Committed by Olivier BICHLER
Browse files

override updateConsumerProducer method in Stack operator

parent dddb8ab3
No related branches found
No related tags found
2 merge requests!279v0.4.0,!256Add a Stack operator
...@@ -18,6 +18,7 @@ class StackProdConso : public ProdConso { ...@@ -18,6 +18,7 @@ class StackProdConso : public ProdConso {
Elts_t getRequiredMemory( Elts_t getRequiredMemory(
const IOIndex_t outputIdx, const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override final; const std::vector<DimSize_t> &inputsSize) const override final;
void resetConsummerProducer() override final;
}; };
class StackOpImpl : public OperatorImpl { class StackOpImpl : public OperatorImpl {
......
...@@ -24,7 +24,7 @@ namespace Aidge { ...@@ -24,7 +24,7 @@ namespace Aidge {
// TODO: Check why getRequiredMemory is always called with empty vector as // TODO: Check why getRequiredMemory is always called with empty vector as
// inputSize // inputSize
Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( Elts_t StackProdConso::getRequiredMemory(
const Aidge::IOIndex_t inputIdx, const Aidge::IOIndex_t inputIdx,
const std::vector<DimSize_t> &inputsSize) const { const std::vector<DimSize_t> &inputsSize) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input"); assert(mOp.getRawInput(inputIdx) && "requires valid input");
...@@ -35,6 +35,13 @@ Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( ...@@ -35,6 +35,13 @@ Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory(
return Elts_t::DataElts(op.getInput(inputIdx)->size()); return Elts_t::DataElts(op.getInput(inputIdx)->size());
} }
void StackProdConso::resetConsummerProducer() {
ProdConso::updateConsummerProducer();
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
op.forwardStep() = 0;
}
const std::string StackOp::s_type = "Stack"; const std::string StackOp::s_type = "Stack";
void StackOpImpl::forward() { void StackOpImpl::forward() {
...@@ -87,6 +94,7 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { ...@@ -87,6 +94,7 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
return false; return false;
} }
void StackOp::setBackend(const std::string &name, DeviceIdx_t device) { void StackOp::setBackend(const std::string &name, DeviceIdx_t device) {
if (Registrar<StackOp>::exists({name})) { if (Registrar<StackOp>::exists({name})) {
SET_IMPL_MACRO(StackOp, *this, name); SET_IMPL_MACRO(StackOp, *this, name);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment