Skip to content
Snippets Groups Projects
Commit ade2edaf authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Producer] Add setOutputTensor method.

parent ec958ea7
No related branches found
No related tags found
1 merge request!9Fuse bn
...@@ -55,6 +55,16 @@ public: ...@@ -55,6 +55,16 @@ public:
assert(false && "Producer operator takes no input"); assert(false && "Producer operator takes no input");
} }
/**
* @brief Set the Output Tensor of the Producer operator.
* This method will create a copy of the Tensor.
*
* @param newOutput Tensor containing the values to copy
*/
void setOutputTensor(const Tensor& newOutput) {
*mOutput = newOutput;
}
void computeOutputDims() override final {} void computeOutputDims() override final {}
bool outputDimsForwarded() const override final {return true;} bool outputDimsForwarded() const override final {return true;}
...@@ -143,4 +153,4 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim ...@@ -143,4 +153,4 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim
} }
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ #endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
\ No newline at end of file
...@@ -26,18 +26,19 @@ template <DimIdx_t DIM> ...@@ -26,18 +26,19 @@ template <DimIdx_t DIM>
void declare_Producer(py::module &m) { void declare_Producer(py::module &m) {
// m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = "");
} }
void init_Producer(py::module &m) { void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>( py::class_<Producer_Op, std::shared_ptr<Producer_Op>, Operator>(
m, m,
"ProducerOp", "ProducerOp",
py::multiple_inheritance()) py::multiple_inheritance())
.def("dims", &Producer_Op::dims); .def("dims", &Producer_Op::dims)
.def("set_output_tensor", &Producer_Op::setOutputTensor);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
declare_Producer<1>(m); declare_Producer<1>(m);
declare_Producer<2>(m); declare_Producer<2>(m);
declare_Producer<3>(m); declare_Producer<3>(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment