From d9e59761905aba69ad39da6c1b31f5bd72ec27c6 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 15 Mar 2024 14:31:19 +0100
Subject: [PATCH] Updated C-P model to work with both data and tokens

---
 include/aidge/backend/OperatorImpl.hpp        |  15 ++-
 include/aidge/data/Elts.hpp                   | 124 ++++++++++++++++++
 include/aidge/operator/MetaOperator.hpp       |  10 +-
 include/aidge/operator/Operator.hpp           |  16 +--
 include/aidge/scheduler/Scheduler.hpp         |   2 +-
 include/aidge/utils/ErrorHandling.hpp         |   1 +
 .../backend/pybind_OperatorImpl.cpp           |  20 +--
 src/backend/OperatorImpl.cpp                  |  73 ++++++++---
 src/operator/MetaOperator.cpp                 |  20 +--
 src/operator/Operator.cpp                     |  10 +-
 src/scheduler/Scheduler.cpp                   |  51 ++++---
 unit_tests/scheduler/Test_Scheduler.cpp       |  42 ++++++
 12 files changed, 305 insertions(+), 79 deletions(-)
 create mode 100644 include/aidge/data/Elts.hpp

diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp
index 8b5aba10d..215ac804c 100644
--- a/include/aidge/backend/OperatorImpl.hpp
+++ b/include/aidge/backend/OperatorImpl.hpp
@@ -16,6 +16,7 @@
 #include <vector>
 #include <memory>
 #include "aidge/utils/Types.h"
+#include "aidge/data/Elts.hpp"
 
 namespace Aidge {
 class Operator;
@@ -33,13 +34,13 @@ public:
      * @param inputIdx Index of the input analysed.
      * @return std::size_t
      */
-    virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
 
     // Amount of input data that cannot be overwritten during the execution.
-    virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
 
     // Memory required at an output for a given input size.
-    virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
+    virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
 
     /**
      * @brief Total amount of consumed data from a specific input.
@@ -47,7 +48,7 @@ public:
      * @param inputIdx Index of the input analysed.
      * @return DimSize_t
      */
-    virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
 
     /**
      * @brief Total amount of produced data ready to be used on a specific output.
@@ -55,7 +56,7 @@ public:
      * @param outputIdx Index of the output analysed.
      * @return DimSize_t
      */
-    virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
+    virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
 
     /**
      * @brief Update the Consummer Producer system by simulating the consumption and production of i/o
@@ -73,8 +74,8 @@ public:
 
 protected:
     const Operator &mOp;
-    std::vector<NbElts_t> mNbConsumedData;
-    std::vector<NbElts_t> mNbProducedData;
+    std::vector<Elts_t> mNbConsumedData;
+    std::vector<Elts_t> mNbProducedData;
 };
 } // namespace Aidge
 
diff --git a/include/aidge/data/Elts.hpp b/include/aidge/data/Elts.hpp
new file mode 100644
index 000000000..1a5a9e10e
--- /dev/null
+++ b/include/aidge/data/Elts.hpp
@@ -0,0 +1,124 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_ELTS_H_
+#define AIDGE_ELTS_H_
+
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+/**
+ * Base object for Aidge consumer-producer model (C-P model).
+ * It is a hybrid model: operator implementations can specify their C-P model
+ * with precise data (bytes) or with tokens.
+*/
+struct Elts_t {
+    enum EltType {
+        Data,
+        Token,
+        Undef
+    };
+
+    NbElts_t data;
+    NbElts_t token;
+    EltType type;
+
+    // Addition operator
+    inline Elts_t operator+(const Elts_t& other) const {
+        AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
+            "Incompatible C-P model types: {} + {}. Data and Token cannot be mixed.", type, other.type);
+        return Elts_t(data + other.data, token + other.token, (other.type == Undef) ? type : other.type);
+    }
+
+    // Addition assignment operator
+    inline Elts_t& operator+=(const Elts_t& other) {
+        AIDGE_ASSERT(type == other.type || other.type == Undef || type == Undef,
+            "Incompatible C-P model types: {} += {}. Data and Token cannot be mixed.", type, other.type);
+        data += other.data;
+        token += other.token;
+        type = (other.type == Undef) ? type : other.type;
+        return *this;
+    }
+
+    // Comparison operators
+    inline bool operator<(const Elts_t& other) const {
+        if (type == Elts_t::Undef || type == Elts_t::Token) {
+            // Nothing, or only a token is required: don't care about how much data has been produced for the token
+            return (token < other.token);
+        }
+        else if (type == Elts_t::Data && other.type != Elts_t::Token) {
+            // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
+            return (data < other.data);
+        }
+        else {
+            AIDGE_THROW_OR_ABORT(std::runtime_error,
+                "Incompatible C-P model types: {} < {}. Data is expected for right-hand side.", type, other.type);
+        }
+    }
+
+    inline bool operator>(const Elts_t& other) const {
+        if (type == Elts_t::Undef || type == Elts_t::Token) {
+            // Nothing, or only a token is required: don't care about how much data has been produced for the token
+            return (token > other.token);
+        }
+        else if (type == Elts_t::Data && other.type != Elts_t::Token) {
+            // A precise amount of data is required, so the amount of produced data must be specified, a token is not enough
+            return (data > other.data);
+        }
+        else {
+            AIDGE_THROW_OR_ABORT(std::runtime_error,
+                "Incompatible C-P model types: {} > {}. Data is expected for right-hand side.", type, other.type);
+        }
+    }
+
+    inline static Elts_t NoneElts() {
+        return Elts_t(0, 0, Elts_t::Undef);
+    }
+
+    inline static Elts_t DataElts(NbElts_t data, NbElts_t token = 1) {
+        return Elts_t(data, token, Elts_t::Data);
+    }
+
+    inline static Elts_t TokenElts(NbElts_t token) {
+        return Elts_t(0, token, Elts_t::Token);
+    }
+
+private:
+    inline Elts_t(NbElts_t data_, NbElts_t token_, EltType type_):
+        data(data_), token(token_), type(type_) {}
+};
+} // end namespace Aidge
+
+template<>
+struct fmt::formatter<Aidge::Elts_t> {
+    template<typename ParseContext>
+    inline constexpr auto parse(ParseContext& ctx) {
+        return ctx.begin();
+    }
+
+    template<typename FormatContext>
+    inline auto format(Aidge::Elts_t const& elt, FormatContext& ctx) {
+        return fmt::format_to(ctx.out(), "{}:{}", elt.data, elt.token);
+    }
+};
+
+namespace {
+template <>
+const char* const EnumStrings<Aidge::Elts_t::EltType>::data[]
+    = {"Data", "Token", "Undef"};
+}
+
+namespace Aidge {
+inline auto format_as(Elts_t::EltType elt) { return EnumStrings<Aidge::Elts_t::EltType>::data[static_cast<int>(elt)]; }
+}
+
+#endif /* AIDGE_ELTS_H_ */
diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index ce328c23f..cd23acd90 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -107,11 +107,11 @@ public:
         mGraph->setDataType(datatype);
     }
 
