diff --git a/include/aidge/operator/ArithmeticOperator.hpp b/include/aidge/operator/ArithmeticOperator.hpp deleted file mode 100644 index 7d207c6c216af4e369d8324fa6be303ffb5347f0..0000000000000000000000000000000000000000 --- a/include/aidge/operator/ArithmeticOperator.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2024 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_ -#define AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_ - -#include <memory> -#include <string> -#include <vector> - -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Types.h" -#include "aidge/graph/Node.hpp" - -namespace Aidge { - -class ArithmeticOperator : public OperatorTensor { - -public: - ArithmeticOperator() = delete; - - ArithmeticOperator(const std::string& type) - : OperatorTensor(type, 2, 0, 1) { - } - - ArithmeticOperator(const ArithmeticOperator& other) : OperatorTensor(other){ } - - ~ArithmeticOperator(); - - std::shared_ptr<Operator> clone() const override { - return std::make_shared<ArithmeticOperator>(*this); - } - - void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); } - -public: - - void computeOutputDims() override final; - - static const std::vector<std::string> getInputsName(){ - return {"data_input_1", "data_input_2"}; - } - static const std::vector<std::string> getOutputsName(){ - return {"data_output"}; - } -}; -} // namespace Aidge - -#endif // AIDGE_CORE_OPERATOR_ARITHMETICOPERATOR_H_ \ No newline at end of file diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index ec4319e6b74258fcc3836f5b7aaca781b418dc2d..a033c6920a374003ad869bddbf5641c48fc5f6e2 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -17,7 +17,7 @@ #include <vector> #include "aidge/utils/Registrar.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" @@ -25,19 +25,21 @@ namespace Aidge { -class Div_Op : public ArithmeticOperator, +class Div_Op : public OperatorTensor, public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> { public: static const std::string Type; - Div_Op() : ArithmeticOperator(Type) {} + Div_Op() : OperatorTensor(Type, 2, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Div_Op(const Div_Op& op) : ArithmeticOperator(op){ + Div_Op(const Div_Op& op) + : OperatorTensor(op) + { mImpl = op.mImpl ? Registrar<Div_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -49,10 +51,20 @@ public: return std::make_shared<Div_Op>(*this); } + void computeOutputDims() override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Div_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_1", "data_input_2"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Div(const std::string& name = "") { diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 2a1ad1d5bedc8ff64fad2e254601fdcfa4781f56..8758021a9c3de1707a96bbfafc21686ded8b7e40 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -17,7 +17,7 @@ #include <vector> #include "aidge/utils/Registrar.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" @@ -28,19 +28,21 @@ namespace Aidge { /** * @brief Tensor element-wise multiplication. */ -class Mul_Op : public ArithmeticOperator, +class Mul_Op : public OperatorTensor, public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> { public: static const std::string Type; - Mul_Op() : ArithmeticOperator(Type) {} + Mul_Op() : OperatorTensor(Type, 2, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), * but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Mul_Op(const Mul_Op& op) : ArithmeticOperator(op){ + Mul_Op(const Mul_Op& op) + : OperatorTensor(op) + { mImpl = op.mImpl ? Registrar<Mul_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -52,10 +54,19 @@ public: return std::make_shared<Mul_Op>(*this); } + void computeOutputDims() override final; + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Mul_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_1", "data_input_2"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Mul(const std::string& name = "") { diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index f1bd3ad51bf407ca756260e781564a6c87907705..ba8d3d05877f9aa543518fff1d88f4e8a436b712 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -17,7 +17,7 @@ #include <vector> #include "aidge/utils/Registrar.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/data/Data.hpp" @@ -26,18 +26,20 @@ namespace Aidge { -class Pow_Op : public ArithmeticOperator, +class Pow_Op : public OperatorTensor, public Registrable<Pow_Op, std::string, std::unique_ptr<OperatorImpl>(const Pow_Op&)> { public: static const std::string Type; - Pow_Op() : ArithmeticOperator(Type) {} + Pow_Op() : OperatorTensor(Type, 2, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Pow_Op(const Pow_Op& op) : ArithmeticOperator(op){ + Pow_Op(const Pow_Op& op) + : OperatorTensor(op) + { mImpl = op.mImpl ? Registrar<Pow_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -49,10 +51,20 @@ public: return std::make_shared<Pow_Op>(*this); } + void computeOutputDims() override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Pow_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_1", "data_input_2"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Pow(const std::string& name = "") { diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index d00e9f5f035f693e6eca012bebf5a6d345b231d3..7d346457ead71724ba05da70b5bdf7ad145cbe0c 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -17,7 +17,7 @@ #include <vector> #include "aidge/utils/Registrar.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/data/Data.hpp" @@ -26,7 +26,7 @@ namespace Aidge { -class Sub_Op : public ArithmeticOperator, +class Sub_Op : public OperatorTensor, public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> { public: // FIXME: change accessibility @@ -36,13 +36,15 @@ public: public: static const std::string Type; - Sub_Op() : ArithmeticOperator(Type) {} + Sub_Op() : OperatorTensor(Type, 2, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Sub_Op(const Sub_Op& op) : ArithmeticOperator(op){ + Sub_Op(const Sub_Op& op) + : OperatorTensor(op) + { mImpl = op.mImpl ? Registrar<Sub_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -54,10 +56,20 @@ public: return std::make_shared<Sub_Op>(*this); } + void computeOutputDims() override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Sub_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_1", "data_input_2"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Sub(const std::string& name = "") { diff --git a/python_binding/operator/pybind_ArithmeticOperator.cpp b/python_binding/operator/pybind_ArithmeticOperator.cpp deleted file mode 100644 index cd5d164b924d9347bf3cf25dfa123df69116c529..0000000000000000000000000000000000000000 --- a/python_binding/operator/pybind_ArithmeticOperator.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2024 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include "aidge/operator/ArithmeticOperator.hpp" -#include "aidge/operator/OperatorTensor.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_ArithmeticOperator(py::module& m){ - py::class_<ArithmeticOperator, std::shared_ptr<ArithmeticOperator>, OperatorTensor>(m, "ArithmeticOperator") - .def("get_output", &ArithmeticOperator::getOutput, py::arg("outputIdx")) - .def("get_input", &ArithmeticOperator::getInput, py::arg("inputIdx")) - - .def("set_output", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setOutput, py::arg("outputIdx"), py::arg("data")) - .def("set_input", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setInput, py::arg("outputIdx"), py::arg("data")) - .def("output_dims_forwarded", &ArithmeticOperator::outputDimsForwarded) - ; -} -} // namespace Aidge diff --git a/python_binding/operator/pybind_Div.cpp b/python_binding/operator/pybind_Div.cpp index ff9933cdcc1ca8507016bde153b440d2497250fe..6d14510f34349c001289096a7fc9b08681a25bc8 100644 --- a/python_binding/operator/pybind_Div.cpp +++ b/python_binding/operator/pybind_Div.cpp @@ -12,13 +12,13 @@ #include <pybind11/pybind11.h> #include "aidge/operator/Div.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" namespace py = pybind11; namespace Aidge { void init_Div(py::module& m) { - py::class_<Div_Op, std::shared_ptr<Div_Op>, ArithmeticOperator>(m, "DivOp", py::multiple_inheritance()) + py::class_<Div_Op, std::shared_ptr<Div_Op>, OperatorTensor>(m, "DivOp", py::multiple_inheritance()) .def("get_inputs_name", &Div_Op::getInputsName) .def("get_outputs_name", &Div_Op::getOutputsName); diff --git a/python_binding/operator/pybind_Mul.cpp b/python_binding/operator/pybind_Mul.cpp index 7ad55c3ad90d22748977c149a2489a717111da3c..21f510d98728fbe5401288a366294241b5f10a3f 100644 --- a/python_binding/operator/pybind_Mul.cpp +++ b/python_binding/operator/pybind_Mul.cpp @@ -12,13 +12,13 @@ #include <pybind11/pybind11.h> #include "aidge/operator/Mul.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" namespace py = pybind11; namespace Aidge { void init_Mul(py::module& m) { - py::class_<Mul_Op, std::shared_ptr<Mul_Op>, ArithmeticOperator>(m, "MulOp", py::multiple_inheritance()) + py::class_<Mul_Op, std::shared_ptr<Mul_Op>, OperatorTensor>(m, "MulOp", py::multiple_inheritance()) .def("get_inputs_name", &Mul_Op::getInputsName) .def("get_outputs_name", &Mul_Op::getOutputsName); diff --git a/python_binding/operator/pybind_Pow.cpp b/python_binding/operator/pybind_Pow.cpp index 7bc8f05a8cc9a0748ae3853bbbaaf3b638e01bf1..09d1e4ad2ad6413901c28bc9d9fe16995483da05 100644 --- a/python_binding/operator/pybind_Pow.cpp +++ b/python_binding/operator/pybind_Pow.cpp @@ -12,13 +12,13 @@ #include <pybind11/pybind11.h> #include "aidge/operator/Pow.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" namespace py = pybind11; namespace Aidge { void init_Pow(py::module& m) { - py::class_<Pow_Op, std::shared_ptr<Pow_Op>, ArithmeticOperator>(m, "PowOp", py::multiple_inheritance()) + py::class_<Pow_Op, std::shared_ptr<Pow_Op>, OperatorTensor>(m, "PowOp", py::multiple_inheritance()) .def("get_inputs_name", &Pow_Op::getInputsName) .def("get_outputs_name", &Pow_Op::getOutputsName); diff --git a/python_binding/operator/pybind_Sub.cpp b/python_binding/operator/pybind_Sub.cpp index c2670389e805e33eca42f37bc13e84cca94f9f08..dce1ab6cb27cc7da02e6c817a6bc49ec64bcf364 100644 --- a/python_binding/operator/pybind_Sub.cpp +++ b/python_binding/operator/pybind_Sub.cpp @@ -12,13 +12,13 @@ #include <pybind11/pybind11.h> #include "aidge/operator/Sub.hpp" -#include "aidge/operator/ArithmeticOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" namespace py = pybind11; namespace Aidge { void init_Sub(py::module& m) { - py::class_<Sub_Op, std::shared_ptr<Sub_Op>, ArithmeticOperator>(m, "SubOp", py::multiple_inheritance()) + py::class_<Sub_Op, std::shared_ptr<Sub_Op>, OperatorTensor>(m, "SubOp", py::multiple_inheritance()) .def("get_inputs_name", &Sub_Op::getInputsName) .def("get_outputs_name", &Sub_Op::getOutputsName); diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index eded0f53ec60bbff8e574d809e6595fd25a079a8..be0d357b7f73e26aad44994f407696f70617ad71 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -22,7 +22,6 @@ void init_Operator(py::module&); void init_OperatorTensor(py::module&); void init_Add(py::module&); -void init_ArithmeticOperator(py::module&); void init_AvgPooling(py::module&); void init_BatchNorm(py::module&); void init_Concat(py::module&); @@ -78,7 +77,6 @@ void init_Aidge(py::module& m){ init_Attributes(m); init_Operator(m); init_OperatorTensor(m); - init_ArithmeticOperator(m); init_Add(m); init_AvgPooling(m); init_BatchNorm(m); diff --git a/src/operator/ArithmeticOperator.cpp b/src/operator/ArithmeticOperator.cpp deleted file mode 100644 index 2e02109d36e1e50db5f78b92c5a580ee4477922d..0000000000000000000000000000000000000000 --- a/src/operator/ArithmeticOperator.cpp +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <cassert> -#include <memory> - -#include "aidge/operator/ArithmeticOperator.hpp" -#include "aidge/data/Data.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/utils/Types.h" -#include "aidge/utils/ErrorHandling.hpp" - - -Aidge::ArithmeticOperator::~ArithmeticOperator() = default; - -void Aidge::ArithmeticOperator::computeOutputDims() { - // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - // if (getInput(0)->empty() || getInput(1)->empty()) { - // AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input is empty"); - // } - - std::vector<std::vector<std::size_t>> inputsDims; - for (std::size_t i = 0; i < nbInputs(); i++) - { - inputsDims.push_back(getInput(i)->dims()); - } - - std::size_t outNbDims = 1; - - for(size_t i=0; i<inputsDims.size() ; ++i) - outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims; - - std::vector<std::size_t> outDims(outNbDims, 1); - - std::vector<std::size_t>::iterator it = outDims.end(); - while (it != outDims.begin()) - { - --it; - for (size_t i = 0; i < inputsDims.size(); i++) - { - if(!inputsDims[i].empty()) - { - std::size_t dim = inputsDims[i].back(); - inputsDims[i].pop_back(); - if (*it != dim) - { - if(dim != 1) - { - if (*it != 1) - { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Arithmetic Operation"); - } - else - { - *it = dim; - } - } - } - } - } - } - mOutputs[0]->resize(outDims); -} \ No newline at end of file diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index e5fe78b3119273f57aa795bc3295b640c7ec7cd3..66221ce6ae917e6cdce9bc888f66476fbc000d22 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ - +#include <algorithm> #include <cassert> #include <cstddef> #include <string> @@ -20,4 +20,41 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Div_Op::Type = "Div"; \ No newline at end of file +const std::string Aidge::Div_Op::Type = "Div"; + +void Aidge::Div_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (!getInput(0)->empty() && !getInput(1)->empty()) { + + std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()}; + + std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ? + inputsDims[0] : inputsDims[1]; + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) { + if(!inputsDims[i].empty()) { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) { + if(dim != 1) { + if (*it != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Div Operation"); + } + else { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 4de84e45b58e05fd5668dafacc4361b64968e9bb..5b705c603233ffc2b299e66315e59caab181c0ac 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ - +#include <algorithm> #include <cassert> #include <cstddef> #include <vector> @@ -19,4 +19,41 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Mul_Op::Type = "Mul"; \ No newline at end of file +const std::string Aidge::Mul_Op::Type = "Mul"; + +void Aidge::Mul_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (!getInput(0)->empty() && !getInput(1)->empty()) { + + std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()}; + + std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ? + inputsDims[0] : inputsDims[1]; + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) { + if(!inputsDims[i].empty()) { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) { + if(dim != 1) { + if (*it != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Mul Operation"); + } + else { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index 932b31e977889aea79248cf6163f89692cdde0b9..850e32173c6da2fe8501d94b048b80a8ef8179f5 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ - +#include <algorithm> #include <cassert> #include <cstddef> #include <vector> @@ -19,4 +19,41 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Pow_Op::Type = "Pow"; \ No newline at end of file +const std::string Aidge::Pow_Op::Type = "Pow"; + +void Aidge::Pow_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (!getInput(0)->empty() && !getInput(1)->empty()) { + + std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()}; + + std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ? + inputsDims[0] : inputsDims[1]; + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) { + if(!inputsDims[i].empty()) { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) { + if(dim != 1) { + if (*it != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Pow Operation"); + } + else { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index 74f9e8ca1f3cc3f36cb7333a013bb132be89f1e1..1fe41140d31f74d47a6313c23ddbe3ef2e5073df 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ - +#include <algorithm> #include <cassert> #include <cstddef> #include <vector> @@ -19,4 +19,41 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Sub_Op::Type = "Sub"; \ No newline at end of file +const std::string Aidge::Sub_Op::Type = "Sub"; + +void Aidge::Sub_Op::computeOutputDims() { + // check inputs have been associated + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + } + + if (!getInput(0)->empty() && !getInput(1)->empty()) { + + std::vector<std::vector<std::size_t>> inputsDims{getInput(0)->dims(), getInput(1)->dims()}; + + std::vector<std::size_t> outDims = (inputsDims[0].size() >= inputsDims[1].size()) ? + inputsDims[0] : inputsDims[1]; + + std::vector<std::size_t>::iterator it = outDims.end(); + while (it != outDims.begin()) { + --it; + for (size_t i = 0; i < inputsDims.size(); i++) { + if(!inputsDims[i].empty()) { + std::size_t dim = inputsDims[i].back(); + inputsDims[i].pop_back(); + if (*it != dim) { + if(dim != 1) { + if (*it != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Sub Operation"); + } + else { + *it = dim; + } + } + } + } + } + } + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file