diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 3031fc19b335f6e77bb7999f8b3a2b107e3f5323..cd36a654772d2d641b9af32bb74b1336f4a9742d 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -47,6 +47,7 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/Div.hpp" +#include "aidge/operator/Equal.hpp" #include "aidge/operator/Erf.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/Gather.hpp" diff --git a/include/aidge/operator/Equal.hpp b/include/aidge/operator/Equal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12bc9af7812aedf52a4502f270e136c65a4a9756 --- /dev/null +++ b/include/aidge/operator/Equal.hpp @@ -0,0 +1,82 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_EQUAL_H_ +#define AIDGE_CORE_OPERATOR_EQUAL_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +/** + * @brief Tensor element-wise logical equal operation. + */ +class Equal_Op : public OperatorTensor, + public Registrable<Equal_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Equal_Op&)>> { +public: + static const std::string Type; + + /** + * @brief Compute element-wise Equal operation on two given inputs. + * @details supports broadcasting of both operands. + */ + Equal_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), + * but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Equal_Op(const Equal_Op& op) + : OperatorTensor(op) + { + if (op.mImpl) { + SET_IMPL_MACRO(Equal_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Equal_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Equal_Op>(*this); + } + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + std::set<std::string> getAvailableBackends() const override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input_1", "data_input_2"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Equal(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Equal_Op>(), name); +} +} // namespace Aidge + +#endif /* AIDGE_CORE_OPERATOR_EQUAL_H_ */ diff --git a/python_binding/operator/pybind_Equal.cpp b/python_binding/operator/pybind_Equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef4488edce3c096c368f43a07de6b0d65f368013 --- /dev/null +++ b/python_binding/operator/pybind_Equal.cpp @@ -0,0 +1,34 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Equal.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Equal(py::module& m) { + py::class_<Equal_Op, std::shared_ptr<Equal_Op>, OperatorTensor>(m, "Equal_Op", py::multiple_inheritance(), + R"mydelimiter( Initialize an Equal operator.)mydelimiter") + .def(py::init<>()) + .def_static("get_inputs_name", &Equal_Op::getInputsName) + .def_static("get_outputs_name", &Equal_Op::getOutputsName); + declare_registrable<Equal_Op>(m, "EqualOp"); + m.def("Equal", &Equal, py::arg("name") = "", + R"mydelimiter( + Initialize a node containing an Equal operator. + :param name : name of the node. + )mydelimiter"); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index cc6f0bf2502027fea467b9db39561769fcebbd2b..ef1111b39a2f6fff3153dfb7441543ff5c3956c2 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -50,9 +50,11 @@ void init_Conv(py::module&); void init_ConvDepthWise(py::module&); void init_DepthToSpace(py::module&); void init_Div(py::module&); +void init_Equal(py::module&); void init_Erf(py::module&); void init_Expand(py::module&); void init_FC(py::module&); +void init_Flatten(py::module&); void init_Gather(py::module&); void init_GenericOperator(py::module&); void init_GlobalAveragePooling(py::module&); @@ -149,9 +151,11 @@ void init_Aidge(py::module& m) { init_ConstantOfShape(m); init_DepthToSpace(m); init_Div(m); + init_Equal(m); init_Erf(m); init_Expand(m); init_FC(m); + init_Flatten(m); init_Gather(m); init_GenericOperator(m); init_GlobalAveragePooling(m); diff --git a/src/operator/Equal.cpp b/src/operator/Equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc0fcd984062baeac3da47d03a3d64cda63eada3 --- /dev/null +++ b/src/operator/Equal.cpp @@ -0,0 +1,62 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> // std::size_t +#include <memory> +#include <stdexcept> // std::runtime_error +#include <string> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Equal.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::Equal_Op::Type = "Equal"; + +bool Aidge::Equal_Op::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { + const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); + const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); + + std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1; + const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1; + + std::size_t out_id = outDims.size() - 1; + std::size_t low_id = lowDims.size() - 1; + std::size_t i = 0; + while (i++ < lowDims.size()) { + if (outDims[out_id] == 1) { + outDims[out_id] = lowDims[low_id]; + } + else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Equal Operation: {} for input#0 vs {} for input#1", + inputsDims0, inputsDims1); + } + --out_id; + --low_id; + } + mOutputs[0]->resize(outDims); + return true; + } + + return false; +} + +void Aidge::Equal_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(Equal_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> Aidge::Equal_Op::getAvailableBackends() const { + return Registrar<Equal_Op>::getKeys(); +}