-    NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
-    NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
-    NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
-    NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override;
-    NbElts_t getNbProducedData(IOIndex_t outputIdx) const override;
+    Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
+    Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
+    Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
+    Elts_t getNbConsumedData(IOIndex_t inputIdx) const override;
+    Elts_t getNbProducedData(IOIndex_t outputIdx) const override;
 
     void updateConsummerProducer() override;
     void forward() override;
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 396c60e46..6e2e44426 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -131,31 +131,31 @@ public:
     /**
      * @brief Minimum amount of data from a specific input for one computation pass.
      * @param inputIdx Index of the input analysed.
-     * @return NbElts_t
+     * @return Elts_t
      */
-    virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
 
     // Amount of input data that cannot be overwritten during the execution.
-    virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
 
     // Memory required at an output for a given input size.
-    virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
+    virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
 
     /**
      * @brief Total amount of consumed data from a specific input.
      *
      * @param inputIdx Index of the input analysed.
-     * @return NbElts_t
+     * @return Elts_t
      */
-    virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
+    virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
 
     /**
      * @brief Total amount of produced data ready to be used on a specific output.
      *
      * @param outputIdx Index of the output analysed.
-     * @return NbElts_t
+     * @return Elts_t
      */
-    virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
+    virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
 
     virtual void updateConsummerProducer();
 
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index e0284f0fb..4c5b3bd4c 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -141,7 +141,7 @@ protected:
      * @return std::set<std::shared_ptr<Node>>
      */
     std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
-    NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
+    Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
     PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
 
     /** @brief Shared ptr to the scheduled graph view */
diff --git a/include/aidge/utils/ErrorHandling.hpp b/include/aidge/utils/ErrorHandling.hpp
index d4235d2db..f6a9aefe2 100644
--- a/include/aidge/utils/ErrorHandling.hpp
+++ b/include/aidge/utils/ErrorHandling.hpp
@@ -14,6 +14,7 @@
 #define AIDGE_ERRORHANDLING_H_
 
 #include <memory>
+#include <cassert>
 
 #include <fmt/format.h>
 #include <fmt/ranges.h>
diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp
index 91d65484a..5259d877d 100644
--- a/python_binding/backend/pybind_OperatorImpl.cpp
+++ b/python_binding/backend/pybind_OperatorImpl.cpp
@@ -42,18 +42,18 @@ public:
 
         );
     }
-    NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
+    Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
         PYBIND11_OVERRIDE_NAME(
-            NbElts_t,
+            Elts_t,
             OperatorImpl,
             "get_nb_required_data",
             getNbRequiredData,
             inputIdx
         );
     }
-    NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
+    Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
         PYBIND11_OVERRIDE_NAME(
-            NbElts_t,
+            Elts_t,
             OperatorImpl,
             "get_nb_required_protected",
             getNbRequiredProtected,
@@ -61,10 +61,10 @@ public:
 
         );
     }
-    NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
+    Elts_t getRequiredMemory(const IOIndex_t outputIdx,
     const std::vector<DimSize_t> &inputsSize) const override {
         PYBIND11_OVERRIDE_NAME(
-            NbElts_t,
+            Elts_t,
             OperatorImpl,
             "get_required_memory",
             getRequiredMemory,
@@ -73,9 +73,9 @@ public:
 
         );
     }
-    NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
+    Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
         PYBIND11_OVERRIDE_NAME(
-            NbElts_t,
+            Elts_t,
             OperatorImpl,
             "get_nb_consumed_data",
             getNbConsumedData,
@@ -83,9 +83,9 @@ public:
 
         );
     }
-    NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override {
+    Elts_t getNbProducedData(const IOIndex_t outputIdx) const override {
         PYBIND11_OVERRIDE_NAME(
-            NbElts_t,
+            Elts_t,
             OperatorImpl,
             "get_nb_produced_data",
             getNbProducedData,
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
index 1439391b2..42e8545d3 100644
--- a/src/backend/OperatorImpl.cpp
+++ b/src/backend/OperatorImpl.cpp
@@ -18,48 +18,91 @@
 
 Aidge::OperatorImpl::OperatorImpl(const Operator& op):
     mOp(op),
-    mNbConsumedData(mOp.nbInputs(), 0),
-    mNbProducedData(mOp.nbOutputs(), 0)
+    mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
+    mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
 {
     //ctor
 }
 
-Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
     AIDGE_ASSERT(mOp.getRawInput(inputIdx),
         "a valid input is required at index {} for operator type {}",
         inputIdx, mOp.type());
 
-    // Requires the whole tensor by default
-    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
+    if (mOp.getRawInput(inputIdx)) {
+        const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
+        if (!input->empty()) {
+            // Known amount of data: requires the whole tensor by default
+            return Elts_t::DataElts(input->size());
+        }
+        else {
+            // Unknown amount of data: require a single token by default
+            return Elts_t::TokenElts(1);
+        }
+    }
+
+    // Input not connected, meaning it is an optional input: do no require anything!
+    return Elts_t::NoneElts();
 }
 
-Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
     AIDGE_ASSERT(mOp.getRawInput(inputIdx),
         "a valid input is required at index {} for operator type {}",
         inputIdx, mOp.type());
 
-    // Protect the whole tensor by default
-    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
+    if (mOp.getRawInput(inputIdx)) {
+        const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
+        if (!input->empty()) {
+            // Known amount of data: protect the whole tensor by default
+            return Elts_t::DataElts(input->size());
+        }
+        else {
+            // Unknown amount of data: protect a single token by default
+            // (this does not really make sense for now, as getNbRequiredProtected()
+            // is supposed to give a precise amount of data to protect for
+            // memory management purpose...)
+            return Elts_t::TokenElts(1);
+        }
+    }
+
+    // Input not connected, meaning it is an optional input: do no require anything!
+    return Elts_t::NoneElts();
 }
 
-Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
+Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
                                                          const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
     AIDGE_ASSERT(mOp.getRawOutput(outputIdx),
         "a valid output is required at index {} for operator type {}",
         outputIdx, mOp.type());
 
-    // Requires the whole tensor by default, regardless of available data on inputs
-    return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
+    if (mOp.getRawOutput(outputIdx)) {
+        const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
+        if (!output->empty()) {
+            // Known amount of data: requires the whole tensor by default,
+            // regardless of available data on inputs
+            return Elts_t::DataElts(output->size());
+        }
+        else {
+            // Unknown amount of data: require a single token by default
+            // (this does not really make sense for now, as getRequiredMemory()
+            // is supposed to give a precise amount of data to allocate for
+            // memory management purpose...)
+            return Elts_t::TokenElts(1);
+        }
+    }
+
+    // Output not set, meaning it is an optional output: do no require anything!
+    return Elts_t::NoneElts();
 }
 
-Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
     AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
         "input index ({}) is out of bound ({}) for operator type {}",
         inputIdx, mNbConsumedData.size(), mOp.type());
     return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
 }
 
-Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
+Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
     AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
         "output index ({}) is out of bound ({}) for operator type {}",
         outputIdx, mNbProducedData.size(), mOp.type());
@@ -79,8 +122,8 @@ void Aidge::OperatorImpl::updateConsummerProducer(){
 }
 
 void Aidge::OperatorImpl::resetConsummerProducer(){
-    std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0);
-    std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0);
+    std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
+    std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
 }
 
 void Aidge::OperatorImpl::forward() {
diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp
index 883185021..1d15db1fb 100644
--- a/src/operator/MetaOperator.cpp
+++ b/src/operator/MetaOperator.cpp
@@ -30,7 +30,7 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<
     }
 }
 
-Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
     if (mImpl) {
         return mImpl->getNbRequiredData(inputIdx);
     }
@@ -40,12 +40,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
             return inputOp.first->getOperator()->getNbRequiredData(inputOp.second);
         }
         else {
-            return 0;
+            return Elts_t::NoneElts();
         }
     }
 }
 
-Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
     if (mImpl) {
         return mImpl->getNbRequiredProtected(inputIdx);
     }
@@ -55,12 +55,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t i
             return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second);
         }
         else {
-            return 0;
+            return Elts_t::NoneElts();
         }
     }
 }
 
-Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
+Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
     if (mImpl) {
         return mImpl->getRequiredMemory(outputIdx, inputsSize);
     }
@@ -70,12 +70,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t output
             return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize);
         }
         else {
-            return 0;
+            return Elts_t::NoneElts();
         }
     }
 }
 
-Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
     if (mImpl) {
         return mImpl->getNbConsumedData(inputIdx);
     }
@@ -85,12 +85,12 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) co
             return inputOp.first->getOperator()->getNbConsumedData(inputOp.second);
         }
         else {
-            return 0;
+            return Elts_t::NoneElts();
         }
     }
 }
 
-Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const {
+Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const {
     if (mImpl) {
         return mImpl->getNbProducedData(outputIdx);
     }
@@ -100,7 +100,7 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) c
             return outputOp.first->getOperator()->getNbProducedData(outputOp.second);
         }
         else {
-            return 0;
+            return Elts_t::NoneElts();
         }
     }
 }
diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp
index e4213cad8..317bbd364 100644
--- a/src/operator/Operator.cpp
+++ b/src/operator/Operator.cpp
@@ -31,27 +31,27 @@ Aidge::Operator::~Operator() noexcept = default;
 //        IMPLEMENTATION
 ///////////////////////////////////////////////////////
 
-Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
     AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type());
     return mImpl->getNbRequiredData(inputIdx);
 }
 
-Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
     AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type());
     return mImpl->getNbRequiredProtected(inputIdx);
 }
 
-Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
+Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
     AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type());
     return mImpl->getRequiredMemory(outputIdx, inputsSize);
 }
 
-Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
     AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type());
     return mImpl->getNbConsumedData(inputIdx);
 }
 
-Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
+Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
     AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type());
     return mImpl->getNbProducedData(outputIdx);
 }
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 639375902..906b3fa71 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -138,8 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
 
             bool isRunnable = true;
             for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
-                if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
-                    && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
+                if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
                             getNbAvailableData(consumer, inputIdx)) {
                     Log::debug("  not runnable: C{} + R{} > P{} for input #{}",
                         consumer->getOperator()->getNbConsumedData(inputIdx),
@@ -226,12 +225,17 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
                         IOIndex_t inputIdx = 0;
                         for (const auto& childParent : child->getParents()) {
                             if (childParent == consumer) {
-                                if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
+                                if (child->getOperator()->getNbConsumedData(inputIdx) < consumer->getOperator()->getNbProducedData(outId)) {
                                     isProducer = true;
+                                    break;
                                 }
                             }
                             ++inputIdx;
                         }
+
+                        if (isProducer) {
+                            break;
+                        }
                     }
                 }
 /*
@@ -383,17 +387,22 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
             }
             
             const auto childs = node->getChildren();
-            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
+            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor,
+                "Operator must be of Tensor type for node {} (of type {}).",
+                node->name(), node->type());
             const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
 
             std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
 
             // Allocate a memory plane for each node's output
             for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
-                const size_t requiredSize = op->getRequiredMemory(outputIdx, {});
+                const auto requiredSize = op->getRequiredMemory(outputIdx, {});
+                AIDGE_ASSERT(requiredSize.type == Elts_t::Data,
+                    "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
+                    node->name(), node->type());
 
                 // By default, specifies a fully monolithic memory block
-                size_t size = requiredSize;
+                size_t size = requiredSize.data;
                 size_t stride = 0;
                 size_t length = 1;
                 size_t count = 1;
@@ -425,21 +434,27 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
                         // memSpace should not be already released
                         && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
                     {
-                        const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
+                        const auto requiredData = op->getNbRequiredData(inputIdx);
+                        const auto requiredProtected = op->getNbRequiredProtected(inputIdx);
+                        AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data,
+                            "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).",
+                            node->name(), node->type());
+
+                        const bool isWrappable = (requiredProtected.data < requiredData.data);
                         const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
 
                         if (isWrappable || !memManager.isWrapAround(
                                     memPlane.memSpace,
                                     memPlane.getFinalOffset()
                                         - memPlane.memSpace->offset,
-                                    requiredSize))
+                                    requiredSize.data))
                         {
-                            if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
+                            if (memPlane.getSize() > wrapAroundSize + requiredProtected.data
                                 && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
                             {
-                                wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
-                                if (requiredSize > wrapAroundSize) {
-                                    wrapAroundExtra = requiredSize - wrapAroundSize;
+                                wrapAroundSize = memPlane.getSize() - requiredProtected.data;
+                                if (requiredSize.data > wrapAroundSize) {
+                                    wrapAroundExtra = requiredSize.data - wrapAroundSize;
                                 }
                                 wrapAroundMemPlane[outputIdx] = &memPlane;
                             }
@@ -456,17 +471,17 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr
                 const MemoryManager::MemoryPlane& memPlane
                     = (wrapAroundBuffer && wrapAroundSize > 0)
                         ? (*wrapAroundMemPlane[outputIdx]) :
-                            memManager.allocate(requiredSize, childs, stride, length, count);
+                            memManager.allocate(requiredSize.data, childs, stride, length, count);
 
                 if (wrapAroundBuffer && wrapAroundSize > 0) {
                     memManager.reallocate(memPlane,
                         node, 0,
-                        requiredSize, true, wrapAroundExtra, childs, stride, length, count);
+                        requiredSize.data, true, wrapAroundExtra, childs, stride, length, count);
                 }
                 else {
                     memManager.reallocate(memPlane.memSpace,
                         node, memPlane.offset,
-                        requiredSize, false, 0, childs, stride, length, count);
+                        requiredSize.data, false, 0, childs, stride, length, count);
                 }
             }
 
@@ -574,7 +589,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers(
     return consumers;
 }
 
-Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
+Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
     const auto parent = node->inputs()[inputIdx];
 
     if (parent.first) {
@@ -605,14 +620,14 @@ Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>
         // In this case, we assume a single-use data (unlike a Producer, which
         // keep producing the data each time it is needed).
         fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
-        return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
+        return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size());
     }
     else {
         // Input is not connected, this is an error
         AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
     }
 
-    return 0;
+    return Elts_t::NoneElts();
 }
 
 Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index ab5fef1f6..75a0daed6 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -75,3 +75,45 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
 
     fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
 }
+
+TEST_CASE("randomScheduling_tokens", "[Scheduler][randomGen]") {
+    const size_t nbTests = 100;
+    size_t nbUnicity = 0;
+
+    for (int test = 0; test < nbTests; ++test) {
+        std::random_device rd;
+        const std::mt19937::result_type seed(rd());
+
+        RandomGraph randGraph;
+        randGraph.acyclic = true;
+        const auto g1 = std::make_shared<GraphView>("g1");
+        const bool unicity1 = g1->add(randGraph.gen(seed, 10));
+
+        if (unicity1) {
+            const auto orderedInputs = g1->getOrderedInputs();
+            for (const auto& input : orderedInputs) {
+                auto prod = Producer({16, 32});
+                prod->addChild(input.first, 0, input.second);
+                g1->add(prod);
+            }
+
+            g1->save("schedule");
+
+            auto scheduler = SequentialScheduler(g1);
+            scheduler.generateScheduling();
+            const auto sch = scheduler.getStaticScheduling();
+
+            const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
+
+            std::vector<std::string> nodesName;
+            std::transform(sch.begin(), sch.end(),
+                std::back_inserter(nodesName),
+                [&namePtrTable](auto val){ return namePtrTable.at(val); });
+
+            fmt::print("schedule: {}\n", nodesName);
+            REQUIRE(sch.size() == 10 + orderedInputs.size());
+        }
+    }
+
+    fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
+}
-- 
GitLab