From 8c0397d716673957cdf2e44fd94f95e75771d5e9 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 20 Oct 2023 16:24:25 +0200
Subject: [PATCH] Added default operator impl with default producer-consumer
 model

---
 include/aidge/backend/OperatorImpl.hpp | 25 ++++++---
 src/backend/OperatorImpl.cpp           | 77 ++++++++++++++++++++++++++
 2 files changed, 93 insertions(+), 9 deletions(-)
 create mode 100644 src/backend/OperatorImpl.cpp

diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp
index 453e30a86..19f083750 100644
--- a/include/aidge/backend/OperatorImpl.hpp
+++ b/include/aidge/backend/OperatorImpl.hpp
@@ -18,11 +18,13 @@
 #include "aidge/utils/Types.h"
 
 namespace Aidge {
+class Operator;
+
 class OperatorImpl {
 public:
-
-    virtual void forward(){};
-    virtual void backward(){};
+    OperatorImpl(const Operator& op);
+    virtual void forward();
+    virtual void backward();
 
     /**
      * @brief Minimum amount of data from a specific input required by the
@@ -31,13 +33,13 @@ public:
      * @param inputIdx Index of the input analysed.
      * @return std::size_t
      */
-    virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0;
+    virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
 
     // Amount of input data that cannot be overwritten during the execution.
-    virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0;
+    virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
 
     // Memory required at an output for a given input size.
-    virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0;
+    virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
 
     /**
      * @brief Total amount of consumed data from a specific input.
@@ -45,7 +47,7 @@ public:
      * @param inputIdx Index of the input analysed.
      * @return DimSize_t
      */
-    virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0;
+    virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
 
     /**
      * @brief Total amount of produced data ready to be used on a specific output.
@@ -53,15 +55,20 @@ public:
      * @param outputIdx Index of the output analysed.
      * @return DimSize_t
      */
-    virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0;
+    virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
 
     /**
      * @brief Update the Consummer Producer system by simulating the consumption and production of i/o
      *
      */
-    virtual void updateConsummerProducer() = 0;
+    virtual void updateConsummerProducer();
 
     virtual ~OperatorImpl() = default;
+
+protected:
+    const Operator &mOp;
+    std::vector<NbElts_t> mNbConsumedData;
+    std::vector<NbElts_t> mNbProducedData;
 };
 } // namespace Aidge
 
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
new file mode 100644
index 000000000..166754cc9
--- /dev/null
+++ b/src/backend/OperatorImpl.cpp
@@ -0,0 +1,77 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Operator.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+
+Aidge::OperatorImpl::OperatorImpl(const Operator& op):
+    mOp(op),
+    mNbConsumedData(mOp.nbInputs(), 0),
+    mNbProducedData(mOp.nbOutputs(), 0)
+{
+    //ctor
+}
+
+Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
+    assert(mOp.getInput(inputIdx) && "requires valid input");
+
+    // Requires the whole tensor by default
+    return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
+}
+
+Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
+    assert(mOp.getInput(inputIdx) && "requires valid input");
+
+    // Protect the whole tensor by default
+    return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size();
+}
+
+Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
+                                                         const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
+    assert(mOp.getOutput(outputIdx) && "requires valid output");
+
+    // Requires the whole tensor by default, regardless of available data on inputs
+    return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size();
+}
+
+Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
+    assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
+    return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
+}
+
+Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
+    assert(static_cast<std::size_t>(outputIdx) < mNbProducedData.size());
+    return mNbProducedData[static_cast<std::size_t>(outputIdx)];
+}
+
+void Aidge::OperatorImpl::updateConsummerProducer(){
+    // Update producer-consumer data
+    for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
+        // each input is consumed by the minimum amount for a forward pass
+        mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
+    }
+
+    for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
+        mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
+    }
+}
+
+void Aidge::OperatorImpl::forward() {
+    AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented");
+}
+
+void Aidge::OperatorImpl::backward() {
+    AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented");
+}
-- 
GitLab