Skip to content
Snippets Groups Projects
Commit edeb04c4 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add consts

parent 8a913046
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
......@@ -27,12 +27,12 @@ void Aidge::Gather_Op::computeOutputDims() {
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if (!gatheredShape.empty())
{
......
......@@ -25,7 +25,7 @@ void Aidge::Reshape_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
const DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
std::vector<DimSize_t> outDims;
std::size_t outSize = 1;
for(std::size_t i=0; i<nbOutDims; ++i)
......@@ -33,7 +33,7 @@ void Aidge::Reshape_Op::computeOutputDims() {
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value");
AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad dimension value");
}
outDims.push_back(dimSize);
outSize *= dimSize;
......
......@@ -30,7 +30,7 @@ void Aidge::Slice_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment