From cb229c4272f768aedede0895dee220b54adb8520 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 11 Jan 2024 13:37:17 +0000 Subject: [PATCH] [Producer] add constant attribute to disable setOutput method. --- include/aidge/operator/Producer.hpp | 68 +++++++++++++++------ python_binding/operator/pybind_Producer.cpp | 6 +- src/operator/Producer.cpp | 2 +- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index ee00ead69..1e082e73a 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -24,22 +24,32 @@ namespace Aidge { +enum class ProdAttr { Constant }; + class Producer_Op : public OperatorTensor, public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( - const Producer_Op &)> { + const Producer_Op &)>, + public StaticAttributes<ProdAttr, bool> { public: static const std::string Type; + using Attributes_ = StaticAttributes<ProdAttr, bool>; + template <ProdAttr e> + using attr = typename Attributes_::template attr<e>; + template <std::size_t DIM> - Producer_Op(const std::array<DimSize_t, DIM>& dims) - : OperatorTensor(Type, 0, 0, 1) + Producer_Op(const std::array<DimSize_t, DIM>& dims, + bool constant = false) + : OperatorTensor(Type, 0, 0, 1), + Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); } - Producer_Op(const std::shared_ptr<Tensor> tensor) - : OperatorTensor(Type, 0, 0, 1) + Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false) + : OperatorTensor(Type, 0, 0, 1), + Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0] = tensor; // copy the pointer of the Tensor } @@ -49,7 +59,8 @@ public: * @param op OperatorTensor to copy. */ Producer_Op(const Producer_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + Attributes_(op) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i))); @@ -89,28 +100,41 @@ public: } public: - void forward() override final { - printf("Basic Producer forward() function.\n"); - } - void backward() override final { - printf("Basic Producer backward() function.\n"); - } + void forward() override final { + printf("Basic Producer forward() function.\n"); + } + void backward() override final { + printf("Basic Producer backward() function.\n"); + } + void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override { + if (getAttr<ProdAttr::Constant>()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output."); + } + OperatorTensor::setOutput(outputIdx, data); + } + + void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override { + if (getAttr<ProdAttr::Constant>()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output."); + } + OperatorTensor::setOutput(outputIdx, data); + } }; template <std::array<DimSize_t, 1>::size_type DIM> -inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") { +inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "", bool constant = false) { static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported"); - return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); + return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name); } // helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> -inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") { - return Producer(to_array(dims), name); +inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "", bool constant = false) { + return Producer(to_array(dims), name, constant); } -inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name); +inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "", bool constant = false) { + return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name); } template <std::array<DimSize_t, 1>::size_type DIM> @@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim } } // namespace Aidge -#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ \ No newline at end of file +namespace { +template <> +const char *const EnumStrings<Aidge::ProdAttr>::data[] = { + "Constant" +}; +} +#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 3dae24b62..78d9ce348 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -24,20 +24,20 @@ namespace Aidge { template <DimIdx_t DIM> void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false); } void init_Producer(py::module &m) { - py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>( + py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor, Attributes>( m, "ProducerOp", py::multiple_inheritance()) .def("dims", &Producer_Op::dims) .def("get_inputs_name", &Producer_Op::getInputsName) .def("get_outputs_name", &Producer_Op::getOutputsName); - m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); + m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false); declare_Producer<1>(m); declare_Producer<2>(m); diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 443f2fa7d..7bccbe763 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -13,4 +13,4 @@ #include "aidge/operator/Producer.hpp" -const std::string Aidge::Producer_Op::Type = "Producer"; \ No newline at end of file +const std::string Aidge::Producer_Op::Type = "Producer"; -- GitLab