Skip to content
Snippets Groups Projects
Commit cb229c42 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Producer] add constant attribute to disable setOutput method.

parent dc834182
No related branches found
No related tags found
No related merge requests found
......@@ -24,22 +24,32 @@
namespace Aidge {
enum class ProdAttr { Constant };
class Producer_Op
: public OperatorTensor,
public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
const Producer_Op &)> {
const Producer_Op &)>,
public StaticAttributes<ProdAttr, bool> {
public:
static const std::string Type;
using Attributes_ = StaticAttributes<ProdAttr, bool>;
template <ProdAttr e>
using attr = typename Attributes_::template attr<e>;
template <std::size_t DIM>
Producer_Op(const std::array<DimSize_t, DIM>& dims)
: OperatorTensor(Type, 0, 0, 1)
Producer_Op(const std::array<DimSize_t, DIM>& dims,
bool constant = false)
: OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{
mOutputs[0]->resize(dims);
}
Producer_Op(const std::shared_ptr<Tensor> tensor)
: OperatorTensor(Type, 0, 0, 1)
Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false)
: OperatorTensor(Type, 0, 0, 1),
Attributes_(attr<ProdAttr::Constant>(constant))
{
mOutputs[0] = tensor; // copy the pointer of the Tensor
}
......@@ -49,7 +59,8 @@ public:
* @param op OperatorTensor to copy.
*/
Producer_Op(const Producer_Op& op)
: OperatorTensor(op)
: OperatorTensor(op),
Attributes_(op)
{
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i)));
......@@ -89,28 +100,41 @@ public:
}
public:
void forward() override final {
printf("Basic Producer forward() function.\n");
}
void backward() override final {
printf("Basic Producer backward() function.\n");
}
void forward() override final {
printf("Basic Producer forward() function.\n");
}
void backward() override final {
printf("Basic Producer backward() function.\n");
}
void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, data);
}
void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, data);
}
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "") {
inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, const std::string& name = "", bool constant = false) {
static_assert(DIM<=MaxDim,"Too many tensor dimensions required by Producer, not supported");
return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name);
return std::make_shared<Node>(std::make_shared<Producer_Op>(dims, constant), name);
}
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <std::size_t DIM>
inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") {
return Producer(to_array(dims), name);
inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "", bool constant = false) {
return Producer(to_array(dims), name, constant);
}
inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor), name);
inline std::shared_ptr<Node> Producer(const std::shared_ptr<Tensor> tensor, const std::string& name = "", bool constant = false) {
return std::make_shared<Node>(std::make_shared<Producer_Op>(tensor, constant), name);
}
template <std::array<DimSize_t, 1>::size_type DIM>
......@@ -130,4 +154,10 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, Dim
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
\ No newline at end of file
namespace {
template <>
const char *const EnumStrings<Aidge::ProdAttr>::data[] = {
"Constant"
};
}
#endif /* AIDGE_CORE_OPERATOR_PRODUCER_H_ */
......@@ -24,20 +24,20 @@ namespace Aidge {
template <DimIdx_t DIM>
void declare_Producer(py::module &m) {
// m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&)>(&Producer), py::arg("dims"), py::arg("name") = "");
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false);
}
void init_Producer(py::module &m) {
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor>(
py::class_<Producer_Op, std::shared_ptr<Producer_Op>, OperatorTensor, Attributes>(
m,
"ProducerOp",
py::multiple_inheritance())
.def("dims", &Producer_Op::dims)
.def("get_inputs_name", &Producer_Op::getInputsName)
.def("get_outputs_name", &Producer_Op::getOutputsName);
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = "");
m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false);
declare_Producer<1>(m);
declare_Producer<2>(m);
......
......@@ -13,4 +13,4 @@
#include "aidge/operator/Producer.hpp"
const std::string Aidge::Producer_Op::Type = "Producer";
\ No newline at end of file
const std::string Aidge::Producer_Op::Type = "Producer";
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment