From ade2edaf9e5359b9c7201789120c775c4a70f963 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 22 Sep 2023 11:36:45 +0000 Subject: [PATCH] [Producer] Add setOutputTensor method. --- include/aidge/operator/Producer.hpp | 12 +++++++++++- python_binding/operator/pybind_Producer.cpp | 11 ++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index acdc69b69..de885d47c 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -55,6 +55,16 @@ public: assert(false && "Producer operator takes no input"); } + /** + * @brief Set the Output Tensor of the Producer operator. + * This method will create a copy of the Tensor. + * + * @param newOutput Tensor containing the values to copy + */ + void setOutputTensor(const Tensor& newOutput) { + *mOutput = newOutput; + } + void computeOutputDims() override final {} bool outputDimsForwarded() const override final {return true;} @@ -143,4 +153,4 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim } } // namespace Aidge -#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index ea9880800..4714e096f 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -26,18 +26,19 @@ template <DimIdx_t DIM> void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); - + } void init_Producer(py::module &m) { py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( - m, - "ProducerOp", + m, + "ProducerOp", py::multiple_inheritance()) - .def("dims", &Producer_Op::dims); + .def("dims", &Producer_Op::dims) + .def("set_output_tensor", &Producer_Op::setOutputTensor); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); - + declare_Producer<1>(m); declare_Producer<2>(m); declare_Producer<3>(m); -- GitLab