Skip to content
Snippets Groups Projects
Commit 62ef56c4 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add sub operator

parent 3dab8e7c
No related branches found
No related tags found
No related merge requests found
Pipeline #33583 failed
...@@ -44,8 +44,9 @@ ...@@ -44,8 +44,9 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Pow.hpp" #include "aidge/operator/Pow.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Scaling.hpp" #include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Attributes.hpp" #include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SUB_H_
#define AIDGE_CORE_OPERATOR_SUB_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Sub_Op : public Operator,
public Registrable<Sub_Op, std::string, std::unique_ptr<OperatorImpl>(const Sub_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Sub";
Sub_Op()
: Operator(Type)
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Sub_Op(const Sub_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Sub_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Sub_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Sub_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 2 && "operator supports only 2 inputs");
(void) inputIdx; // avoid unused warning
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs[0]->empty())
mOutput->resize(mInputs[0]->dims());
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Add operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 2) && "Sub Operator has 2 inputs");
(void) inputIdx; // avoid unused warning
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Sub Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 2 && "operator supports only 2 inputs");
(void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Sub_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[0]->setBackend(name);
mInputs[1]->setBackend(name);
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 2; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 2; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Sub(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Sub_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_SUB_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Sub(py::module& m) {
py::class_<Sub_Op, std::shared_ptr<Sub_Op>, Operator>(m, "SubOp", py::multiple_inheritance())
.def("get_inputs_name", &Sub_Op::getInputsName)
.def("get_outputs_name", &Sub_Op::getOutputsName);
m.def("Sub", &Sub, py::arg("name") = "");
}
} // namespace Aidge
...@@ -37,6 +37,7 @@ void init_Producer(py::module&); ...@@ -37,6 +37,7 @@ void init_Producer(py::module&);
void init_Pow(py::module&); void init_Pow(py::module&);
void init_ReLU(py::module&); void init_ReLU(py::module&);
void init_Softmax(py::module&); void init_Softmax(py::module&);
void init_Sub(py::module&);
void init_Node(py::module&); void init_Node(py::module&);
void init_GraphView(py::module&); void init_GraphView(py::module&);
...@@ -81,6 +82,7 @@ void init_Aidge(py::module& m){ ...@@ -81,6 +82,7 @@ void init_Aidge(py::module& m){
init_Pow(m); init_Pow(m);
init_ReLU(m); init_ReLU(m);
init_Softmax(m); init_Softmax(m);
init_Sub(m);
init_Producer(m); init_Producer(m);
init_Match(m); init_Match(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment