Skip to content
Snippets Groups Projects
Commit 1be1a5cc authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add Atan operator to aidge_core.

parent 86e00a0a
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_ATAN_H_
#define AIDGE_CORE_OPERATOR_ATAN_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Atan_Op : public OperatorTensor,
public Registrable<Atan_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Atan_Op&)>> {
public:
static const std::string Type;
Atan_Op();
Atan_Op(const Atan_Op& op);
std::shared_ptr<Operator> clone() const override;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
std::shared_ptr<Node> Atan(const std::string& name = "");
}
#endif /* AIDGE_CORE_OPERATOR_ATAN_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Atan.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Atan(py::module& m) {
py::class_<Atan_Op, std::shared_ptr<Atan_Op>, OperatorTensor>(m, "AtanOp", py::multiple_inheritance())
.def(py::init<>())
.def_static("get_inputs_name", &Atan_Op::getInputsName)
.def_static("get_outputs_name", &Atan_Op::getOutputsName);
declare_registrable<Atan_Op>(m, "AtanOp");
m.def("Atan", &Atan, py::arg("name") = "");
}
} // namespace Aidge
...@@ -31,6 +31,7 @@ void init_OperatorTensor(py::module&); ...@@ -31,6 +31,7 @@ void init_OperatorTensor(py::module&);
void init_Add(py::module&); void init_Add(py::module&);
void init_And(py::module&); void init_And(py::module&);
void init_ArgMax(py::module&); void init_ArgMax(py::module&);
void init_Atan(py::module&);
void init_AvgPooling(py::module&); void init_AvgPooling(py::module&);
void init_BatchNorm(py::module&); void init_BatchNorm(py::module&);
void init_Concat(py::module&); void init_Concat(py::module&);
...@@ -113,6 +114,7 @@ void init_Aidge(py::module& m) { ...@@ -113,6 +114,7 @@ void init_Aidge(py::module& m) {
init_Add(m); init_Add(m);
init_And(m); init_And(m);
init_ArgMax(m); init_ArgMax(m);
init_Atan(m);
init_AvgPooling(m); init_AvgPooling(m);
init_BatchNorm(m); init_BatchNorm(m);
init_Concat(m); init_Concat(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Atan.hpp"
#include <memory>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Atan_Op::Type = "Atan";
Aidge::Atan_Op::Atan_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {}
Aidge::Atan_Op::Atan_Op(const Aidge::Atan_Op& op)
: OperatorTensor(op)
{
if (op.mImpl){
SET_IMPL_MACRO(Atan_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
}
std::shared_ptr<Aidge::Operator> Aidge::Atan_Op::clone() const {
return std::make_shared<Atan_Op>(*this);
}
void Aidge::Atan_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
mImpl = Registrar<Atan_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Atan_Op::getAvailableBackends() const {
return Registrar<Atan_Op>::getKeys();
}
///////////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Atan(const std::string& name) {
return std::make_shared<Node>(std::make_shared<Atan_Op>(), name);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment