Skip to content
Snippets Groups Projects
Commit c6c739fe authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add div operator

parent 1755ca8b
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MatMul.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_DIV_H_
#define AIDGE_CORE_OPERATOR_DIV_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Div_Op : public Operator,
public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Div";
Div_Op()
: Operator(Type)
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Div_Op(const Div_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Div_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Div_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Div_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 2 && "operator supports only 2 inputs");
(void) inputIdx; // avoid unused warning
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs[0]->empty())
mOutput->resize(mInputs[0]->dims());
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 2) && "Div Operator has 2 inputs");
(void) inputIdx; // avoid unused warning
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Div Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 2 && "operator supports only 2 inputs");
(void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Div_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[0]->setBackend(name);
mInputs[1]->setBackend(name);
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 2; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 2; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Div(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Div_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_DIV_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Div(py::module& m) {
py::class_<Div_Op, std::shared_ptr<Div_Op>, Operator>(m, "DivOp", py::multiple_inheritance())
.def("get_inputs_name", &Div_Op::getInputsName)
.def("get_outputs_name", &Div_Op::getOutputsName);
m.def("Div", &Div, py::arg("name") = "");
}
} // namespace Aidge
......@@ -25,6 +25,7 @@ void init_AvgPooling(py::module&);
void init_BatchNorm(py::module&);
void init_Conv(py::module&);
void init_ConvDepthWise(py::module&);
void init_Div(py::module&);
void init_FC(py::module&);
void init_GenericOperator(py::module&);
void init_LeakyReLU(py::module&);
......@@ -68,6 +69,7 @@ void init_Aidge(py::module& m){
init_BatchNorm(m);
init_Conv(m);
init_ConvDepthWise(m);
init_Div(m);
init_FC(m);
init_GenericOperator(m);
init_LeakyReLU(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment