Skip to content
Snippets Groups Projects
Commit cd743cb2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fix comparison operators and improved selection concept

parent 924d4a7d
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53353 failed
...@@ -26,25 +26,26 @@ namespace Aidge { ...@@ -26,25 +26,26 @@ namespace Aidge {
class Node; class Node;
class Operator; class Operator;
/**
* @brief ImplSpec stores the requirements or the specifications of an implementation.
*
*/
struct ImplSpec { struct ImplSpec {
struct IOSpec { struct IOSpec {
IOSpec(DataType type_): IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}):
type(type_),
format(DataFormat::Any),
dims({})
{}
IOSpec(DataType type_, DataFormat format_):
type(type_), type(type_),
format(format_), format(format_),
dims({}) dims(dims_)
{} {}
DataType type; DataType type;
DataFormat format; DataFormat format;
std::vector<std::pair<size_t, size_t>> dims; std::vector<std::pair<DimSize_t, DimSize_t>> dims;
}; };
ImplSpec() {
}
ImplSpec(IOSpec io) { ImplSpec(IOSpec io) {
inputs.push_back(io); inputs.push_back(io);
outputs.push_back(io); outputs.push_back(io);
...@@ -60,14 +61,28 @@ struct ImplSpec { ...@@ -60,14 +61,28 @@ struct ImplSpec {
//DynamicAttributes attrs; //DynamicAttributes attrs;
}; };
inline bool operator==(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
return (lhs.type == rhs.type)
&& (lhs.format == rhs.format)
&& (lhs.dims == rhs.dims);
}
inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) { inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
return (lhs.type < rhs.type) && (lhs.format < rhs.format) && (lhs.dims < rhs.dims); return (lhs.type < rhs.type)
|| (lhs.type == rhs.type && lhs.format < rhs.format)
|| (lhs.type == rhs.type && lhs.format == rhs.format && lhs.dims < rhs.dims);
} }
inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) { inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) {
return (lhs.inputs < rhs.inputs) && (lhs.outputs < rhs.outputs); return (lhs.inputs < rhs.inputs)
|| (lhs.inputs == rhs.inputs && lhs.outputs < rhs.outputs);
} }
/**
* @brief Impl stores the details of a specific implementation.
* It is associated to a ImplSpec in a registry.
*
*/
template <class FwdFunc, class BwdFunc> template <class FwdFunc, class BwdFunc>
struct Impl { struct Impl {
Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_, Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_,
...@@ -96,23 +111,20 @@ public: ...@@ -96,23 +111,20 @@ public:
* to the current operator configuration. * to the current operator configuration.
* *
*/ */
ImplSpec getRequiredSpec() const { ImplSpec getRequiredSpec() const;
// TODO
return ImplSpec{DataType::Float32};
}
/** /**
* @brief Get the best implementation that matches \p requiredSpecs. * @brief Get the best implementation that matches \p requiredSpecs.
* *
*/ */
ImplSpec getBestMatch(ImplSpec /*requiredSpecs*/) const { ImplSpec getBestMatch(ImplSpec requiredSpecs) const;
// TODO:
return getAvailableImplSpecs()[0];
}
// std::shared_ptr<Node> getAdaptedOp(ImplSpec requiredSpecs) {
// } /**
* @brief Get the best alternative node (containing a meta operator)
* fulfilling the requirements.
*
*/
std::shared_ptr<Node> getBestAlternative(ImplSpec requiredSpecs);
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/scheduler/ProdConso.hpp" #include "aidge/scheduler/ProdConso.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
...@@ -32,6 +33,105 @@ std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() { ...@@ -32,6 +33,105 @@ std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() {
return mProdConso; return mProdConso;
} }
Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
const auto& opTensor = dynamic_cast<const OperatorTensor&>(mOp);
ImplSpec requiredSpec;
// Inputs specs
for (size_t i = 0; i < opTensor.nbInputs(); ++i) {
if (opTensor.getInput(i)) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims;
for (auto dim : opTensor.getInput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim));
}
requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims});
}
else {
requiredSpec.inputs.push_back({DataType::Any});
}
}
// Outputs specs
for (size_t i = 0; i < opTensor.nbOutputs(); ++i) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims;
for (auto dim : opTensor.getOutput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim));
}
requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
}
return requiredSpec;
}
Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const {
Log::debug("getBestMatch() for requirements: {}", requiredSpecs);
const auto availableSpecs = getAvailableImplSpecs();
std::vector<bool> matchingSpecs(availableSpecs.size(), false);
for (size_t s = 0; s < availableSpecs.size(); ++s) {
auto spec = availableSpecs[s];
bool match = true;
// Check inputs
for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
if (requiredSpecs.inputs[i].type != DataType::Any
&& spec.inputs[i].type != DataType::Any
&& requiredSpecs.inputs[i].type != spec.inputs[i].type)
{
match = false;
break;
}
if (requiredSpecs.inputs[i].format != DataFormat::Any
&& spec.inputs[i].format != DataFormat::Any
&& requiredSpecs.inputs[i].format != spec.inputs[i].format)
{
match = false;
break;
}
if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) {
if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) {
match = false;
break;
}
for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) {
// TODO
}
}
}
// Check outputs
// TODO
// Check attributes
// TODO
matchingSpecs[s] = match;
Log::debug(" {} - {}", (match) ? "MATCH" : "MISMATCH", spec);
}
// Return best match
// TODO: for now, returns the **first** match
for (size_t s = 0; s < availableSpecs.size(); ++s) {
if (matchingSpecs[s]) {
return availableSpecs[s];
}
}
// If there is no match, return the required specs for the registrar, which
// will throw a "missing or invalid registrar key"
return requiredSpecs;
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) {
// TODO: have a generic getBestAlternative() that handle at least data type and data format conversions.
return nullptr;
}
void Aidge::OperatorImpl::forward() { void Aidge::OperatorImpl::forward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented yet for operator of type {}", mOp.type());
} }
......
...@@ -89,7 +89,7 @@ void Aidge::Log::log(Level level, const std::string& msg) { ...@@ -89,7 +89,7 @@ void Aidge::Log::log(Level level, const std::string& msg) {
fmt::println("Context: {}", context); fmt::println("Context: {}", context);
} }
fmt::println(mFile.get(), msg); fmt::println(mFile.get(), "{}", msg);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment