Skip to content
Snippets Groups Projects
Commit 52272348 authored by Maxence Naud's avatar Maxence Naud
Browse files

Replace input()/output() by setInput()/setOutput()

parent cbcd268c
No related branches found
No related tags found
No related merge requests found
......@@ -83,7 +83,22 @@ public:
* @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/
// virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
/**
* @brief Set the specified input by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set.
*/
virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0;
/**
* @brief Set the specified output by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set.
*/
virtual void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0;
std::shared_ptr<Hook> getHook(std::string hookName) {
......
......@@ -56,7 +56,7 @@ public:
mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) {
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(other.output(i));
mOutputs[i] = std::make_shared<Tensor>(*(other.getOutput(i)));
// datatype already copied
}
}
......@@ -72,6 +72,8 @@ public:
///////////////////////////////////////////////////
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
......@@ -79,6 +81,8 @@ public:
}
// output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
inline Tensor& output(const IOIndex_t outputIdx) const {
return *getOutput(outputIdx);
......
......@@ -66,16 +66,6 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer operator takes no input.");
}
/**
* @brief Set the Output Tensor of the Producer operator.
* This method will create a copy of the Tensor.
*
* @param newOutput Tensor containing the values to copy
*/
void setOutput(const std::shared_ptr<Tensor>& newOutput) {
mOutputs[0] = newOutput;
}
void computeOutputDims() override final {}
bool outputDimsForwarded() const override final {return true;}
......
......@@ -12,19 +12,25 @@
#include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"
#include <pybind11/stl.h>
namespace py = pybind11;
namespace Aidge {
void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("output", &Operator::output, py::arg("outputIdx"))
.def("input", &Operator::input, py::arg("inputIdx"))
.def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
// .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
// .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs)
.def("nb_data_inputs", &Operator::nbDataInputs)
.def("nb_data", &Operator::nbData)
.def("nb_param", &Operator::nbParam)
.def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype"))
.def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
......@@ -33,4 +39,4 @@ void init_Operator(py::module& m){
.def("add_hook", &Operator::addHook)
;
}
}
}
\ No newline at end of file
......@@ -14,7 +14,7 @@
#include "aidge/utils/Types.h"
// #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/data/Tensor.hpp"
......@@ -30,12 +30,11 @@ void declare_Producer(py::module &m) {
void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>(
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>(
m,
"ProducerOp",
py::multiple_inheritance())
.def("dims", &Producer_Op::dims)
.def("set_output_tensor", &Producer_Op::setOutputTensor)
.def("get_inputs_name", &Producer_Op::getInputsName)
.def("get_outputs_name", &Producer_Op::getOutputsName);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
......
......@@ -29,6 +29,28 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(*std::dynamic_pointer_cast<Tensor>(data));
}
}
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data)));
}
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
......@@ -36,6 +58,25 @@ const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidg
return mInputs[inputIdx];
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) {
......@@ -69,7 +110,7 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const {
bool forwarded = true;
// check both inputs and outputs have been filled
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
forwarded &= !(getInput(i)->empty());
forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false;
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty());
......@@ -82,6 +123,11 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
getOutput(i)->setDataType(dataType);
}
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setDataType(dataType);
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
}
else {
getInput(i)->setDataType(dataType);
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment