Skip to content
Snippets Groups Projects

Replace swich case with refCastFrom()

Merged Olivier BICHLER requested to merge better_inputs_to_attr into dev
1 file
+ 9
18
Compare changes
  • Side-by-side
  • Inline
+ 9
18
@@ -36,33 +36,24 @@ void Aidge::Slice_OpImpl::forward() {
@@ -36,33 +36,24 @@ void Aidge::Slice_OpImpl::forward() {
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
"start, end and axes arguments should be the same size.");
const std::size_t nbDims = op.getInput(0)->nbDims();
const auto nbDims = op.getInput(0)->nbDims();
const auto& inputDims = op.getInput(0)->dims();
const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
const auto& outputDims = op.getOutput(0)->dims();
auto outputDims = op.getInput(0)->dims();
// compute index of the output's first element
// compute index of the output's first element
// compute output dimension at the same time (may change between two forward calls)
std::size_t beginning = 0;
std::size_t beginning = 0;
const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
for (std::size_t i = 0; i < nbAxes; ++i) {
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
// For each slice operation get the params and cast them to size_t
DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
const DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
const DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis]));
const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
beginning += start * stridePostAxis;
beginning += start * stridePostAxis;
const std::size_t sliceLength = end - start;
outputDims[axis] = sliceLength;
}
}
op.getOutput(0)->resize(outputDims);
// for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
// for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
@@ -195,16 +186,16 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
@@ -195,16 +186,16 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
for (std::size_t i = 0; i < nbAxes; ++i) {
DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
Loading