Skip to content
Snippets Groups Projects
Commit edeb04c4 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add consts

parent 8a913046
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
...@@ -27,12 +27,12 @@ void Aidge::Gather_Op::computeOutputDims() { ...@@ -27,12 +27,12 @@ void Aidge::Gather_Op::computeOutputDims() {
} }
std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape // TODO: check indices and gatheredShape
std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() : this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size(); this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if (!gatheredShape.empty()) if (!gatheredShape.empty())
{ {
......
...@@ -25,7 +25,7 @@ void Aidge::Reshape_Op::computeOutputDims() { ...@@ -25,7 +25,7 @@ void Aidge::Reshape_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
} }
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size(); const DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
std::vector<DimSize_t> outDims; std::vector<DimSize_t> outDims;
std::size_t outSize = 1; std::size_t outSize = 1;
for(std::size_t i=0; i<nbOutDims; ++i) for(std::size_t i=0; i<nbOutDims; ++i)
...@@ -33,7 +33,7 @@ void Aidge::Reshape_Op::computeOutputDims() { ...@@ -33,7 +33,7 @@ void Aidge::Reshape_Op::computeOutputDims() {
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 1) if (dimSize < 1)
{ {
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad dimension value");
} }
outDims.push_back(dimSize); outDims.push_back(dimSize);
outSize *= dimSize; outSize *= dimSize;
......
...@@ -30,7 +30,7 @@ void Aidge::Slice_Op::computeOutputDims() { ...@@ -30,7 +30,7 @@ void Aidge::Slice_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
} }
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t // For each slice operation get the params and cast them to size_t
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment