From fbd368946f2a2cf86c832d04991ca1f2b830a36e Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 30 Nov 2023 16:34:07 +0100 Subject: [PATCH] move computeDims to cpp --- include/aidge/operator/Gather.hpp | 12 +------ include/aidge/operator/Reshape.hpp | 14 +------- include/aidge/operator/Slice.hpp | 18 +--------- include/aidge/operator/Transpose.hpp | 1 - src/operator/Gather.cpp | 38 +++++++++++++++++++++ src/operator/Reshape.cpp | 47 ++++++++++++++++++++++++++ src/operator/Slice.cpp | 49 ++++++++++++++++++++++++++++ 7 files changed, 137 insertions(+), 42 deletions(-) create mode 100644 src/operator/Gather.cpp create mode 100644 src/operator/Reshape.cpp create mode 100644 src/operator/Slice.cpp diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 6579331ca..ba7d745fa 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -68,17 +68,7 @@ public: return std::make_shared<Gather_Op>(*this); } - void computeOutputDims() override final { - if (!mInputs.empty() && !mInputs[0]->empty() && mInputs[1]->nbDims()==2) - { - std::vector<DimSize_t> outDims = mInputs[0]->dims(); - std::vector<DimSize_t> indexesDims = mInputs[1]->dims(); - int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size(); - outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end()); - mOutputs[0]->resize(outDims); - } - } + void computeOutputDims() override final; void setBackend(const std::string& name) override { mImpl = Registrar<Gather_Op>::create(name)(*this); diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 81cc7cd19..2d9372c4e 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -52,19 +52,7 @@ public: return std::make_shared<Reshape_Op>(*this); } - void computeOutputDims() override final { - if (!mInputs[0]->empty() && !mInputs[1]->empty()) - { - std::vector<DimSize_t> outDims; - int* shapeElem = static_cast<int*>(mInputs[1]->getImpl()->rawPtr()); - for(std::size_t i=0; i<mInputs[1]->size(); ++i) - { - outDims.push_back(shapeElem[i]); - } - mOutputs[0]->resize(outDims); - } - } - + void computeOutputDims() override final; void setBackend(const std::string& name) override { mImpl = Registrar<Reshape_Op>::create(name)(*this); diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index d1e000723..e98714b02 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -50,23 +50,7 @@ public: */ std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } - void computeOutputDims() override final { - if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty()) - { - DimSize_t nbAxes = mInputs[1]->dims()[0]; - const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr()); - const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr()); - const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr()); - std::vector<DimSize_t> outDims = mInputs[0]->dims(); - for(std::size_t i=0; i<nbAxes;++i) - { - std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims(); - outDims[axis] = ends[i] - starts[i] + 1; - } - mOutputs[0]->resize(outDims); - } - } - + void computeOutputDims() override final; void setBackend(const std::string& name) override { mImpl = Registrar<Slice_Op>::create(name)(*this); diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index 6248dcfc5..8bf5f17ab 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -68,7 +68,6 @@ class Transpose_Op : public OperatorTensor, } void computeOutputDims() override final { - printf("************** nbIn %d \n", this->nbInputs()); if (!getInput(0)->empty()) { auto attr = (this)->getStaticAttributes(); const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr)); diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp new file mode 100644 index 000000000..26a334bb0 --- /dev/null +++ b/src/operator/Gather.cpp @@ -0,0 +1,38 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <cstddef> +#include <vector> +#include <utility> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Gather.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +void Aidge::Gather_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (getInput(1)->nbDims()!=2){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor"); + } + + std::vector<DimSize_t> outDims = getInput(0)->dims(); + std::vector<DimSize_t> indexesDims = getInput(1)->dims(); + int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end()); + mOutputs[0]->resize(outDims); +} \ No newline at end of file diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp new file mode 100644 index 000000000..f32e8b5af --- /dev/null +++ b/src/operator/Reshape.cpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <cstddef> +#include <vector> +#include <utility> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +void Aidge::Reshape_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + std::vector<DimSize_t> outDims; + std::size_t outSize = 1; + int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr()); + for(std::size_t i=0; i<mInputs[1]->size(); ++i) + { + int dimSize = shapeElem[i]; + if (dimSize < 1) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); + } + outDims.push_back(dimSize); + outSize *= dimSize; + } + + if (getInput(0)->size() != outSize){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); + } + + mOutputs[0]->resize(outDims); +} \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp new file mode 100644 index 000000000..0495f96c5 --- /dev/null +++ b/src/operator/Slice.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <cstddef> +#include <vector> +#include <utility> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Slice.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +void Aidge::Slice_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (getInput(1)->nbDims()!=1){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 1D Tensor"); + } + if (getInput(2)->nbDims()!=1){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Starts input must be a 1D Tensor"); + } + if (getInput(3)->nbDims()!=1){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Ends input must be a 1D Tensor"); + } + + DimSize_t nbAxes = getInput(1)->dims()[0]; + const int* axes = static_cast<const int*>(getInput(1)->getImpl()->rawPtr()); + const int* starts = static_cast<const int*>(getInput(2)->getImpl()->rawPtr()); + const int* ends = static_cast<const int*>(getInput(3)->getImpl()->rawPtr()); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for(std::size_t i=0; i<nbAxes;++i) + { + std::size_t axis = axes[i]>=0?axes[i]:axes[i]+getInput(0)->nbDims(); + outDims[axis] = ends[i] - starts[i] + 1; + } + mOutputs[0]->resize(outDims); +} \ No newline at end of file -- GitLab