diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index f6f9d1ce88836b81b02525836bc59a43ae0aec9e..63ef1a0623349e7dc998e908803c3c6979b64096 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -32,7 +32,7 @@ class Operator; */ struct ImplSpec { struct IOSpec { - IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}): + IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<int, int>> dims_ = {}): type(type_), format(format_), dims(dims_) @@ -40,7 +40,7 @@ struct ImplSpec { DataType type; DataFormat format; - std::vector<std::pair<DimSize_t, DimSize_t>> dims; + std::vector<std::pair<int, int>> dims; }; ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()): @@ -128,6 +128,7 @@ public: protected: virtual std::shared_ptr<ProdConso> getProdConso() const; virtual std::vector<ImplSpec> getAvailableImplSpecs() const; + bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const; const Operator &mOp; const std::string mBackend; diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index a04952c4b29a8f55ad3f7c624a2b196f604f4b6a..91db61e40b42e65fa4b12f4a73bf643d1db63367 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -40,9 +40,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { // Inputs specs for (size_t i = 0; i < opTensor.nbInputs(); ++i) { if (opTensor.getInput(i)) { - std::vector<std::pair<DimSize_t, DimSize_t>> dims; + std::vector<std::pair<int, int>> dims; for (auto dim : opTensor.getInput(i)->dims()) { - dims.push_back(std::make_pair(dim, dim)); + dims.push_back(std::make_pair<int, int>(dim, dim)); } requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims}); @@ -53,9 +53,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { } // Outputs specs for (size_t i = 0; i < opTensor.nbOutputs(); ++i) { - std::vector<std::pair<DimSize_t, DimSize_t>> dims; + std::vector<std::pair<int, int>> dims; for (auto dim : opTensor.getOutput(i)->dims()) { - dims.push_back(std::make_pair(dim, dim)); + dims.push_back(std::make_pair<int, int>(dim, dim)); } requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); @@ -78,42 +78,25 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const for (size_t s = 0; s < availableSpecs.size(); ++s) { auto spec = availableSpecs[s]; - int match = true; + bool match = true; int priority = 0; // Check inputs for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { - if (requiredSpecs.inputs[i].type != DataType::Any - && spec.inputs[i].type != DataType::Any - && requiredSpecs.inputs[i].type != spec.inputs[i].type) - { + if (!checkIOSpec(requiredSpecs.inputs[i], spec.inputs[i])) { match = false; break; } + } - if (requiredSpecs.inputs[i].format != DataFormat::Any - && spec.inputs[i].format != DataFormat::Any - && requiredSpecs.inputs[i].format != spec.inputs[i].format) - { + // Check outputs + for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) { + if (!checkIOSpec(requiredSpecs.outputs[i], spec.outputs[i])) { match = false; break; } - - if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) { - if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) { - match = false; - break; - } - - for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) { - // TODO - } - } } - // Check outputs - // TODO - // Check attributes for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) { std::string name = attrName; @@ -173,6 +156,42 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const return requiredSpecs; } +bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const { + if (required.type != DataType::Any + && spec.type != DataType::Any + && required.type != spec.type) + { + return false; + } + + if (required.format != DataFormat::Any + && spec.format != DataFormat::Any + && required.format != spec.format) + { + return false; + } + + if (!required.dims.empty() && !spec.dims.empty()) { + if (required.dims.size() != spec.dims.size()) { + return false; + } + + for (size_t dim = 0; dim < required.dims.size(); ++dim) { + const auto requiredDim = required.dims[dim]; + const auto specDim = spec.dims[dim]; + + if (requiredDim.first != -1 + && specDim.first != -1 + && !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second)) + { + return false; + } + } + } + + return true; +} + std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) { // TODO: have a generic getBestAlternative() that handle at least data type and data format conversions. return nullptr;