diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 7fca2a46fdfe13baabafcb6790d9e30e9b04f23d..5ff1152e4bb2195e3cc0f70cc9efe9c6986cd634 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -83,7 +83,22 @@ public: * @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area. */ // virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; + + /** + * @brief Set the specified input by performing a deep copy of the given data. + * The pointer itself is not changed, thus keeping the current connections. + * @param inputIdx Index of the input to set. + */ + virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; + virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0; virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; + /** + * @brief Set the specified output by performing a deep copy of the given data. + * The pointer itself is not changed, thus keeping the current connections. + * @param inputIdx Index of the input to set. + */ + virtual void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) = 0; + virtual void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) = 0; virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; std::shared_ptr<Hook> getHook(std::string hookName) { diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index b8c180e6d5dbe4959d29b2806173606ec245cb3a..0ecccb3713afbf141a6afbcc0ab69a51383f712d 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -56,7 +56,7 @@ public: mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)), mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { - mOutputs[i] = std::make_shared<Tensor>(other.output(i)); + mOutputs[i] = std::make_shared<Tensor>(*(other.getOutput(i))); // datatype already copied } } @@ -72,6 +72,8 @@ public: /////////////////////////////////////////////////// // Tensor access // input management + void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; + void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); } inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { @@ -79,6 +81,8 @@ public: } // output management + void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final; + void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final; const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; inline Tensor& output(const IOIndex_t outputIdx) const { return *getOutput(outputIdx); diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 0b63e5a32afe2fa21cbec30855468fbef85881d2..fb6a20403adc1ee5cddb5869fd9d39ef59fb776e 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -66,16 +66,6 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer operator takes no input."); } - /** - * @brief Set the Output Tensor of the Producer operator. - * This method will create a copy of the Tensor. - * - * @param newOutput Tensor containing the values to copy - */ - void setOutput(const std::shared_ptr<Tensor>& newOutput) { - mOutputs[0] = newOutput; - } - void computeOutputDims() override final {} bool outputDimsForwarded() const override final {return true;} diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ef02b8aaef9f4ea3bd97559ad9e94c38c5b1d29e..09b1abef1e5ea66c6843594e3fa3beb20ec10740 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -12,19 +12,25 @@ #include <pybind11/pybind11.h> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/utils/Types.h" #include <pybind11/stl.h> namespace py = pybind11; namespace Aidge { void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") - .def("output", &Operator::output, py::arg("outputIdx")) - .def("input", &Operator::input, py::arg("inputIdx")) + .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) + // .def("set_output", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) + .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) + .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) + // .def("set_input", py::overload_cast<const IOIndex_t, std::shared_ptr<Data>&&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) + .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("nb_inputs", &Operator::nbInputs) - .def("nb_data_inputs", &Operator::nbDataInputs) + .def("nb_data", &Operator::nbData) + .def("nb_param", &Operator::nbParam) .def("nb_outputs", &Operator::nbOutputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) - .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) + .def("set_datatype", &Operator::setDataType, py::arg("dataType")) .def("set_backend", &Operator::setBackend, py::arg("name")) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! @@ -33,4 +39,4 @@ void init_Operator(py::module& m){ .def("add_hook", &Operator::addHook) ; } -} +} \ No newline at end of file diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 107b7ba00e4077d9f7c215257bf7fd46629481c1..3dae24b620fe99098205d7d5f23591780f1e9cb7 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -14,7 +14,7 @@ #include "aidge/utils/Types.h" // #include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/data/Tensor.hpp" @@ -30,12 +30,11 @@ void declare_Producer(py::module &m) { void init_Producer(py::module &m) { - py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( + py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>( m, "ProducerOp", py::multiple_inheritance()) .def("dims", &Producer_Op::dims) - .def("set_output_tensor", &Producer_Op::setOutputTensor) .def("get_inputs_name", &Producer_Op::getInputsName) .def("get_outputs_name", &Producer_Op::getOutputsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 1594548ce82bca45d3adee04596af82f1cc7256b..6fcb70f0f779d352cc64be07f38b9673ff322e11 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -29,6 +29,28 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } +void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str()); + } + if (getInput(inputIdx)) { + *mInputs[inputIdx] = *std::dynamic_pointer_cast<Tensor>(data); + } else { + mInputs[inputIdx] = std::make_shared<Tensor>(*std::dynamic_pointer_cast<Tensor>(data)); + } +} + +void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str()); + } + if (getInput(inputIdx)) { + *mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data)); + } else { + mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data))); + } +} + const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const { if (inputIdx >= nbInputs()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs()); @@ -36,6 +58,25 @@ const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidg return mInputs[inputIdx]; } +void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str()); + } + if (outputIdx >= nbOutputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs()); + } + *mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data); +} + +void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) { + if (strcmp(data->type(), "Tensor") != 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str()); + } + if (outputIdx >= nbOutputs()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs()); + } + *mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data)); +} const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const { if (outputIdx >= nbOutputs()) { @@ -69,7 +110,7 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const { bool forwarded = true; // check both inputs and outputs have been filled for (IOIndex_t i = 0; i < nbInputs(); ++i) { - forwarded &= !(getInput(i)->empty()); + forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false; } for (IOIndex_t i = 0; i < nbOutputs(); ++i) { forwarded &= !(getOutput(i)->empty()); @@ -82,6 +123,11 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { getOutput(i)->setDataType(dataType); } for (IOIndex_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setDataType(dataType); + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set"); + } + else { + getInput(i)->setDataType(dataType); + } } } \ No newline at end of file