diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index ed64f569598b763ac13a71f389403b24d1cc172d..863aa21bb5422ac7c953b18d814fc2f46662bbae 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -86,7 +86,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) // Check inputs for (size_t i = 0; i < requiredSpecs.inputs.size(); ++i) { - if (!checkIOSpec(requiredSpecs.inputs[i], spec.inputs[i])) { + const auto inputSpec = (i < spec.inputs.size()) ? spec.inputs[i] : spec.inputs.back(); + if (!checkIOSpec(requiredSpecs.inputs[i], inputSpec)) { match = false; break; } @@ -94,7 +95,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) // Check outputs for (size_t i = 0; i < requiredSpecs.outputs.size(); ++i) { - if (!checkIOSpec(requiredSpecs.outputs[i], spec.outputs[i])) { + const auto outputSpec = (i < spec.outputs.size()) ? spec.outputs[i] : spec.outputs.back(); + if (!checkIOSpec(requiredSpecs.outputs[i], outputSpec)) { match = false; break; } @@ -171,7 +173,13 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im && spec.format != DataFormat::Any && required.format != spec.format) { - return false; + const auto transpose = getDataFormatTranspose(required.format, spec.format); + std::vector<size_t> identity(transpose.size()); + std::iota(std::begin(identity), std::end(identity), 0); + + if (!std::equal(transpose.begin(), transpose.end(), identity.begin())) { + return false; + } } if (!required.dims.empty() && !spec.dims.empty()) {