From ca67f0bcc5e4074240cef1cfdb2082003bfb6f2e Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 22 Nov 2023 15:34:31 +0000
Subject: [PATCH] Custom computeOutputDims() for Pow, Div, Mul, Sub operators

---
 include/aidge/operator/Div.hpp |  4 +++-
 include/aidge/operator/Mul.hpp |  1 +
 include/aidge/operator/Pow.hpp |  2 ++
 include/aidge/operator/Sub.hpp |  2 ++
 src/operator/Div.cpp           | 35 ++++++++++++++++++++++++++++++++++
 src/operator/Mul.cpp           | 35 ++++++++++++++++++++++++++++++++++
 src/operator/Pow.cpp           | 35 ++++++++++++++++++++++++++++++++++
 src/operator/Sub.cpp           | 35 ++++++++++++++++++++++++++++++++++
 8 files changed, 148 insertions(+), 1 deletion(-)
 create mode 100644 src/operator/Div.cpp
 create mode 100644 src/operator/Mul.cpp
 create mode 100644 src/operator/Pow.cpp
 create mode 100644 src/operator/Sub.cpp

diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp
index b4acd79e4..ba76c0bde 100644
--- a/include/aidge/operator/Div.hpp
+++ b/include/aidge/operator/Div.hpp
@@ -22,7 +22,6 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/Node.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/utils/ErrorHandling.hpp"
 
 namespace Aidge {
 
@@ -52,6 +51,9 @@ public:
         return std::make_shared<Div_Op>(*this);
     }
 
+    void computeOutputDims() override final;
+
+
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Div_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name);
diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp
index f1537f5b2..5b9ab4eb8 100644
--- a/include/aidge/operator/Mul.hpp
+++ b/include/aidge/operator/Mul.hpp
@@ -54,6 +54,7 @@ public:
         return std::make_shared<Mul_Op>(*this);
     }
 
+    void computeOutputDims() override final;
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Mul_Op>::create(name)(*this);
diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp
index 0ab73441f..0b0ae82f0 100644
--- a/include/aidge/operator/Pow.hpp
+++ b/include/aidge/operator/Pow.hpp
@@ -51,6 +51,8 @@ public:
         return std::make_shared<Pow_Op>(*this);
     }
 
+    void computeOutputDims() override final;
+
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Pow_Op>::create(name)(*this);
diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp
index 3a826bd0f..becf98926 100644
--- a/include/aidge/operator/Sub.hpp
+++ b/include/aidge/operator/Sub.hpp
@@ -56,6 +56,8 @@ public:
         return std::make_shared<Sub_Op>(*this);
     }
 
+    void computeOutputDims() override final;
+
 
     void setBackend(const std::string& name) override {
         mImpl = Registrar<Sub_Op>::create(name)(*this);
diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp
new file mode 100644
index 000000000..273eac2e8
--- /dev/null
+++ b/src/operator/Div.cpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Div.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Div_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if ((!getInput(0)->empty()) &&
+        ((getInput(1)->size() == 1) || // div by a single value
+        (getInput(1)->size() == getInput(0)->size()) || // div elem-wise
+        (getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // div by a Tensor with one dimension of output size
+    {
+        mOutputs[0]->resize(getInput(0)->dims());
+    }
+}
\ No newline at end of file
diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp
new file mode 100644
index 000000000..2e3e77288
--- /dev/null
+++ b/src/operator/Mul.cpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Mul.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Mul_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if ((!getInput(0)->empty()) &&
+        ((getInput(1)->size() == 1) || // mul by a single value
+        (getInput(1)->size() == getInput(0)->size()) || // mul elem-wise
+        (getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // mul by a Tensor with one dimension of output size
+    {
+        mOutputs[0]->resize(getInput(0)->dims());
+    }
+}
\ No newline at end of file
diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp
new file mode 100644
index 000000000..c213a47a4
--- /dev/null
+++ b/src/operator/Pow.cpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Pow.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Pow_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if ((!getInput(0)->empty()) &&
+        ((getInput(1)->size() == 1) || // pow by a single value
+        (getInput(1)->size() == getInput(0)->size()) || // pow elem-wise
+        (getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // pow by a Tensor with one dimension of output size
+    {
+        mOutputs[0]->resize(getInput(0)->dims());
+    }
+}
\ No newline at end of file
diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp
new file mode 100644
index 000000000..8175f1b7a
--- /dev/null
+++ b/src/operator/Sub.cpp
@@ -0,0 +1,35 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cassert>
+#include <cstddef>
+#include <vector>
+#include <utility>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/operator/Sub.hpp"
+#include "aidge/utils/Types.h"
+#include "aidge/utils/ErrorHandling.hpp"
+
+void Aidge::Sub_Op::computeOutputDims() {
+    // check inputs have been associated
+    if (!getInput(0) || !getInput(1)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    }
+
+    if ((!getInput(0)->empty()) &&
+        ((getInput(1)->size() == 1) || // sub by a single value
+        (getInput(1)->size() == getInput(0)->size()) || // sub elem-wise
+        (getInput(1)->nbDims() == 1 && getInput(1)->size() == getInput(0)->dims()[getInput(0)->nbDims()-1]))) // sub by a Tensor with one dimension of output size
+    {
+        mOutputs[0]->resize(getInput(0)->dims());
+    }
+}
\ No newline at end of file
-- 
GitLab