Skip to content
Snippets Groups Projects
Commit cb229c42 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Producer] add constant attribute to disable setOutput method.

parent dc834182
No related branches found
No related tags found
No related merge requests found
...@@ -24,22 +24,32 @@ ...@@ -24,22 +24,32 @@
namespace Aidge { namespace Aidge {
enum class ProdAttr { Constant };
class Producer_Op class Producer_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
const Producer_Op &)> { const Producer_Op &)>,
public StaticAttributes<ProdAttr, bool> {
public: public:
static const std::string Type; static const std::string Type;
using Attributes_ = StaticAttributes<ProdAttr, bool>;
template <ProdAttr e>
using attr = typename Attributes_::template attr<e>;
template <std::size_t DIM> template <std::size_t DIM>
Producer_Op(const std::array<DimSize_t, DIM>& dims) Producer_Op(const std::array<DimSize_t, DIM>& dims,
: OperatorTensor(Type, 0, 0, 1) bool constant = false)
: OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{ {
mOutputs[0]->resize(dims); mOutputs[0]->resize(dims);
} }
Producer_Op(const std::shared_ptr<Tensor> tensor) Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false)
: OperatorTensor(Type, 0, 0, 1) : OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{ {
mOutputs[0] = tensor; // copy the pointer of the Tensor mOutputs[0] = tensor; // copy the pointer of the Tensor
} }
...@@ -49,7 +59,8 @@ public: ...@@ -49,7 +59,8 @@ public:
* @param op OperatorTensor to copy. * @param op OperatorTensor to copy.
*/ */
Producer_Op(const Producer_Op& op) Producer_Op(const Producer_Op& op)
: OperatorTensor(op) : OperatorTensor(op),
Attributes_(op)
{ {
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i))); mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i)));
...@@ -89,28 +100,41 @@ public: ...@@ -89,28 +100,41 @@ public:
} }
public: public:
void forward() override final { void forward() override final {
printf("Basic Producer forward() function.\n"); printf("Basic Producer forward() function.\n");
} }
void backward() override final { void backward() override final {
printf("Basic Producer backward() function.\n"); printf("Basic Producer backward() function.\n");
} }
void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, data);
}
void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, data);
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") { inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "", bool constant = false) {
static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported"); static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported");
return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name);
} }
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction // helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <std::size_t DIM> template <std::size_t DIM>
inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") { inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "", bool constant = false) {
return Producer(to_array(dims), name); return Producer(to_array(dims), name, constant);
} }
inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") { inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "", bool constant = false) {
return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name); return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name);
} }
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
...@@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim ...@@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim
} }
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */ namespace {
\ No newline at end of file template <>
const char *const EnumStrings<Aidge::ProdAttr>::data[] = {
"Constant"
};
}
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
...@@ -24,20 +24,20 @@ namespace Aidge { ...@@ -24,20 +24,20 @@ namespace Aidge {
template <DimIdx_t DIM> template <DimIdx_t DIM>
void declare_Producer(py::module &m) { void declare_Producer(py::module &m) {
// m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = ""); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false);
} }
void init_Producer(py::module &m) { void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>( py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor, Attributes>(
m, m,
"ProducerOp", "ProducerOp",
py::multiple_inheritance()) py::multiple_inheritance())
.def("dims", &Producer_Op::dims) .def("dims", &Producer_Op::dims)
.def("get_inputs_name", &Producer_Op::getInputsName) .def("get_inputs_name", &Producer_Op::getInputsName)
.def("get_outputs_name", &Producer_Op::getOutputsName); .def("get_outputs_name", &Producer_Op::getOutputsName);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false);
declare_Producer<1>(m); declare_Producer<1>(m);
declare_Producer<2>(m); declare_Producer<2>(m);
......
...@@ -13,4 +13,4 @@ ...@@ -13,4 +13,4 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
const std::string Aidge::Producer_Op::Type = "Producer"; const std::string Aidge::Producer_Op::Type = "Producer";
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment