Skip to content
Snippets Groups Projects
Commit c3acda68 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Olivier BICHLER
Browse files

add Equal operator

parent 3d892e42
No related branches found
No related tags found
2 merge requests!341Error,!311fix failed onnx tests
......@@ -47,6 +47,7 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Equal.hpp"
#include "aidge/operator/Erf.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Gather.hpp"
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_EQUAL_H_
#define AIDGE_CORE_OPERATOR_EQUAL_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief Tensor element-wise logical equal operation.
*/
class Equal_Op : public OperatorTensor,
public Registrable<Equal_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Equal_Op&)>> {
public:
static const std::string Type;
/**
* @brief Compute element-wise Equal operation on two given inputs.
* @details supports broadcasting of both operands.
*/
Equal_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s),
* but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Equal_Op(const Equal_Op& op)
: OperatorTensor(op)
{
if (op.mImpl) {
SET_IMPL_MACRO(Equal_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Equal_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Equal_Op>(*this);
}
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
std::set<std::string> getAvailableBackends() const override;
static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Equal(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Equal_Op>(), name);
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_EQUAL_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Equal.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Equal(py::module& m) {
py::class_<Equal_Op, std::shared_ptr<Equal_Op>, OperatorTensor>(m, "Equal_Op", py::multiple_inheritance(),
R"mydelimiter( Initialize an Equal operator.)mydelimiter")
.def(py::init<>())
.def_static("get_inputs_name", &Equal_Op::getInputsName)
.def_static("get_outputs_name", &Equal_Op::getOutputsName);
declare_registrable<Equal_Op>(m, "EqualOp");
m.def("Equal", &Equal, py::arg("name") = "",
R"mydelimiter(
Initialize a node containing an Equal operator.
:param name : name of the node.
)mydelimiter");
}
} // namespace Aidge
......@@ -50,9 +50,11 @@ void init_Conv(py::module&);
void init_ConvDepthWise(py::module&);
void init_DepthToSpace(py::module&);
void init_Div(py::module&);
void init_Equal(py::module&);
void init_Erf(py::module&);
void init_Expand(py::module&);
void init_FC(py::module&);
void init_Flatten(py::module&);
void init_Gather(py::module&);
void init_GenericOperator(py::module&);
void init_GlobalAveragePooling(py::module&);
......@@ -149,9 +151,11 @@ void init_Aidge(py::module& m) {
init_ConstantOfShape(m);
init_DepthToSpace(m);
init_Div(m);
init_Equal(m);
init_Erf(m);
init_Expand(m);
init_FC(m);
init_Flatten(m);
init_Gather(m);
init_GenericOperator(m);
init_GlobalAveragePooling(m);
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <memory>
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Equal.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Equal_Op::Type = "Equal";
bool Aidge::Equal_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims();
const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims();
std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1;
const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1;
std::size_t out_id = outDims.size() - 1;
std::size_t low_id = lowDims.size() - 1;
std::size_t i = 0;
while (i++ < lowDims.size()) {
if (outDims[out_id] == 1) {
outDims[out_id] = lowDims[low_id];
}
else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Equal Operation: {} for input#0 vs {} for input#1",
inputsDims0, inputsDims1);
}
--out_id;
--low_id;
}
mOutputs[0]->resize(outDims);
return true;
}
return false;
}
void Aidge::Equal_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Equal_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Equal_Op::getAvailableBackends() const {
return Registrar<Equal_Op>::getKeys();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment