Skip to content
Snippets Groups Projects
Commit c67e943c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

remove ArithmeticOperator class

parent 98262f26
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!65[Add] broadcasting for Arithmetic Operators
Showing
with 227 additions and 196 deletions
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_
#define AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/graph/Node.hpp"
namespace Aidge {
class ArithmeticOperator : public OperatorTensor {
public:
ArithmeticOperator() = delete;
ArithmeticOperator(const std::string& type)
: OperatorTensor(type, 2, 0, 1) {
}
ArithmeticOperator(const ArithmeticOperator& other) : OperatorTensor(other){ }
~ArithmeticOperator();
std::shared_ptr<Operator> clone() const override {
return std::make_shared<ArithmeticOperator>(*this);
}
void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); }
public:
void computeOutputDims() override final;
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
} // namespace Aidge
#endif // AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_
\ No newline at end of file
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
...@@ -25,19 +25,21 @@ ...@@ -25,19 +25,21 @@
namespace Aidge { namespace Aidge {
class Div_Op : public ArithmeticOperator, class Div_Op : public OperatorTensor,
public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
Div_Op() : ArithmeticOperator(Type) {} Div_Op() : OperatorTensor(Type, 2, 0, 1) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Div_Op(const Div_Op& op) : ArithmeticOperator(op){ Div_Op(const Div_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Div_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Div_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -49,10 +51,20 @@ public: ...@@ -49,10 +51,20 @@ public:
return std::make_shared<Div_Op>(*this); return std::make_shared<Div_Op>(*this);
} }
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Div_Op>::create(name)(*this); mImpl = Registrar<Div_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Div(const std::string& name = "") { inline std::shared_ptr<Node> Div(const std::string& name = "") {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
...@@ -28,19 +28,21 @@ namespace Aidge { ...@@ -28,19 +28,21 @@ namespace Aidge {
/** /**
* @brief Tensor element-wise multiplication. * @brief Tensor element-wise multiplication.
*/ */
class Mul_Op : public ArithmeticOperator, class Mul_Op : public OperatorTensor,
public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
Mul_Op() : ArithmeticOperator(Type) {} Mul_Op() : OperatorTensor(Type, 2, 0, 1) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), * @brief Copy-constructor. Copy the operator attributes and its output tensor(s),
* but not its input tensors (the new operator has no input associated). * but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Mul_Op(const Mul_Op& op) : ArithmeticOperator(op){ Mul_Op(const Mul_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Mul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Mul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -52,10 +54,19 @@ public: ...@@ -52,10 +54,19 @@ public:
return std::make_shared<Mul_Op>(*this); return std::make_shared<Mul_Op>(*this);
} }
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Mul_Op>::create(name)(*this); mImpl = Registrar<Mul_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Mul(const std::string& name = "") { inline std::shared_ptr<Node> Mul(const std::string& name = "") {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
...@@ -26,18 +26,20 @@ ...@@ -26,18 +26,20 @@
namespace Aidge { namespace Aidge {
class Pow_Op : public ArithmeticOperator, class Pow_Op : public OperatorTensor,
public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
Pow_Op() : ArithmeticOperator(Type) {} Pow_Op() : OperatorTensor(Type, 2, 0, 1) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Pow_Op(const Pow_Op& op) : ArithmeticOperator(op){ Pow_Op(const Pow_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Pow_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Pow_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -49,10 +51,20 @@ public: ...@@ -49,10 +51,20 @@ public:
return std::make_shared<Pow_Op>(*this); return std::make_shared<Pow_Op>(*this);
} }
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Pow_Op>::create(name)(*this); mImpl = Registrar<Pow_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Pow(const std::string& name = "") { inline std::shared_ptr<Node> Pow(const std::string& name = "") {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace Aidge { namespace Aidge {
class Sub_Op : public ArithmeticOperator, class Sub_Op : public OperatorTensor,
public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> { public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> {
public: public:
// FIXME: change accessibility // FIXME: change accessibility
...@@ -36,13 +36,15 @@ public: ...@@ -36,13 +36,15 @@ public:
public: public:
static const std::string Type; static const std::string Type;
Sub_Op() : ArithmeticOperator(Type) {} Sub_Op() : OperatorTensor(Type, 2, 0, 1) {}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Sub_Op(const Sub_Op& op) : ArithmeticOperator(op){ Sub_Op(const Sub_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Sub_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; mImpl = op.mImpl ? Registrar<Sub_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
} }
...@@ -54,10 +56,20 @@ public: ...@@ -54,10 +56,20 @@ public:
return std::make_shared<Sub_Op>(*this); return std::make_shared<Sub_Op>(*this);
} }
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Sub_Op>::create(name)(*this); mImpl = Registrar<Sub_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Sub(const std::string& name = "") { inline std::shared_ptr<Node> Sub(const std::string& name = "") {
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/ArithmeticOperator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_ArithmeticOperator(py::module& m){
py::class_<ArithmeticOperator, std::shared_ptr<ArithmeticOperator>, OperatorTensor>(m, "ArithmeticOperator")
.def("get_output", &ArithmeticOperator::getOutput, py::arg("outputIdx"))
.def("get_input", &ArithmeticOperator::getInput, py::arg("inputIdx"))
.def("set_output", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setOutput, py::arg("outputIdx"), py::arg("data"))
.def("set_input", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setInput, py::arg("outputIdx"), py::arg("data"))
.def("output_dims_forwarded", &ArithmeticOperator::outputDimsForwarded)
;
}
} // namespace Aidge
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/Div.hpp" #include "aidge/operator/Div.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Div(py::module& m) { void init_Div(py::module& m) {
py::class_<Div_Op, std::shared_ptr<Div_Op>, ArithmeticOperator>(m, "DivOp", py::multiple_inheritance()) py::class_<Div_Op, std::shared_ptr<Div_Op>, OperatorTensor>(m, "DivOp", py::multiple_inheritance())
.def("get_inputs_name", &Div_Op::getInputsName) .def("get_inputs_name", &Div_Op::getInputsName)
.def("get_outputs_name", &Div_Op::getOutputsName); .def("get_outputs_name", &Div_Op::getOutputsName);
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/Mul.hpp" #include "aidge/operator/Mul.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Mul(py::module& m) { void init_Mul(py::module& m) {
py::class_<Mul_Op, std::shared_ptr<Mul_Op>, ArithmeticOperator>(m, "MulOp", py::multiple_inheritance()) py::class_<Mul_Op, std::shared_ptr<Mul_Op>, OperatorTensor>(m, "MulOp", py::multiple_inheritance())
.def("get_inputs_name", &Mul_Op::getInputsName) .def("get_inputs_name", &Mul_Op::getInputsName)
.def("get_outputs_name", &Mul_Op::getOutputsName); .def("get_outputs_name", &Mul_Op::getOutputsName);
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/Pow.hpp" #include "aidge/operator/Pow.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Pow(py::module& m) { void init_Pow(py::module& m) {
py::class_<Pow_Op, std::shared_ptr<Pow_Op>, ArithmeticOperator>(m, "PowOp", py::multiple_inheritance()) py::class_<Pow_Op, std::shared_ptr<Pow_Op>, OperatorTensor>(m, "PowOp", py::multiple_inheritance())
.def("get_inputs_name", &Pow_Op::getInputsName) .def("get_inputs_name", &Pow_Op::getInputsName)
.def("get_outputs_name", &Pow_Op::getOutputsName); .def("get_outputs_name", &Pow_Op::getOutputsName);
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/Sub.hpp" #include "aidge/operator/Sub.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Sub(py::module& m) { void init_Sub(py::module& m) {
py::class_<Sub_Op, std::shared_ptr<Sub_Op>, ArithmeticOperator>(m, "SubOp", py::multiple_inheritance()) py::class_<Sub_Op, std::shared_ptr<Sub_Op>, OperatorTensor>(m, "SubOp", py::multiple_inheritance())
.def("get_inputs_name", &Sub_Op::getInputsName) .def("get_inputs_name", &Sub_Op::getInputsName)
.def("get_outputs_name", &Sub_Op::getOutputsName); .def("get_outputs_name", &Sub_Op::getOutputsName);
......
...@@ -22,7 +22,6 @@ void init_Operator(py::module&); ...@@ -22,7 +22,6 @@ void init_Operator(py::module&);
void init_OperatorTensor(py::module&); void init_OperatorTensor(py::module&);
void init_Add(py::module&); void init_Add(py::module&);
void init_ArithmeticOperator(py::module&);
void init_AvgPooling(py::module&); void init_AvgPooling(py::module&);
void init_BatchNorm(py::module&); void init_BatchNorm(py::module&);
void init_Concat(py::module&); void init_Concat(py::module&);
...@@ -78,7 +77,6 @@ void init_Aidge(py::module& m){ ...@@ -78,7 +77,6 @@ void init_Aidge(py::module& m){
init_Attributes(m); init_Attributes(m);
init_Operator(m); init_Operator(m);
init_OperatorTensor(m); init_OperatorTensor(m);
init_ArithmeticOperator(m);
init_Add(m); init_Add(m);
init_AvgPooling(m); init_AvgPooling(m);
init_BatchNorm(m); init_BatchNorm(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include "aidge/operator/ArithmeticOperator.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::ArithmeticOperator::~ArithmeticOperator() = default;
void Aidge::ArithmeticOperator::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
// if (getInput(0)->empty() || getInput(1)->empty()) {
// AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input is empty");
// }
std::vector<std::vector<std::size_t>> inputsDims;
for (std::size_t i = 0; i < nbInputs(); i++)
{
inputsDims.push_back(getInput(i)->dims());
}
std::size_t outNbDims = 1;
for(size_t i=0; i<inputsDims.size() ; ++i)
outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims;
std::vector<std::size_t> outDims(outNbDims, 1);
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin())
{
--it;
for (size_t i = 0; i < inputsDims.size(); i++)
{
if(!inputsDims[i].empty())
{
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim)
{
if(dim != 1)
{
if (*it != 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Arithmetic Operation");
}
else
{
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
\ No newline at end of file
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
...@@ -20,4 +20,41 @@ ...@@ -20,4 +20,41 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Div_Op::Type = "Div"; const std::string Aidge::Div_Op::Type = "Div";
\ No newline at end of file
void Aidge::Div_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()};
std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ?
inputsDims[0] : inputsDims[1];
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin()) {
--it;
for (size_t i = 0; i < inputsDims.size(); i++) {
if(!inputsDims[i].empty()) {
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim) {
if(dim != 1) {
if (*it != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Div Operation");
}
else {
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
...@@ -19,4 +19,41 @@ ...@@ -19,4 +19,41 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Mul_Op::Type = "Mul"; const std::string Aidge::Mul_Op::Type = "Mul";
\ No newline at end of file
void Aidge::Mul_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()};
std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ?
inputsDims[0] : inputsDims[1];
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin()) {
--it;
for (size_t i = 0; i < inputsDims.size(); i++) {
if(!inputsDims[i].empty()) {
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim) {
if(dim != 1) {
if (*it != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Mul Operation");
}
else {
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
...@@ -19,4 +19,41 @@ ...@@ -19,4 +19,41 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Pow_Op::Type = "Pow"; const std::string Aidge::Pow_Op::Type = "Pow";
\ No newline at end of file
void Aidge::Pow_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()};
std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ?
inputsDims[0] : inputsDims[1];
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin()) {
--it;
for (size_t i = 0; i < inputsDims.size(); i++) {
if(!inputsDims[i].empty()) {
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim) {
if(dim != 1) {
if (*it != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Pow Operation");
}
else {
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <vector> #include <vector>
...@@ -19,4 +19,41 @@ ...@@ -19,4 +19,41 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Sub_Op::Type = "Sub"; const std::string Aidge::Sub_Op::Type = "Sub";
\ No newline at end of file
void Aidge::Sub_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()};
std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ?
inputsDims[0] : inputsDims[1];
std::vector<std::size_t>::iterator it = outDims.end();
while (it != outDims.begin()) {
--it;
for (size_t i = 0; i < inputsDims.size(); i++) {
if(!inputsDims[i].empty()) {
std::size_t dim = inputsDims[i].back();
inputsDims[i].pop_back();
if (*it != dim) {
if(dim != 1) {
if (*it != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Sub Operation");
}
else {
*it = dim;
}
}
}
}
}
}
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment