From cd743cb25ca1deb4866e5defb821ef07bd0657d7 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 26 Aug 2024 12:58:12 +0200
Subject: [PATCH] Fix comparison operators and improved selection concept

---
 include/aidge/backend/OperatorImpl.hpp |  56 ++++++++------
 src/backend/OperatorImpl.cpp           | 100 +++++++++++++++++++++++++
 src/utils/Log.cpp                      |   2 +-
 3 files changed, 135 insertions(+), 23 deletions(-)

diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp
index a9f968c59..76dcdc3c4 100644
--- a/include/aidge/backend/OperatorImpl.hpp
+++ b/include/aidge/backend/OperatorImpl.hpp
@@ -26,25 +26,26 @@ namespace Aidge {
 class Node;
 class Operator;
 
+/**
+ * @brief ImplSpec stores the requirements or the specifications of an implementation.
+ * 
+ */
 struct ImplSpec {
     struct IOSpec {
-        IOSpec(DataType type_):
-            type(type_),
-            format(DataFormat::Any),
-            dims({})
-        {}
-
-        IOSpec(DataType type_, DataFormat format_):
+        IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}):
             type(type_),
             format(format_),
-            dims({})
+            dims(dims_)
         {}
 
         DataType type;
         DataFormat format;
-        std::vector<std::pair<size_t, size_t>> dims;
+        std::vector<std::pair<DimSize_t, DimSize_t>> dims;
     };
 
+    ImplSpec() {
+    }
+
     ImplSpec(IOSpec io) {
         inputs.push_back(io);
         outputs.push_back(io);
@@ -60,14 +61,28 @@ struct ImplSpec {
     //DynamicAttributes attrs;
 };
 
+inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
+    return (lhs.type == rhs.type)
+        && (lhs.format == rhs.format)
+        && (lhs.dims == rhs.dims);
+}
+
 inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
-    return (lhs.type < rhs.type) && (lhs.format < rhs.format) && (lhs.dims < rhs.dims);
+    return (lhs.type < rhs.type)
+        || (lhs.type == rhs.type && lhs.format < rhs.format)
+        || (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims);
 }
 
 inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) {
-    return (lhs.inputs < rhs.inputs) && (lhs.outputs < rhs.outputs);
+    return (lhs.inputs < rhs.inputs)
+        || (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs);
 }
 
+/**
+ * @brief Impl stores the details of a specific implementation.
+ * It is associated to a ImplSpec in a registry.
+ * 
+ */
 template <class FwdFunc, class BwdFunc>
 struct Impl {
     Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_,
@@ -96,23 +111,20 @@ public:
      * to the current operator configuration.
      * 
      */
-    ImplSpec getRequiredSpec() const {
-        // TODO
-        return ImplSpec{DataType::Float32};
-    }
+    ImplSpec getRequiredSpec() const;
 
     /**
      * @brief Get the best implementation that matches \p requiredSpecs.
      * 
      */
-    ImplSpec getBestMatch(ImplSpec /*requiredSpecs*/) const {
-        // TODO:
-        return getAvailableImplSpecs()[0];
-    }
-
-    // std::shared_ptr<Node> getAdaptedOp(ImplSpec requiredSpecs) {
+    ImplSpec getBestMatch(ImplSpec requiredSpecs) const;
 
-    // }
+    /**
+     * @brief Get the best alternative node (containing a meta operator)
+     * fulfilling the requirements.
+     * 
+     */
+    std::shared_ptr<Node> getBestAlternative(ImplSpec requiredSpecs);
 
     virtual ~OperatorImpl() = default;
 
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
index 0eb1e635a..5e29418d4 100644
--- a/src/backend/OperatorImpl.cpp
+++ b/src/backend/OperatorImpl.cpp
@@ -14,6 +14,7 @@
 
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/Operator.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/scheduler/ProdConso.hpp"
 #include "aidge/data/Tensor.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
@@ -32,6 +33,105 @@ std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() {
     return mProdConso;
 }
 
+Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
+    const auto& opTensor = dynamic_cast<const OperatorTensor&>(mOp);
+
+    ImplSpec requiredSpec;
+    // Inputs specs
+    for (size_t i = 0; i < opTensor.nbInputs(); ++i) {
+        if (opTensor.getInput(i)) {
+            std::vector<std::pair<DimSize_t, DimSize_t>> dims;
+            for (auto dim : opTensor.getInput(i)->dims()) {
+                dims.push_back(std::make_pair(dim, dim));
+            }
+
+            requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims});
+        }
+        else {
+            requiredSpec.inputs.push_back({DataType::Any});
+        }
+    }
+    // Outputs specs
+    for (size_t i = 0; i < opTensor.nbOutputs(); ++i) {
+        std::vector<std::pair<DimSize_t, DimSize_t>> dims;
+        for (auto dim : opTensor.getOutput(i)->dims()) {
+            dims.push_back(std::make_pair(dim, dim));
+        }
+
+        requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
+    }
+    return requiredSpec;
+}
+
+Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const {
+    Log::debug("getBestMatch() for requirements: {}", requiredSpecs);
+
+    const auto availableSpecs = getAvailableImplSpecs();
+    std::vector<bool> matchingSpecs(availableSpecs.size(), false);
+
+    for (size_t s = 0; s < availableSpecs.size(); ++s) {
+        auto spec = availableSpecs[s];
+        bool match = true;
+
+        // Check inputs
+        for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
+            if (requiredSpecs.inputs[i].type != DataType::Any
+                && spec.inputs[i].type != DataType::Any
+                && requiredSpecs.inputs[i].type != spec.inputs[i].type)
+            {
+                match = false;
+                break;
+            }
+
+            if (requiredSpecs.inputs[i].format != DataFormat::Any
+                && spec.inputs[i].format != DataFormat::Any
+                && requiredSpecs.inputs[i].format != spec.inputs[i].format)
+            {
+                match = false;
+                break;
+            }
+
+            if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) {
+                if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) {
+                    match = false;
+                    break;
+                }
+
+                for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) {
+                    // TODO
+                }
+            }
+        }
+
+        // Check outputs
+        // TODO
+
+        // Check attributes
+        // TODO
+
+        matchingSpecs[s] = match;
+
+        Log::debug("  {} - {}", (match) ? "MATCH" : "MISMATCH", spec);
+    }
+
+    // Return best match
+    // TODO: for now, returns the **first** match
+    for (size_t s = 0; s < availableSpecs.size(); ++s) {
+        if (matchingSpecs[s]) {
+            return availableSpecs[s];
+        }
+    }
+
+    // If there is no match, return the required specs for the registrar, which
+    // will throw a "missing or invalid registrar key"
+    return requiredSpecs;
+}
+
+std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) {
+    // TODO: have a generic getBestAlternative() that handle at least data type and data format conversions.
+    return nullptr;
+}
+
 void Aidge::OperatorImpl::forward() {
     AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type());
 }
diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp
index ae8816e78..da32a8e0e 100644
--- a/src/utils/Log.cpp
+++ b/src/utils/Log.cpp
@@ -89,7 +89,7 @@ void Aidge::Log::log(Level level, const std::string& msg) {
             fmt::println("Context: {}", context);
         }
 
-        fmt::println(mFile.get(), msg);
+        fmt::println(mFile.get(), "{}", msg);
     }
 }
 
-- 
GitLab