diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 327f4f7c3d43b5194f23cfaed8674ee0b47bd6a2..f2e4722aa6b02d6f1d5ffa13cecb9578dd8cf034 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -75,6 +75,10 @@ public: inline void addAttr(const std::string& name, const T& value) const { mAttributes -> template addAttr<T>(name, value); } + inline void setAttrs(const std::map<std::string, future_std::any>& attrs) { + *mAttributes = attrs; + } + // Helper functions that can be used with setForwardDims(): static const ComputeDimsFunc Identity; static const ComputeDimsFunc InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs); @@ -84,9 +88,9 @@ public: }; /** - * @brief Fictive custom operator not associated with any implementation. + * @brief Generic operator not associated with any implementation. * Allows to import unknown operators and simulate new ones. - * @param type Type of the fictive operator. + * @param type Type of the generic operator. * @param inputCategory List inputs with their category * @param nbOut Number of output data. * @param name (optional) name of the Operator. @@ -96,9 +100,9 @@ std::shared_ptr<Node> GenericOperator(const std::string& type, const std::vector const std::string& name = ""); /** - * @brief Fictive custom operator not associated with any implementation. + * @brief Generic operator not associated with any implementation. * Allows to import unknown operators and simulate new ones. - * @param type Type of the fictive operator. + * @param type Type of the generic operator. * @param nbData Number of input data. * @param nbParam Number of parameters. * @param nbOut Number of output data. @@ -107,6 +111,18 @@ std::shared_ptr<Node> GenericOperator(const std::string& type, const std::vector */ std::shared_ptr<Node> GenericOperator(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut, const std::string& name = ""); + +/** + * @brief Generic operator not associated with any implementation. + * Create a generic operator from another existing operator. + * @param type Type of the generic operator. + * @param op Original operator from witch one wants to derive a generic operator. + * @param name (optional) name of the Operator. + * @return std::shared_ptr<Node> Node associated with the Generic Operator. + */ +std::shared_ptr<Aidge::Node> GenericOperator(const std::string& type, + std::shared_ptr<OperatorTensor> op, + const std::string& name = ""); } // namespace Aidge #endif /* AIDGE_CORE_OPERATOR_GENERICOPERATOR_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index e9988b4421b785a91ec170796be49c0c8df52142..95698b751a9f0f4c0cc8e716eb5140ee74e21a3f 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -196,6 +196,10 @@ public: return mOperatorType; } + inline std::vector<InputCategory> inputCategory() const { + return mInputsCategory; + } + inline InputCategory inputCategory(IOIndex_t idx) const { // AIDGE_ASSERT(idx < mInputsCategory.size(), "Input #{} out of range (number of inputs is {})", idx, mInputsCategory.size()); return mInputsCategory.at(idx); diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 3ecd4da393eaac9881d008e27989a52e883ecb6a..16d5d94a16b6c8ad01cf26bb8680f79f594b2190 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -41,6 +41,10 @@ class DynamicAttributes : public Attributes { public: DynamicAttributes() = default; DynamicAttributes(const std::map<std::string, future_std::any>& attrs): mAttrs(attrs) {} + DynamicAttributes& operator=(const std::map<std::string, future_std::any>& attrs) { + mAttrs = attrs; + return *this; + } /** * \brief Returning an Attribute identified by its name diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index f125291fafb89ec7ae81678a37e2bde2222a1054..f5ab29c679b7cbb06e5bd86876b63117fd8ce56d 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -39,6 +39,30 @@ void init_GenericOperator(py::module& m) { .def("set_forward_dims", &GenericOperator_Op::setForwardDims, py::arg("computation_function")); // &GenericOperator + m.def("GenericOperator", + []( const std::string& type, + const std::vector<Aidge::InputCategory>& inputCategory, + IOIndex_t nbOut, + const std::string& name, + const py::kwargs kwargs){ + std::shared_ptr<Node> genericNode = GenericOperator( + type, + inputCategory, + nbOut, + name + ); + if (kwargs){ + std::shared_ptr<GenericOperator_Op> gop = std::static_pointer_cast<GenericOperator_Op>(genericNode->getOperator()); + std::shared_ptr<DynamicAttributes> attr = std::dynamic_pointer_cast<DynamicAttributes>(gop->attributes()); + for (auto item : kwargs) { + std::string key = py::cast<std::string>(item.first); + py::object value = py::reinterpret_borrow<py::object>(item.second); + attr->setAttrPy(key, std::move(value)); + } + } + return genericNode; + }, py::arg("type"), py::arg("input_category"), py::arg("nb_out"), py::arg("name") = ""); + m.def("GenericOperator", []( const std::string& type, IOIndex_t nbData, @@ -65,6 +89,8 @@ void init_GenericOperator(py::module& m) { return genericNode; }, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = ""); + m.def("GenericOperator", py::overload_cast<const std::string&, std::shared_ptr<OperatorTensor>, const std::string&>(&GenericOperator), py::arg("type"), py::arg("op"), py::arg("name") = ""); + declare_registrable<GenericOperator_Op>(m, "GenericOperatorOp"); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index a1d1889c9a1881d3aa7b6eb9ccb4c23c5314cc80..ce70a4d7a6f2d5acc1bb69ba43ba7509c074a99a 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -44,7 +44,14 @@ void init_Operator(py::module& m){ .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("nb_inputs", &Operator::nbInputs) .def("nb_outputs", &Operator::nbOutputs) - .def("input_category", &Operator::inputCategory, py::arg("idx"), + .def("input_category", static_cast<std::vector<InputCategory>(Operator::*)() const>(&Operator::inputCategory), + R"mydelimiter( + Category of the inputs (Data or Param, optional or not). + Data inputs exclude inputs expecting parameters (weights or bias). + + :rtype: list(InputCategory) + )mydelimiter") + .def("input_category", static_cast<InputCategory(Operator::*)(IOIndex_t) const>(&Operator::inputCategory), py::arg("idx"), R"mydelimiter( Category of a specific input (Data or Param, optional or not). Data inputs exclude inputs expecting parameters (weights or bias). diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index c5bca92406e518df593fcc6c3a40525a4ba81dfa..b24c353524be0c795d9908206a6f9550a1e59d7b 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -73,7 +73,8 @@ bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) { } const auto& outputsDims = mForwardDims(inputsDims); - AIDGE_ASSERT((outputsDims.size() == nbOutputs()), "The provided ComputeDimsFunc function returns the wrong number of outputs"); + AIDGE_ASSERT(!outputsDims.empty(), "The provided ComputeDimsFunc cannot compute the output dims (an empty vector was returned)"); + AIDGE_ASSERT(outputsDims.size() == nbOutputs(), "The provided ComputeDimsFunc function returned the wrong number of outputs: {}, but {} are expected", outputsDims.size(), nbOutputs()); for (std::size_t i = 0; i < nbOutputs(); ++i) { mOutputs[i]->resize(outputsDims[i]); } @@ -117,3 +118,40 @@ std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type, const std::string& name) { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); } + +std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type, + std::shared_ptr<OperatorTensor> op, + const std::string& name) +{ + // Create a generic op with the same inputs/outputs + auto genericOp = std::make_shared<GenericOperator_Op>(type, op->inputCategory(), op->nbOutputs()); + + // Copy attributes + genericOp->setAttrs(op->attributes()->getAttrs()); + + // Set a default forward dims if possible + if (op->dimsForwarded()) { + auto opInputDims = std::vector<std::vector<DimSize_t>>(op->nbInputs()); + for (size_t i = 0; i < op->nbInputs(); ++i) { + opInputDims[i] = op->getInput(i)->dims(); + } + + auto opOutputDims = std::vector<std::vector<DimSize_t>>(op->nbOutputs()); + for (size_t o = 0; o < op->nbOutputs(); ++o) { + opOutputDims[o] = op->getOutput(o)->dims(); + } + + genericOp->setForwardDims([opInputDims, opOutputDims](const std::vector<std::vector<std::size_t>>& inputsDims) { + // Check input dims + for (size_t i = 0; i < opInputDims.size(); ++i) { + if (inputsDims[i] != opInputDims[i]) { + // No matching => unable to compute output dims! + return std::vector<std::vector<std::size_t>>(); + } + } + return opOutputDims; + }); + } + + return std::make_shared<Node>(genericOp, name); +}