Skip to content
Snippets Groups Projects
Commit 93942ddc authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Replaced swich case with refCastFrom()

parent 97d2fef9
No related branches found
No related tags found
No related merge requests found
......@@ -64,33 +64,13 @@ bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type());
}
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
std::shared_ptr<Tensor> fallback;
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->rawPtr()),
indices.size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
......
......@@ -48,33 +48,13 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
std::shared_ptr<Tensor> fallback;
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(shape.getImpl()->rawPtr()),
shape.size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
......
......@@ -127,61 +127,27 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
std::shared_ptr<Tensor> fallback;
this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Ends>().clear();
this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Axes>().clear();
this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()),
getInput(2)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()),
getInput(3)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
const auto& starts = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(starts.getImpl()->rawPtr()),
starts.size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size());
const auto& ends = mInputs[2]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(ends.getImpl()->rawPtr()),
ends.size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size());
const auto& axes = mInputs[3]->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
std::copy_n(static_cast<int8_t*>(axes.getImpl()->rawPtr()),
axes.size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment