Skip to content
Snippets Groups Projects
Commit fbd36894 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

move computeDims to cpp

parent 7fc25a30
No related branches found
No related tags found
No related merge requests found
...@@ -68,17 +68,7 @@ public: ...@@ -68,17 +68,7 @@ public:
return std::make_shared<Gather_Op>(*this); return std::make_shared<Gather_Op>(*this);
} }
void computeOutputDims() override final { void computeOutputDims() override final;
if (!mInputs.empty() && !mInputs[0]->empty() && mInputs[1]->nbDims()==2)
{
std::vector<DimSize_t> outDims = mInputs[0]->dims();
std::vector<DimSize_t> indexesDims = mInputs[1]->dims();
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
mOutputs[0]->resize(outDims);
}
}
void setBackend(const std::string& name) override { void setBackend(const std::string& name) override {
mImpl = Registrar<Gather_Op>::create(name)(*this); mImpl = Registrar<Gather_Op>::create(name)(*this);
......
...@@ -52,19 +52,7 @@ public: ...@@ -52,19 +52,7 @@ public:
return std::make_shared<Reshape_Op>(*this); return std::make_shared<Reshape_Op>(*this);
} }
void computeOutputDims() override final { void computeOutputDims() override final;
if (!mInputs[0]->empty() && !mInputs[1]->empty())
{
std::vector<DimSize_t> outDims;
int* shapeElem = static_cast<int*>(mInputs[1]->getImpl()->rawPtr());
for(std::size_t i=0; i<mInputs[1]->size(); ++i)
{
outDims.push_back(shapeElem[i]);
}
mOutputs[0]->resize(outDims);
}
}
void setBackend(const std::string& name) override { void setBackend(const std::string& name) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this); mImpl = Registrar<Reshape_Op>::create(name)(*this);
......
...@@ -50,23 +50,7 @@ public: ...@@ -50,23 +50,7 @@ public:
*/ */
std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); }
void computeOutputDims() override final { void computeOutputDims() override final;
if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
{
DimSize_t nbAxes = mInputs[1]->dims()[0];
const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
std::vector<DimSize_t> outDims = mInputs[0]->dims();
for(std::size_t i=0; i<nbAxes;++i)
{
std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims();
outDims[axis] = ends[i] - starts[i] + 1;
}
mOutputs[0]->resize(outDims);
}
}
void setBackend(const std::string& name) override { void setBackend(const std::string& name) override {
mImpl = Registrar<Slice_Op>::create(name)(*this); mImpl = Registrar<Slice_Op>::create(name)(*this);
......
...@@ -68,7 +68,6 @@ class Transpose_Op : public OperatorTensor, ...@@ -68,7 +68,6 @@ class Transpose_Op : public OperatorTensor,
} }
void computeOutputDims() override final { void computeOutputDims() override final {
printf("************** nbIn %d \n", this->nbInputs());
if (!getInput(0)->empty()) { if (!getInput(0)->empty()) {
auto attr = (this)->getStaticAttributes(); auto attr = (this)->getStaticAttributes();
const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr)); const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr));
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Gather_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (getInput(1)->nbDims()!=2){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor");
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indexesDims = getInput(1)->dims();
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
mOutputs[0]->resize(outDims);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Reshape_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
std::vector<DimSize_t> outDims;
std::size_t outSize = 1;
int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr());
for(std::size_t i=0; i<mInputs[1]->size(); ++i)
{
int dimSize = shapeElem[i];
if (dimSize < 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
}
outDims.push_back(dimSize);
outSize *= dimSize;
}
if (getInput(0)->size() != outSize){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
}
mOutputs[0]->resize(outDims);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Slice_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (getInput(1)->nbDims()!=1){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 1D Tensor");
}
if (getInput(2)->nbDims()!=1){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Starts input must be a 1D Tensor");
}
if (getInput(3)->nbDims()!=1){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Ends input must be a 1D Tensor");
}
DimSize_t nbAxes = getInput(1)->dims()[0];
const int* axes = static_cast<const int*>(getInput(1)->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(getInput(2)->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(getInput(3)->getImpl()->rawPtr());
std::vector<DimSize_t> outDims = getInput(0)->dims();
for(std::size_t i=0; i<nbAxes;++i)
{
std::size_t axis = axes[i]>=0?axes[i]:axes[i]+getInput(0)->nbDims();
outDims[axis] = ends[i] - starts[i] + 1;
}
mOutputs[0]->resize(outDims);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment