Skip to content
Snippets Groups Projects
Commit 52272348 authored by Maxence Naud's avatar Maxence Naud
Browse files

Replace input()/output() by setInput()/setOutput()

parent cbcd268c
No related branches found
No related tags found
No related merge requests found
...@@ -83,7 +83,22 @@ public: ...@@ -83,7 +83,22 @@ public:
* @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area. * @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/ */
// virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; // virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
/**
* @brief Set the specified input by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set.
*/
virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0;
/**
* @brief Set the specified output by performing a deep copy of the given data.
* The pointer itself is not changed, thus keeping the current connections.
* @param inputIdx Index of the input to set.
*/
virtual void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0;
std::shared_ptr<Hook> getHook(std::string hookName) { std::shared_ptr<Hook> getHook(std::string hookName) {
......
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)), mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) { mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) {
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(other.output(i)); mOutputs[i] = std::make_shared<Tensor>(*(other.getOutput(i)));
// datatype already copied // datatype already copied
} }
} }
...@@ -72,6 +72,8 @@ public: ...@@ -72,6 +72,8 @@ public:
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// Tensor access // Tensor access
// input management // input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); } inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
...@@ -79,6 +81,8 @@ public: ...@@ -79,6 +81,8 @@ public:
} }
// output management // output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
inline Tensor& output(const IOIndex_t outputIdx) const { inline Tensor& output(const IOIndex_t outputIdx) const {
return *getOutput(outputIdx); return *getOutput(outputIdx);
......
...@@ -66,16 +66,6 @@ public: ...@@ -66,16 +66,6 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer operator takes no input."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer operator takes no input.");
} }
/**
* @brief Set the Output Tensor of the Producer operator.
* This method will create a copy of the Tensor.
*
* @param newOutput Tensor containing the values to copy
*/
void setOutput(const std::shared_ptr<Tensor>& newOutput) {
mOutputs[0] = newOutput;
}
void computeOutputDims() override final {} void computeOutputDims() override final {}
bool outputDimsForwarded() const override final {return true;} bool outputDimsForwarded() const override final {return true;}
......
...@@ -12,19 +12,25 @@ ...@@ -12,19 +12,25 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"
#include <pybind11/stl.h> #include <pybind11/stl.h>
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Operator(py::module& m){ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("output", &Operator::output, py::arg("outputIdx")) .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("input", &Operator::input, py::arg("inputIdx")) // .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data"))
.def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx"))
.def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
// .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data"))
.def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs) .def("nb_inputs", &Operator::nbInputs)
.def("nb_data_inputs", &Operator::nbDataInputs) .def("nb_data", &Operator::nbData)
.def("nb_param", &Operator::nbParam)
.def("nb_outputs", &Operator::nbOutputs) .def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_datatype", &Operator::setDataType, py::arg("dataType"))
.def("set_backend", &Operator::setBackend, py::arg("name")) .def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward) .def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
...@@ -33,4 +39,4 @@ void init_Operator(py::module& m){ ...@@ -33,4 +39,4 @@ void init_Operator(py::module& m){
.def("add_hook", &Operator::addHook) .def("add_hook", &Operator::addHook)
; ;
} }
} }
\ No newline at end of file
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
// #include "aidge/backend/OperatorImpl.hpp" // #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
...@@ -30,12 +30,11 @@ void declare_Producer(py::module &m) { ...@@ -30,12 +30,11 @@ void declare_Producer(py::module &m) {
void init_Producer(py::module &m) { void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>(
m, m,
"ProducerOp", "ProducerOp",
py::multiple_inheritance()) py::multiple_inheritance())
.def("dims", &Producer_Op::dims) .def("dims", &Producer_Op::dims)
.def("set_output_tensor", &Producer_Op::setOutputTensor)
.def("get_inputs_name", &Producer_Op::getInputsName) .def("get_inputs_name", &Producer_Op::getInputsName)
.def("get_outputs_name", &Producer_Op::getOutputsName); .def("get_outputs_name", &Producer_Op::getOutputsName);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
......
...@@ -29,6 +29,28 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons ...@@ -29,6 +29,28 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
} }
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(*std::dynamic_pointer_cast<Tensor>(data));
}
}
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data)));
}
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const { const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) { if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs()); AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
...@@ -36,6 +58,25 @@ const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidg ...@@ -36,6 +58,25 @@ const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidg
return mInputs[inputIdx]; return mInputs[inputIdx];
} }
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
}
*mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const { const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) { if (outputIdx >= nbOutputs()) {
...@@ -69,7 +110,7 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const { ...@@ -69,7 +110,7 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const {
bool forwarded = true; bool forwarded = true;
// check both inputs and outputs have been filled // check both inputs and outputs have been filled
for (IOIndex_t i = 0; i < nbInputs(); ++i) { for (IOIndex_t i = 0; i < nbInputs(); ++i) {
forwarded &= !(getInput(i)->empty()); forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false;
} }
for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty()); forwarded &= !(getOutput(i)->empty());
...@@ -82,6 +123,11 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { ...@@ -82,6 +123,11 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
getOutput(i)->setDataType(dataType); getOutput(i)->setDataType(dataType);
} }
for (IOIndex_t i = 0; i < nbInputs(); ++i) { for (IOIndex_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setDataType(dataType); if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
}
else {
getInput(i)->setDataType(dataType);
}
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment