Skip to content
Snippets Groups Projects
Commit 3934a754 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Upd] Make GenericOperator registrable.

parent fe1e304b
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!247Doc export
Pipeline #58968 canceled
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace Aidge { namespace Aidge {
class GenericOperator_Op class GenericOperator_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<GenericOperator_Op, std::string, std::function<std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)>> { public Registrable<GenericOperator_Op, std::array<std::string, 2>, std::function<std::shared_ptr<OperatorImpl>(const GenericOperator_Op &)>> {
private: private:
using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>;
......
...@@ -64,5 +64,7 @@ void init_GenericOperator(py::module& m) { ...@@ -64,5 +64,7 @@ void init_GenericOperator(py::module& m) {
} }
return genericNode; return genericNode;
}, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = ""); }, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = "");
declare_registrable<GenericOperator_Op>(m, "GenericOperatorOp");
} }
} // namespace Aidge } // namespace Aidge
...@@ -86,7 +86,15 @@ bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -86,7 +86,15 @@ bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) {
} }
void Aidge::GenericOperator_Op::setBackend(const std::string & name, DeviceIdx_t device) { void Aidge::GenericOperator_Op::setBackend(const std::string & name, DeviceIdx_t device) {
Log::warn("GenericOperator::setBackend(): cannot set backend for a generic operator, as no implementation has been provided!"); if (Registrar<GenericOperator_Op>::exists({name, type()})) {
// A custom implementation exists for this meta operator
mImpl = Registrar<GenericOperator_Op>::create({name, type()})(*this);
}else{
Log::warn("GenericOperator::setBackend(): cannot set backend for a generic operator, as no implementation has been provided!");
}
for (std::size_t i = 0; i < nbOutputs(); ++i) { for (std::size_t i = 0; i < nbOutputs(); ++i) {
mOutputs[i]->setBackend(name, device); mOutputs[i]->setBackend(name, device);
...@@ -108,4 +116,4 @@ std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type, ...@@ -108,4 +116,4 @@ std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type,
Aidge::IOIndex_t nbOut, Aidge::IOIndex_t nbOut,
const std::string& name) { const std::string& name) {
return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name);
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment