From c3acda68c60b609525863bcc44e3e73ae1ca7403 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 21 Jan 2025 14:44:58 +0100
Subject: [PATCH] add Equal operator

---
 include/aidge/aidge.hpp                  |  1 +
 include/aidge/operator/Equal.hpp         | 82 ++++++++++++++++++++++++
 python_binding/operator/pybind_Equal.cpp | 34 ++++++++++
 python_binding/pybind_core.cpp           |  4 ++
 src/operator/Equal.cpp                   | 62 ++++++++++++++++++
 5 files changed, 183 insertions(+)
 create mode 100644 include/aidge/operator/Equal.hpp
 create mode 100644 python_binding/operator/pybind_Equal.cpp
 create mode 100644 src/operator/Equal.cpp

diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp
index 3031fc19b..cd36a6547 100644
--- a/include/aidge/aidge.hpp
+++ b/include/aidge/aidge.hpp
@@ -47,6 +47,7 @@
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ConvDepthWise.hpp"
 #include "aidge/operator/Div.hpp"
+#include "aidge/operator/Equal.hpp"
 #include "aidge/operator/Erf.hpp"
 #include "aidge/operator/FC.hpp"
 #include "aidge/operator/Gather.hpp"
diff --git a/include/aidge/operator/Equal.hpp b/include/aidge/operator/Equal.hpp
new file mode 100644
index 000000000..12bc9af78
--- /dev/null
+++ b/include/aidge/operator/Equal.hpp
@@ -0,0 +1,82 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CORE_OPERATOR_EQUAL_H_
+#define AIDGE_CORE_OPERATOR_EQUAL_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/utils/Types.h"
+
+namespace Aidge {
+
+/**
+ * @brief Tensor element-wise logical equal operation.
+ */
+class Equal_Op : public OperatorTensor,
+    public Registrable<Equal_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Equal_Op&)>> {
+public:
+    static const std::string Type;
+
+    /**
+     * @brief Compute element-wise Equal operation on two given inputs.
+     * @details supports broadcasting of both operands.
+     */
+    Equal_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {}
+
+    /**
+     * @brief Copy-constructor. Copy the operator attributes and its output tensor(s),
+     * but not its input tensors (the new operator has no input associated).
+     * @param op Operator to copy.
+     */
+    Equal_Op(const Equal_Op& op)
+        : OperatorTensor(op)
+    {
+        if (op.mImpl) {
+            SET_IMPL_MACRO(Equal_Op, *this, op.backend());
+        } else {
+            mImpl = nullptr;
+        }
+    }
+
+    /**
+     * @brief Clone the operator using its copy-constructor.
+     * @see Operator::Equal_Op
+     */
+    std::shared_ptr<Operator> clone() const override {
+        return std::make_shared<Equal_Op>(*this);
+    }
+
+    bool forwardDims(bool allowDataDependency = false) override final;
+
+    void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
+    std::set<std::string> getAvailableBackends() const override;
+
+    static const std::vector<std::string> getInputsName(){
+        return {"data_input_1", "data_input_2"};
+    }
+    static const std::vector<std::string> getOutputsName(){
+        return {"data_output"};
+    }
+};
+
+inline std::shared_ptr<Node> Equal(const std::string& name = "") {
+    return std::make_shared<Node>(std::make_shared<Equal_Op>(), name);
+}
+} // namespace Aidge
+
+#endif /* AIDGE_CORE_OPERATOR_EQUAL_H_ */
diff --git a/python_binding/operator/pybind_Equal.cpp b/python_binding/operator/pybind_Equal.cpp
new file mode 100644
index 000000000..ef4488edc
--- /dev/null
+++ b/python_binding/operator/pybind_Equal.cpp
@@ -0,0 +1,34 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <pybind11/pybind11.h>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/Equal.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+
+namespace py = pybind11;
+namespace Aidge {
+
+void init_Equal(py::module& m) {
+    py::class_<Equal_Op, std::shared_ptr<Equal_Op>, OperatorTensor>(m, "Equal_Op", py::multiple_inheritance(),
+          R"mydelimiter( Initialize an Equal operator.)mydelimiter")
+    .def(py::init<>())
+    .def_static("get_inputs_name", &Equal_Op::getInputsName)
+    .def_static("get_outputs_name", &Equal_Op::getOutputsName);
+    declare_registrable<Equal_Op>(m, "EqualOp");
+    m.def("Equal", &Equal, py::arg("name") = "",
+	   R"mydelimiter(
+        Initialize a node containing an Equal operator.
+			:param name : name of the node.
+		)mydelimiter");
+}
+}  // namespace Aidge
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index cc6f0bf25..ef1111b39 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -50,9 +50,11 @@ void init_Conv(py::module&);
 void init_ConvDepthWise(py::module&);
 void init_DepthToSpace(py::module&);
 void init_Div(py::module&);
+void init_Equal(py::module&);
 void init_Erf(py::module&);
 void init_Expand(py::module&);
 void init_FC(py::module&);
+void init_Flatten(py::module&);
 void init_Gather(py::module&);
 void init_GenericOperator(py::module&);
 void init_GlobalAveragePooling(py::module&);
@@ -149,9 +151,11 @@ void init_Aidge(py::module& m) {
     init_ConstantOfShape(m);
     init_DepthToSpace(m);
     init_Div(m);
+    init_Equal(m);
     init_Erf(m);
     init_Expand(m);
     init_FC(m);
+    init_Flatten(m);
     init_Gather(m);
     init_GenericOperator(m);
     init_GlobalAveragePooling(m);
diff --git a/src/operator/Equal.cpp b/src/operator/Equal.cpp
new file mode 100644
index 000000000..cc0fcd984
--- /dev/null
+++ b/src/operator/Equal.cpp
@@ -0,0 +1,62 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cstddef>    // std::size_t
+#include <memory>
+#include <stdexcept>  // std::runtime_error
+#include <string>
+#include <vector>
+
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/Equal.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Types.h"
+
+const std::string Aidge::Equal_Op::Type = "Equal";
+
+bool Aidge::Equal_Op::forwardDims(bool /*allowDataDependency*/) {
+    if (inputsAssociated()) {
+        const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims();
+        const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims();
+
+        std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1;
+        const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1;
+
+        std::size_t out_id = outDims.size() - 1;
+        std::size_t low_id = lowDims.size() - 1;
+        std::size_t i = 0;
+        while (i++ < lowDims.size()) {
+            if (outDims[out_id] == 1) {
+                outDims[out_id] = lowDims[low_id];
+            }
+            else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) {
+                AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Equal Operation: {} for input#0 vs {} for input#1",
+                    inputsDims0, inputsDims1);
+            }
+            --out_id;
+            --low_id;
+        }
+        mOutputs[0]->resize(outDims);
+        return true;
+    }
+
+    return false;
+}
+
+void Aidge::Equal_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
+    SET_IMPL_MACRO(Equal_Op, *this, name);
+    mOutputs[0]->setBackend(name, device);
+}
+
+std::set<std::string> Aidge::Equal_Op::getAvailableBackends() const {
+    return Registrar<Equal_Op>::getKeys();
+}
-- 
GitLab