Skip to content
Snippets Groups Projects

Replace swich case with refCastFrom()

Merged Olivier BICHLER requested to merge better_inputs_to_attr into dev
2 unresolved threads
Files
3
+ 5
25
@@ -64,33 +64,13 @@ bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type());
}
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
std::shared_ptr<Tensor> fallback;
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->rawPtr()),
indices.size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
Loading