Skip to content
Snippets Groups Projects
Commit 8857ae78 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added checkIOSpec()

parent 622478e5
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53524 failed
......@@ -32,7 +32,7 @@ class Operator;
*/
struct ImplSpec {
struct IOSpec {
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<DimSize_t, DimSize_t>> dims_ = {}):
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<int, int>> dims_ = {}):
type(type_),
format(format_),
dims(dims_)
......@@ -40,7 +40,7 @@ struct ImplSpec {
DataType type;
DataFormat format;
std::vector<std::pair<DimSize_t, DimSize_t>> dims;
std::vector<std::pair<int, int>> dims;
};
ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()):
......@@ -128,6 +128,7 @@ public:
protected:
virtual std::shared_ptr<ProdConso> getProdConso() const;
virtual std::vector<ImplSpec> getAvailableImplSpecs() const;
bool checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const;
const Operator &mOp;
const std::string mBackend;
......
......@@ -40,9 +40,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
// Inputs specs
for (size_t i = 0; i < opTensor.nbInputs(); ++i) {
if (opTensor.getInput(i)) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims;
std::vector<std::pair<int, int>> dims;
for (auto dim : opTensor.getInput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim));
dims.push_back(std::make_pair<int, int>(dim, dim));
}
requiredSpec.inputs.push_back({opTensor.getInput(i)->dataType(), opTensor.getInput(i)->dataFormat(), dims});
......@@ -53,9 +53,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const {
}
// Outputs specs
for (size_t i = 0; i < opTensor.nbOutputs(); ++i) {
std::vector<std::pair<DimSize_t, DimSize_t>> dims;
std::vector<std::pair<int, int>> dims;
for (auto dim : opTensor.getOutput(i)->dims()) {
dims.push_back(std::make_pair(dim, dim));
dims.push_back(std::make_pair<int, int>(dim, dim));
}
requiredSpec.outputs.push_back({opTensor.getOutput(i)->dataType(), opTensor.getOutput(i)->dataFormat(), dims});
......@@ -78,42 +78,25 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const
for (size_t s = 0; s < availableSpecs.size(); ++s) {
auto spec = availableSpecs[s];
int match = true;
bool match = true;
int priority = 0;
// Check inputs
for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) {
if (requiredSpecs.inputs[i].type != DataType::Any
&& spec.inputs[i].type != DataType::Any
&& requiredSpecs.inputs[i].type != spec.inputs[i].type)
{
if (!checkIOSpec(requiredSpecs.inputs[i], spec.inputs[i])) {
match = false;
break;
}
}
if (requiredSpecs.inputs[i].format != DataFormat::Any
&& spec.inputs[i].format != DataFormat::Any
&& requiredSpecs.inputs[i].format != spec.inputs[i].format)
{
// Check outputs
for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) {
if (!checkIOSpec(requiredSpecs.outputs[i], spec.outputs[i])) {
match = false;
break;
}
if (!requiredSpecs.inputs[i].dims.empty() && !spec.inputs[i].dims.empty()) {
if (requiredSpecs.inputs[i].dims.size() != spec.inputs[i].dims.size()) {
match = false;
break;
}
for (size_t dim = 0; dim < requiredSpecs.inputs[i].dims.size(); ++dim) {
// TODO
}
}
}
// Check outputs
// TODO
// Check attributes
for (const auto& attrName : requiredSpecs.attrs.getAttrsName()) {
std::string name = attrName;
......@@ -173,6 +156,42 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(ImplSpec requiredSpecs) const
return requiredSpecs;
}
bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const {
if (required.type != DataType::Any
&& spec.type != DataType::Any
&& required.type != spec.type)
{
return false;
}
if (required.format != DataFormat::Any
&& spec.format != DataFormat::Any
&& required.format != spec.format)
{
return false;
}
if (!required.dims.empty() && !spec.dims.empty()) {
if (required.dims.size() != spec.dims.size()) {
return false;
}
for (size_t dim = 0; dim < required.dims.size(); ++dim) {
const auto requiredDim = required.dims[dim];
const auto specDim = spec.dims[dim];
if (requiredDim.first != -1
&& specDim.first != -1
&& !(specDim.first <= requiredDim.first && specDim.second >= requiredDim.second))
{
return false;
}
}
}
return true;
}
std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAlternative(ImplSpec /*requiredSpecs*/) {
// TODO: have a generic getBestAlternative() that handle at least data type and data format conversions.
return nullptr;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment