Skip to content
Snippets Groups Projects
Commit ca67f0bc authored by Maxence Naud's avatar Maxence Naud
Browse files

Custom computeOutputDims() for Pow, Div, Mul, Sub operators

parent f1b503a0
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,6 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
......@@ -52,6 +51,9 @@ public:
return std::make_shared<Div_Op>(*this);
}
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
mImpl = Registrar<Div_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
......
......@@ -54,6 +54,7 @@ public:
return std::make_shared<Mul_Op>(*this);
}
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
mImpl = Registrar<Mul_Op>::create(name)(*this);
......
......@@ -51,6 +51,8 @@ public:
return std::make_shared<Pow_Op>(*this);
}
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
mImpl = Registrar<Pow_Op>::create(name)(*this);
......
......@@ -56,6 +56,8 @@ public:
return std::make_shared<Sub_Op>(*this);
}
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
mImpl = Registrar<Sub_Op>::create(name)(*this);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Div_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if ((!getInput(0)->empty()) &&
((getInput(1)->size() == 1) || // div by a single value
(getInput(1)->size() == getInput(0)->size()) || // div elem-wise
(getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // div by a Tensor with one dimension of output size
{
mOutputs[0]->resize(getInput(0)->dims());
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Mul_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if ((!getInput(0)->empty()) &&
((getInput(1)->size() == 1) || // mul by a single value
(getInput(1)->size() == getInput(0)->size()) || // mul elem-wise
(getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // mul by a Tensor with one dimension of output size
{
mOutputs[0]->resize(getInput(0)->dims());
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Pow.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Pow_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if ((!getInput(0)->empty()) &&
((getInput(1)->size() == 1) || // pow by a single value
(getInput(1)->size() == getInput(0)->size()) || // pow elem-wise
(getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // pow by a Tensor with one dimension of output size
{
mOutputs[0]->resize(getInput(0)->dims());
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Sub_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if ((!getInput(0)->empty()) &&
((getInput(1)->size() == 1) || // sub by a single value
(getInput(1)->size() == getInput(0)->size()) || // sub elem-wise
(getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // sub by a Tensor with one dimension of output size
{
mOutputs[0]->resize(getInput(0)->dims());
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment