diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp index 4b4fee0147462298dd08a47853d2b168e3d7e380..9a184acfdd3b1bae9827c1c9cab7dd64ae044f12 100644 --- a/include/aidge/backend/OperatorImpl.hpp +++ b/include/aidge/backend/OperatorImpl.hpp @@ -43,14 +43,11 @@ struct ImplSpec { std::vector<std::pair<int, int>> dims; }; - ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()): - attrs(attrs_) {} - - ImplSpec(IOSpec io, DynamicAttributes attrs_ = DynamicAttributes()): - inputs(1, io), outputs(1, io), attrs(attrs_) {} - - ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_ = DynamicAttributes()): - inputs(1, i), outputs(1, o), attrs(attrs_) {} + ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()); + ImplSpec(IOSpec io, DynamicAttributes attrs_ = DynamicAttributes()); + ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_ = DynamicAttributes()); + ImplSpec(const Aidge::ImplSpec&); + ~ImplSpec() noexcept; std::vector<IOSpec> inputs; std::vector<IOSpec> outputs; diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index ea6d9f98b5ca2dcfc624ed71eb63b8aa02c1ffb3..23221e653ba725e4463b06cfabb5483a20756701 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -84,35 +84,7 @@ constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{ * @return DataFormatTranspose Permutation array to achieve a transposition * from src to dst DataFormat. */ -constexpr inline DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { - // Permutation array from default format to src format - const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; - // Permutation array from default format to dst format - const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)]; - // Compute permutation array from src format to default format: - DataFormatTranspose srcFormatToDef{}; - for (size_t i = 0; i < srcDefToFormat.size(); ++i) { - if (srcDefToFormat[i] > 0) { - srcFormatToDef[srcDefToFormat[i] - 1] = i; - } - else { - srcFormatToDef[i] = i; - } - } - - // Compute permutation array from src format to dst format: - DataFormatTranspose srcToDst{}; - for (size_t i = 0; i < dstDefToFormat.size(); ++i) { - if (dstDefToFormat[srcFormatToDef[i]] > 0) { - srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1; - } - else { - srcToDst[i] = i; - } - } - - return srcToDst; -} +DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst); class Data { public: diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 3be17d6d21d18d63e75e384f2c6e037452db3a82..ecc47c74578a6ec8bba6c47c07df3f2be6d43078 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -77,7 +77,7 @@ public: */ Node(std::shared_ptr<Operator> op, const std::string& name = ""); - virtual ~Node() = default; + virtual ~Node(); friend bool operator==(const Node &lhs, const Node &rhs) { return lhs.shared_from_this() == rhs.shared_from_this(); diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index ee881c94f7af75b6edb6017299115297ccb29185..651e1ed419a2e63dc8723b6291e8ee2b4f664b07 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -18,6 +18,7 @@ #include <typeinfo> #include <cassert> #include <string> +#include <typeindex> #include "aidge/utils/future_std/any.hpp" #include "aidge/utils/Attributes.hpp" @@ -48,7 +49,32 @@ public: */ template<class T> const T& getAttr(const std::string& name) const { - return future_std::any_cast<const T&>(get(name)); + mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T), + [](const future_std::any& lhs, const future_std::any& rhs) { + return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs)); + })); + + const auto dot = name.find('.'); + if (dot == name.npos) { +#ifdef PYBIND + // If attribute does not exist in C++, it might have been created or modified in Python + auto it = mAttrs.find(name); + if (it == mAttrs.end()) { + auto itPy = mAttrsPy.find(name); + if (itPy != mAttrsPy.end()) { + // Insert the attribute back in C++ + mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); + } + } +#endif + + return future_std::any_cast<const T&>(mAttrs.at(name)); + } + else { + const auto ns = name.substr(0, dot); + const auto nsName = name.substr(dot + 1); + return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).getAttr<T>(nsName); + } } template<class T> T& getAttr(const std::string& name) { @@ -63,7 +89,31 @@ public: ///\param value Attribute value template<class T> void addAttr(const std::string& name, const T& value) { - add(name, future_std::any(value)); + mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T), + [](const future_std::any& lhs, const future_std::any& rhs) { + return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs)); + })); + + const auto dot = name.find('.'); + if (dot == name.npos) { + const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); + AIDGE_ASSERT(res.second, "addAttr(): attribute \"{}\" already exists. Use setAttr() if this is expected.", name); + +#ifdef PYBIND + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + const auto& resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); + AIDGE_ASSERT(resPy.second, "addAttr(): attribute \"{}\" already exists (added in Python). Use setAttr() if this is expected.", name); + } +#endif + } + else { + const auto ns = name.substr(0, dot); + const auto nsName = name.substr(dot + 1); + const auto& res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes()))); + future_std::any_cast<DynamicAttributes&>(res.first->second).addAttr(nsName, value); + } } ///\brief Set an Attribute value, identified by its name. If it already exists, its value (and type, if different) is changed. @@ -72,7 +122,33 @@ public: ///\param value Attribute value template<class T> void setAttr(const std::string& name, const T& value) { - set(name, future_std::any(value)); + mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T), + [](const future_std::any& lhs, const future_std::any& rhs) { + return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs)); + })); + + const auto dot = name.find('.'); + if (dot == name.npos) { + auto res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); + if (!res.second) + res.first->second = future_std::any(value); + +#ifdef PYBIND + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); + if (!resPy.second) + resPy.first->second = std::move(py::cast(value)); + } +#endif + } + else { + const auto ns = name.substr(0, dot); + const auto nsName = name.substr(dot + 1); + auto res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes()))); + future_std::any_cast<DynamicAttributes&>(res.first->second).setAttr<T>(nsName, value); + } } void delAttr(const std::string& name) { @@ -268,7 +344,7 @@ public: }; #endif - const future_std::any& get(const std::string& name) const + future_std::any getAny(const std::string& name) const { const auto dot = name.find('.'); if (dot == name.npos) { @@ -278,8 +354,14 @@ public: if (it == mAttrs.end()) { auto itPy = mAttrsPy.find(name); if (itPy != mAttrsPy.end()) { - // Insert the attribute back in C++ - mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); + // Attribute exists in Python, but its type is not known + // Return a std::any of py::object, which will be comparable + mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(py::object), + [](const future_std::any& lhs, const future_std::any& rhs) { + return (future_std::any_cast<py::object>(lhs) < future_std::any_cast<py::object>(rhs)); + })); + + return future_std::any(itPy->second); } } #endif @@ -289,57 +371,7 @@ public: else { const auto ns = name.substr(0, dot); const auto nsName = name.substr(dot + 1); - return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).get(nsName); - } - } - - void add(const std::string& name, const future_std::any& value) - { - const auto dot = name.find('.'); - if (dot == name.npos) { - const auto& res = mAttrs.emplace(std::make_pair(name, value)); - AIDGE_ASSERT(res.second, "addAttr(): attribute \"{}\" already exists. Use setAttr() if this is expected.", name); - -#ifdef PYBIND - // We cannot handle Python object if the Python interpreter is not running - if (Py_IsInitialized()) { - // Keep a copy of the attribute in py::object that is updated everytime - const auto& resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); - AIDGE_ASSERT(resPy.second, "addAttr(): attribute \"{}\" already exists (added in Python). Use setAttr() if this is expected.", name); - } -#endif - } - else { - const auto ns = name.substr(0, dot); - const auto nsName = name.substr(dot + 1); - const auto& res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes()))); - future_std::any_cast<DynamicAttributes&>(res.first->second).add(nsName, value); - } - } - - void set(const std::string& name, const future_std::any& value) - { - const auto dot = name.find('.'); - if (dot == name.npos) { - auto res = mAttrs.emplace(std::make_pair(name, value)); - if (!res.second) - res.first->second = future_std::any(value); - -#ifdef PYBIND - // We cannot handle Python object if the Python interpreter is not running - if (Py_IsInitialized()) { - // Keep a copy of the attribute in py::object that is updated everytime - auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); - if (!resPy.second) - resPy.first->second = std::move(py::cast(value)); - } -#endif - } - else { - const auto ns = name.substr(0, dot); - const auto nsName = name.substr(dot + 1); - auto res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes()))); - future_std::any_cast<DynamicAttributes&>(res.first->second).set(nsName, value); + return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).getAny(nsName); } } @@ -362,6 +394,10 @@ private: #else std::map<std::string, future_std::any> mAttrs; #endif + +public: + // Stores the comparison function for each attribute type ever used + static std::map<std::type_index, bool(*)(const future_std::any&, const future_std::any&)> mAnyCompare; }; inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs) { @@ -370,39 +406,7 @@ inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs } namespace future_std { -inline bool operator<(const future_std::any& lhs, const future_std::any& rhs) { - bool result = (lhs.type().before(rhs.type())); - if (lhs.type() == rhs.type()) { - if (lhs.type() == typeid(std::string)) - result = (future_std::any_cast<std::string>(lhs) < future_std::any_cast<std::string>(rhs)); - else if (lhs.type() == typeid(bool)) - result = (future_std::any_cast<bool>(lhs) < future_std::any_cast<bool>(rhs)); - else if (lhs.type() == typeid(char)) - result = (future_std::any_cast<char>(lhs) < future_std::any_cast<char>(rhs)); - else if (lhs.type() == typeid(unsigned char)) - result = (future_std::any_cast<unsigned char>(lhs) < future_std::any_cast<unsigned char>(rhs)); - else if (lhs.type() == typeid(short)) - result = (future_std::any_cast<short>(lhs) < future_std::any_cast<short>(rhs)); - else if (lhs.type() == typeid(unsigned short)) - result = (future_std::any_cast<unsigned short>(lhs) < future_std::any_cast<unsigned short>(rhs)); - else if (lhs.type() == typeid(int)) - result = (future_std::any_cast<int>(lhs) < future_std::any_cast<int>(rhs)); - else if (lhs.type() == typeid(unsigned int)) - result = (future_std::any_cast<unsigned int>(lhs) < future_std::any_cast<unsigned int>(rhs)); - else if (lhs.type() == typeid(long long int)) - result = (future_std::any_cast<long long int>(lhs) < future_std::any_cast<long long int>(rhs)); - else if (lhs.type() == typeid(unsigned long long int)) - result = (future_std::any_cast<unsigned long long int>(lhs) < future_std::any_cast<unsigned long long int>(rhs)); - else if (lhs.type() == typeid(float)) - result = (future_std::any_cast<float>(lhs) < future_std::any_cast<float>(rhs)); - else if (lhs.type() == typeid(double)) - result = (future_std::any_cast<double>(lhs) < future_std::any_cast<double>(rhs)); - else { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported type {} in std::any operator<", lhs.type().name()); - } - } - return result; -} +bool operator<(const future_std::any& lhs, const future_std::any& rhs); } #endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */ diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp index 6a83805fc1af2e111dd1c9f49c669e0c2f9422aa..bc2321ab1b823fdf03155050e714f6f60af6153c 100644 --- a/python_binding/backend/pybind_OperatorImpl.cpp +++ b/python_binding/backend/pybind_OperatorImpl.cpp @@ -13,6 +13,7 @@ #include <pybind11/stl.h> #include <string> +#include "aidge/graph/Node.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/backend/OperatorImpl.hpp" @@ -31,102 +32,67 @@ public: PYBIND11_OVERRIDE( void, OperatorImpl, - forward, - + forward ); } + void backward() override { PYBIND11_OVERRIDE( void, OperatorImpl, - backward, - + backward ); } - Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override { - PYBIND11_OVERRIDE_NAME( - Elts_t, - OperatorImpl, - "get_nb_required_data", - getNbRequiredData, - inputIdx - ); - } - Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { - PYBIND11_OVERRIDE_NAME( - Elts_t, - OperatorImpl, - "get_nb_required_protected", - getNbRequiredProtected, - inputIdx - ); - } - Elts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t> &inputsSize) const override { + std::shared_ptr<ProdConso> getProdConso() const override { PYBIND11_OVERRIDE_NAME( - Elts_t, + std::shared_ptr<ProdConso>, OperatorImpl, - "get_required_memory", - getRequiredMemory, - outputIdx, - inputsSize - + "get_prod_conso", + getProdConso ); } - Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override { - PYBIND11_OVERRIDE_NAME( - Elts_t, - OperatorImpl, - "get_nb_consumed_data", - getNbConsumedData, - inputIdx - ); - } - Elts_t getNbProducedData(const IOIndex_t outputIdx) const override { + std::vector<ImplSpec> getAvailableImplSpecs() const noexcept override { PYBIND11_OVERRIDE_NAME( - Elts_t, + std::vector<ImplSpec>, OperatorImpl, - "get_nb_produced_data", - getNbProducedData, - outputIdx - + "get_available_impl_specs", + getAvailableImplSpecs ); } - void updateConsummerProducer() override { - PYBIND11_OVERRIDE_NAME( - void, - OperatorImpl, - "update_consummer_producer", - updateConsummerProducer, - - ); - } - void resetConsummerProducer() override { - PYBIND11_OVERRIDE_NAME( - void, - OperatorImpl, - "reset_consummer_producer", - resetConsummerProducer, +}; - ); - } +// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions +class OperatorImpl_Publicist : public OperatorImpl { +public: + using OperatorImpl::getProdConso; + using OperatorImpl::getAvailableImplSpecs; }; void init_OperatorImpl(py::module& m){ + py::class_<ImplSpec>(m, "ImplSpec") + .def(py::init<DynamicAttributes>()) + .def(py::init<ImplSpec::IOSpec, DynamicAttributes>()) + .def(py::init<ImplSpec::IOSpec, ImplSpec::IOSpec, DynamicAttributes>()) + ; + + py::class_<ImplSpec::IOSpec>(m, "IOSpec") + .def(py::init<DataType, DataFormat, std::vector<std::pair<int, int>>>()) + ; py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def("forward", &OperatorImpl::forward) .def("backward", &OperatorImpl::backward) - .def("get_nb_required_data", &OperatorImpl::getNbRequiredData) - .def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) - .def("get_required_memory", &OperatorImpl::getRequiredMemory) - .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) - .def("get_nb_produced_data", &OperatorImpl::getNbProducedData) - .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) - .def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer) + .def("prod_conso", &OperatorImpl::prodConso) + .def("backend", &OperatorImpl::backend) + .def("get_required_spec", &OperatorImpl::getRequiredSpec) + .def("get_best_match", &OperatorImpl::getBestMatch) + .def("get_adaptation", &OperatorImpl::getAdaptation) + .def("get_best_adaptation", &OperatorImpl::getBestAdaptation) + .def("get_prod_conso", &OperatorImpl_Publicist::getProdConso) + .def("get_available_impl_specs", &OperatorImpl_Publicist::getAvailableImplSpecs) ; } } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 918143213f3dd490ef0e448f086c09135b05f6af..0ed7f4e4c59e5326dd1640c65432207bdafe8023 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -78,6 +78,7 @@ void init_GraphViewHelper(py::module&); void init_Scheduler(py::module&); void init_MemoryManager(py::module&); +void init_ProdConso(py::module& m); void init_TensorUtils(py::module&); void init_Filler(py::module&); @@ -146,6 +147,7 @@ void init_Aidge(py::module& m) { init_GraphViewHelper(m); init_Scheduler(m); init_MemoryManager(m); + init_ProdConso(m); init_TensorUtils(m); init_Filler(m); } diff --git a/python_binding/scheduler/pybind_ProdConso.cpp b/python_binding/scheduler/pybind_ProdConso.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abd6d5379178916b5842095d50a1de2155345b6f --- /dev/null +++ b/python_binding/scheduler/pybind_ProdConso.cpp @@ -0,0 +1,116 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <string> + +#include "aidge/operator/Operator.hpp" +#include "aidge/scheduler/ProdConso.hpp" + +namespace py = pybind11; +namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyProdConso: public ProdConso { +public: + using ProdConso::ProdConso; // Inherit constructors + + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + Elts_t, + ProdConso, + "get_nb_required_data", + getNbRequiredData, + inputIdx + ); + } + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + Elts_t, + ProdConso, + "get_nb_required_protected", + getNbRequiredProtected, + inputIdx + + ); + } + Elts_t getRequiredMemory(const IOIndex_t outputIdx, + const std::vector<DimSize_t> &inputsSize) const override { + PYBIND11_OVERRIDE_NAME( + Elts_t, + ProdConso, + "get_required_memory", + getRequiredMemory, + outputIdx, + inputsSize + + ); + } + Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override { + PYBIND11_OVERRIDE_NAME( + Elts_t, + ProdConso, + "get_nb_consumed_data", + getNbConsumedData, + inputIdx + + ); + } + Elts_t getNbProducedData(const IOIndex_t outputIdx) const override { + PYBIND11_OVERRIDE_NAME( + Elts_t, + ProdConso, + "get_nb_produced_data", + getNbProducedData, + outputIdx + + ); + } + void updateConsummerProducer() override { + PYBIND11_OVERRIDE_NAME( + void, + ProdConso, + "update_consummer_producer", + updateConsummerProducer, + + ); + } + void resetConsummerProducer() override { + PYBIND11_OVERRIDE_NAME( + void, + ProdConso, + "reset_consummer_producer", + resetConsummerProducer, + + ); + } +}; + +void init_ProdConso(py::module& m){ + + py::class_<ProdConso, std::shared_ptr<ProdConso>, pyProdConso>(m, "ProdConso", py::dynamic_attr()) + .def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) + .def_static("default_model", &ProdConso::defaultModel) + .def_static("in_place_model", &ProdConso::inPlaceModel) + .def("get_nb_required_data", &ProdConso::getNbRequiredData) + .def("get_nb_required_protected", &ProdConso::getNbRequiredProtected) + .def("get_required_memory", &ProdConso::getRequiredMemory) + .def("get_nb_consumed_data", &ProdConso::getNbConsumedData) + .def("get_nb_produced_data", &ProdConso::getNbProducedData) + .def("update_consummer_producer", &ProdConso::updateConsummerProducer) + .def("reset_consummer_producer", &ProdConso::resetConsummerProducer) + ; +} +} diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 863aa21bb5422ac7c953b18d814fc2f46662bbae..4112eee079284dd83844430d532bd80c3086f5b9 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -22,6 +22,15 @@ #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" +Aidge::ImplSpec::ImplSpec(DynamicAttributes attrs_): + attrs(attrs_) {} +Aidge::ImplSpec::ImplSpec(IOSpec io, DynamicAttributes attrs_): + inputs(1, io), outputs(1, io), attrs(attrs_) {} +Aidge::ImplSpec::ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_): + inputs(1, i), outputs(1, o), attrs(attrs_) {} +Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default; +Aidge::ImplSpec::~ImplSpec() noexcept = default; + Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend): mOp(op), mBackend(backend) @@ -121,8 +130,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) match = false; break; } - else if (requiredSpecs.attrs.get(attrName) < spec.attrs.get(name) - || spec.attrs.get(name) < requiredSpecs.attrs.get(attrName)) + else if (requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name) + || spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName)) { // Attribute value mismatch match = false; @@ -133,8 +142,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) const int attrPriority = (!qualifier.empty()) ? std::stoi(qualifier) : 0; if (spec.attrs.hasAttr(name) - && !(requiredSpecs.attrs.get(attrName) < spec.attrs.get(name)) - && !(spec.attrs.get(name) < requiredSpecs.attrs.get(attrName))) + && !(requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name)) + && !(spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName))) { // Attribute value match priority = std::max(priority, attrPriority); @@ -162,6 +171,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) } bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const { + // Check type if (required.type != DataType::Any && spec.type != DataType::Any && required.type != spec.type) @@ -169,6 +179,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im return false; } + // Check format if (required.format != DataFormat::Any && spec.format != DataFormat::Any && required.format != spec.format) @@ -182,6 +193,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im } } + // Check dims if (!required.dims.empty() && !spec.dims.empty()) { if (required.dims.size() != spec.dims.size()) { return false; @@ -213,6 +225,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i]; std::shared_ptr<Node> parent = node; + // Input type if (requiredIOSpec.type != DataType::Any && IOSpec.type != DataType::Any && requiredIOSpec.type != IOSpec.type) @@ -223,6 +236,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& op->getInput(i)->setDataType(IOSpec.type); } + // Input format if (requiredIOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) @@ -236,6 +250,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& op->getInput(i)->setDataFormat(IOSpec.format); } + // Input dims if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { return nullptr; @@ -261,6 +276,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i]; std::shared_ptr<Node> parent = node; + // Output type if (requiredIOSpec.type != DataType::Any && IOSpec.type != DataType::Any && requiredIOSpec.type != IOSpec.type) @@ -271,6 +287,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& op->getOutput(i)->setDataType(IOSpec.type); } + // Output format if (requiredIOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any && requiredIOSpec.format != IOSpec.format) @@ -284,6 +301,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& op->getOutput(i)->setDataFormat(IOSpec.format); } + // Output dims if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { return nullptr; diff --git a/src/data/Data.cpp b/src/data/Data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62a883d08a401e02c86408214a061f893ffbfb4a --- /dev/null +++ b/src/data/Data.cpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/data/Data.hpp" + +Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { + // Permutation array from default format to src format + const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; + // Permutation array from default format to dst format + const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)]; + // Compute permutation array from src format to default format: + DataFormatTranspose srcFormatToDef{}; + for (size_t i = 0; i < srcDefToFormat.size(); ++i) { + if (srcDefToFormat[i] > 0) { + srcFormatToDef[srcDefToFormat[i] - 1] = i; + } + else { + srcFormatToDef[i] = i; + } + } + + // Compute permutation array from src format to dst format: + DataFormatTranspose srcToDst{}; + for (size_t i = 0; i < dstDefToFormat.size(); ++i) { + if (dstDefToFormat[srcFormatToDef[i]] > 0) { + srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1; + } + else { + srcToDst[i] = i; + } + } + + return srcToDst; +} diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 6f24e017028535edeaa7d6186e76de0dcb09828d..e91ce25e5ed670bf7e386d88ee6f58a5e4a76996 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -430,6 +430,9 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::No return out; } + +Aidge::Node::~Node() = default; + // namespace Aidge { // std::ostream& operator << (std::ostream& os, Aidge::Node& n) { // using namespace std; diff --git a/src/utils/DynamicAttributes.cpp b/src/utils/DynamicAttributes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56731f234cec303c4b5329d69634a7e4589236d3 --- /dev/null +++ b/src/utils/DynamicAttributes.cpp @@ -0,0 +1,20 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/utils/DynamicAttributes.hpp" + +std::map<std::type_index, bool(*)(const future_std::any&, const future_std::any&)> Aidge::DynamicAttributes::mAnyCompare; + +bool future_std::operator<(const future_std::any& lhs, const future_std::any& rhs) { + return (lhs.type() == rhs.type()) + ? Aidge::DynamicAttributes::mAnyCompare.at(lhs.type())(lhs, rhs) + : (lhs.type().before(rhs.type())); +}