From ade2edaf9e5359b9c7201789120c775c4a70f963 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 22 Sep 2023 11:36:45 +0000
Subject: [PATCH] [Producer] Add setOutputTensor method.

---
 include/aidge/operator/Producer.hpp         | 12 +++++++++++-
 python_binding/operator/pybind_Producer.cpp | 11 ++++++-----
 2 files changed, 17 insertions(+), 6 deletions(-)

diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp
index acdc69b69..de885d47c 100644
--- a/include/aidge/operator/Producer.hpp
+++ b/include/aidge/operator/Producer.hpp
@@ -55,6 +55,16 @@ public:
         assert(false && "Producer operator takes no input");
     }
 
+    /**
+     * @brief Set the Output Tensor of the Producer operator.
+     * This method will create a copy of the Tensor.
+     *
+     * @param newOutput Tensor containing the values to copy 
+     */
+    void setOutputTensor(const Tensor& newOutput) {
+        *mOutput = newOutput;
+    }
+
     void computeOutputDims() override final {}
 
     bool outputDimsForwarded() const override final {return true;}
@@ -143,4 +153,4 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim
 }
 } // namespace Aidge
 
-#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
\ No newline at end of file
+#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp
index ea9880800..4714e096f 100644
--- a/python_binding/operator/pybind_Producer.cpp
+++ b/python_binding/operator/pybind_Producer.cpp
@@ -26,18 +26,19 @@ template <DimIdx_t DIM>
 void declare_Producer(py::module &m) {
     // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = "");
-    
+
 }
 
 
 void init_Producer(py::module &m) {
     py::class_<Producer_Op,  std::shared_ptr<Producer_Op>, Operator>(
-        m, 
-        "ProducerOp", 
+        m,
+        "ProducerOp",
         py::multiple_inheritance())
-    .def("dims", &Producer_Op::dims);
+    .def("dims", &Producer_Op::dims)
+    .def("set_output_tensor", &Producer_Op::setOutputTensor);
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
-    
+
     declare_Producer<1>(m);
     declare_Producer<2>(m);
     declare_Producer<3>(m);
-- 
GitLab