Skip to content
Snippets Groups Projects
Commit 8857ae78 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added checkIOSpec()

parent 622478e5
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53524 failed
...@@ -32,7 +32,7 @@ class Operator; ...@@ -32,7 +32,7 @@ class Operator;
*/ */
struct ImplSpec { struct ImplSpec {
struct IOSpec { struct IOSpec {
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}): IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<int, int>> dims_ = {}):
type(type_), type(type_),
format(format_), format(format_),
dims(dims_) dims(dims_)
...@@ -40,7 +40,7 @@ struct ImplSpec { ...@@ -40,7 +40,7 @@ struct ImplSpec {
DataType type; DataType type;
DataFormat format; DataFormat format;
std::vector<std::pair<DimSize_t, DimSize_t>> dims; std::vector<std::pair<int, int>> dims;
}; };
ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()): ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()):
...@@ -128,6 +128,7 @@ public: ...@@ -128,6 +128,7 @@ public:
protected: protected:
virtual std::shared_ptr<ProdConso> getProdConso() const; virtual std::shared_ptr<ProdConso> getProdConso() const;
virtual std::vector<ImplSpec> getAvailableImplSpecs() const; virtual std::vector<ImplSpec> getAvailableImplSpecs() const;
bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const;
const Operator &mOp; const Operator &mOp;
const std::string mBackend; const std::string mBackend;
......
...@@ -40,9 +40,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { ...@@ -40,9 +40,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
// Inputs specs // Inputs specs
for (size_t i = 0; i < opTensor.nbInputs(); ++i) { for (size_t i = 0; i < opTensor.nbInputs(); ++i) {
if (opTensor.getInput(i)) { if (opTensor.getInput(i)) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims; std::vector<std::pair<int, int>> dims;
for (auto dim : opTensor.getInput(i)->dims()) { for (auto dim : opTensor.getInput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim)); dims.push_back(std::make_pair<int, int>(dim, dim));
} }
requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims}); requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims});
...@@ -53,9 +53,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { ...@@ -53,9 +53,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
} }
// Outputs specs // Outputs specs
for (size_t i = 0; i < opTensor.nbOutputs(); ++i) { for (size_t i = 0; i < opTensor.nbOutputs(); ++i) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims; std::vector<std::pair<int, int>> dims;
for (auto dim : opTensor.getOutput(i)->dims()) { for (auto dim : opTensor.getOutput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim)); dims.push_back(std::make_pair<int, int>(dim, dim));
} }
requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims}); requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
...@@ -78,42 +78,25 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const ...@@ -78,42 +78,25 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const
for (size_t s = 0; s < availableSpecs.size(); ++s) { for (size_t s = 0; s < availableSpecs.size(); ++s) {
auto spec = availableSpecs[s]; auto spec = availableSpecs[s];
int match = true; bool match = true;
int priority = 0; int priority = 0;
// Check inputs // Check inputs
for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
if (requiredSpecs.inputs[i].type != DataType::Any if (!checkIOSpec(requiredSpecs.inputs[i], spec.inputs[i])) {
&& spec.inputs[i].type != DataType::Any
&& requiredSpecs.inputs[i].type != spec.inputs[i].type)
{
match = false; match = false;
break; break;
} }
}
if (requiredSpecs.inputs[i].format != DataFormat::Any // Check outputs
&& spec.inputs[i].format != DataFormat::Any for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) {
&& requiredSpecs.inputs[i].format != spec.inputs[i].format) if (!checkIOSpec(requiredSpecs.outputs[i], spec.outputs[i])) {
{
match = false; match = false;
break; break;
} }
if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) {
if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) {
match = false;
break;
}
for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) {
// TODO
}
}
} }
// Check outputs
// TODO
// Check attributes // Check attributes
for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) { for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) {
std::string name = attrName; std::string name = attrName;
...@@ -173,6 +156,42 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const ...@@ -173,6 +156,42 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const
return requiredSpecs; return requiredSpecs;
} }
bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const {
if (required.type != DataType::Any
&& spec.type != DataType::Any
&& required.type != spec.type)
{
return false;
}
if (required.format != DataFormat::Any
&& spec.format != DataFormat::Any
&& required.format != spec.format)
{
return false;
}
if (!required.dims.empty() && !spec.dims.empty()) {
if (required.dims.size() != spec.dims.size()) {
return false;
}
for (size_t dim = 0; dim < required.dims.size(); ++dim) {
const auto requiredDim = required.dims[dim];
const auto specDim = spec.dims[dim];
if (requiredDim.first != -1
&& specDim.first != -1
&& !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
{
return false;
}
}
}
return true;
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) { std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) {
// TODO: have a generic getBestAlternative() that handle at least data type and data format conversions. // TODO: have a generic getBestAlternative() that handle at least data type and data format conversions.
return nullptr; return nullptr;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment