Skip to content
Snippets Groups Projects
Commit bc0db7b2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed Python binding

parent 737d8ad1
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53745 passed
...@@ -43,14 +43,11 @@ struct ImplSpec { ...@@ -43,14 +43,11 @@ struct ImplSpec {
std::vector<std::pair<int, int>> dims; std::vector<std::pair<int, int>> dims;
}; };
ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()): ImplSpec(DynamicAttributes attrs_ = DynamicAttributes());
attrs(attrs_) {} ImplSpec(IOSpec io, DynamicAttributes attrs_ = DynamicAttributes());
ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_ = DynamicAttributes());
ImplSpec(IOSpec io, DynamicAttributes attrs_ = DynamicAttributes()): ImplSpec(const Aidge::ImplSpec&);
inputs(1, io), outputs(1, io), attrs(attrs_) {} ~ImplSpec() noexcept;
ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_ = DynamicAttributes()):
inputs(1, i), outputs(1, o), attrs(attrs_) {}
std::vector<IOSpec> inputs; std::vector<IOSpec> inputs;
std::vector<IOSpec> outputs; std::vector<IOSpec> outputs;
......
...@@ -84,35 +84,7 @@ constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{ ...@@ -84,35 +84,7 @@ constexpr std::array<DataFormatTranspose, 7> DataFormatTransposeDict = {{
* @return DataFormatTranspose Permutation array to achieve a transposition * @return DataFormatTranspose Permutation array to achieve a transposition
* from src to dst DataFormat. * from src to dst DataFormat.
*/ */
constexpr inline DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { DataFormatTranspose getDataFormatTranspose(const DataFormat& src, const DataFormat& dst);
// Permutation array from default format to src format
const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)];
// Permutation array from default format to dst format
const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)];
// Compute permutation array from src format to default format:
DataFormatTranspose srcFormatToDef{};
for (size_t i = 0; i < srcDefToFormat.size(); ++i) {
if (srcDefToFormat[i] > 0) {
srcFormatToDef[srcDefToFormat[i] - 1] = i;
}
else {
srcFormatToDef[i] = i;
}
}
// Compute permutation array from src format to dst format:
DataFormatTranspose srcToDst{};
for (size_t i = 0; i < dstDefToFormat.size(); ++i) {
if (dstDefToFormat[srcFormatToDef[i]] > 0) {
srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1;
}
else {
srcToDst[i] = i;
}
}
return srcToDst;
}
class Data { class Data {
public: public:
......
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
*/ */
Node(std::shared_ptr<Operator> op, const std::string& name = ""); Node(std::shared_ptr<Operator> op, const std::string& name = "");
virtual ~Node() = default; virtual ~Node();
friend bool operator==(const Node &lhs, const Node &rhs) { friend bool operator==(const Node &lhs, const Node &rhs) {
return lhs.shared_from_this() == rhs.shared_from_this(); return lhs.shared_from_this() == rhs.shared_from_this();
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <typeinfo> #include <typeinfo>
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <typeindex>
#include "aidge/utils/future_std/any.hpp" #include "aidge/utils/future_std/any.hpp"
#include "aidge/utils/Attributes.hpp" #include "aidge/utils/Attributes.hpp"
...@@ -48,7 +49,32 @@ public: ...@@ -48,7 +49,32 @@ public:
*/ */
template<class T> const T& getAttr(const std::string& name) const template<class T> const T& getAttr(const std::string& name) const
{ {
return future_std::any_cast<const T&>(get(name)); mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T),
[](const future_std::any& lhs, const future_std::any& rhs) {
return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs));
}));
const auto dot = name.find('.');
if (dot == name.npos) {
#ifdef PYBIND
// If attribute does not exist in C++, it might have been created or modified in Python
auto it = mAttrs.find(name);
if (it == mAttrs.end()) {
auto itPy = mAttrsPy.find(name);
if (itPy != mAttrsPy.end()) {
// Insert the attribute back in C++
mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>())));
}
}
#endif
return future_std::any_cast<const T&>(mAttrs.at(name));
}
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).getAttr<T>(nsName);
}
} }
template<class T> T& getAttr(const std::string& name) { template<class T> T& getAttr(const std::string& name) {
...@@ -63,7 +89,31 @@ public: ...@@ -63,7 +89,31 @@ public:
///\param value Attribute value ///\param value Attribute value
template<class T> void addAttr(const std::string& name, const T& value) template<class T> void addAttr(const std::string& name, const T& value)
{ {
add(name, future_std::any(value)); mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T),
[](const future_std::any& lhs, const future_std::any& rhs) {
return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs));
}));
const auto dot = name.find('.');
if (dot == name.npos) {
const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value)));
AIDGE_ASSERT(res.second, "addAttr(): attribute \"{}\" already exists. Use setAttr() if this is expected.", name);
#ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running
if (Py_IsInitialized()) {
// Keep a copy of the attribute in py::object that is updated everytime
const auto& resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
AIDGE_ASSERT(resPy.second, "addAttr(): attribute \"{}\" already exists (added in Python). Use setAttr() if this is expected.", name);
}
#endif
}
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
const auto& res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes())));
future_std::any_cast<DynamicAttributes&>(res.first->second).addAttr(nsName, value);
}
} }
///\brief Set an Attribute value, identified by its name. If it already exists, its value (and type, if different) is changed. ///\brief Set an Attribute value, identified by its name. If it already exists, its value (and type, if different) is changed.
...@@ -72,7 +122,33 @@ public: ...@@ -72,7 +122,33 @@ public:
///\param value Attribute value ///\param value Attribute value
template<class T> void setAttr(const std::string& name, const T& value) template<class T> void setAttr(const std::string& name, const T& value)
{ {
set(name, future_std::any(value)); mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(T),
[](const future_std::any& lhs, const future_std::any& rhs) {
return (future_std::any_cast<T>(lhs) < future_std::any_cast<T>(rhs));
}));
const auto dot = name.find('.');
if (dot == name.npos) {
auto res = mAttrs.emplace(std::make_pair(name, future_std::any(value)));
if (!res.second)
res.first->second = future_std::any(value);
#ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running
if (Py_IsInitialized()) {
// Keep a copy of the attribute in py::object that is updated everytime
auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
if (!resPy.second)
resPy.first->second = std::move(py::cast(value));
}
#endif
}
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
auto res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes())));
future_std::any_cast<DynamicAttributes&>(res.first->second).setAttr<T>(nsName, value);
}
} }
void delAttr(const std::string& name) { void delAttr(const std::string& name) {
...@@ -268,7 +344,7 @@ public: ...@@ -268,7 +344,7 @@ public:
}; };
#endif #endif
const future_std::any& get(const std::string& name) const future_std::any getAny(const std::string& name) const
{ {
const auto dot = name.find('.'); const auto dot = name.find('.');
if (dot == name.npos) { if (dot == name.npos) {
...@@ -278,8 +354,14 @@ public: ...@@ -278,8 +354,14 @@ public:
if (it == mAttrs.end()) { if (it == mAttrs.end()) {
auto itPy = mAttrsPy.find(name); auto itPy = mAttrsPy.find(name);
if (itPy != mAttrsPy.end()) { if (itPy != mAttrsPy.end()) {
// Insert the attribute back in C++ // Attribute exists in Python, but its type is not known
mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); // Return a std::any of py::object, which will be comparable
mAnyCompare.emplace(std::make_pair<std::type_index, bool(*)(const future_std::any&, const future_std::any&)>(typeid(py::object),
[](const future_std::any& lhs, const future_std::any& rhs) {
return (future_std::any_cast<py::object>(lhs) < future_std::any_cast<py::object>(rhs));
}));
return future_std::any(itPy->second);
} }
} }
#endif #endif
...@@ -289,57 +371,7 @@ public: ...@@ -289,57 +371,7 @@ public:
else { else {
const auto ns = name.substr(0, dot); const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1); const auto nsName = name.substr(dot + 1);
return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).get(nsName); return future_std::any_cast<const DynamicAttributes&>(mAttrs.at(ns)).getAny(nsName);
}
}
void add(const std::string& name, const future_std::any& value)
{
const auto dot = name.find('.');
if (dot == name.npos) {
const auto& res = mAttrs.emplace(std::make_pair(name, value));
AIDGE_ASSERT(res.second, "addAttr(): attribute \"{}\" already exists. Use setAttr() if this is expected.", name);
#ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running
if (Py_IsInitialized()) {
// Keep a copy of the attribute in py::object that is updated everytime
const auto& resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
AIDGE_ASSERT(resPy.second, "addAttr(): attribute \"{}\" already exists (added in Python). Use setAttr() if this is expected.", name);
}
#endif
}
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
const auto& res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes())));
future_std::any_cast<DynamicAttributes&>(res.first->second).add(nsName, value);
}
}
void set(const std::string& name, const future_std::any& value)
{
const auto dot = name.find('.');
if (dot == name.npos) {
auto res = mAttrs.emplace(std::make_pair(name, value));
if (!res.second)
res.first->second = future_std::any(value);
#ifdef PYBIND
// We cannot handle Python object if the Python interpreter is not running
if (Py_IsInitialized()) {
// Keep a copy of the attribute in py::object that is updated everytime
auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
if (!resPy.second)
resPy.first->second = std::move(py::cast(value));
}
#endif
}
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
auto res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes())));
future_std::any_cast<DynamicAttributes&>(res.first->second).set(nsName, value);
} }
} }
...@@ -362,6 +394,10 @@ private: ...@@ -362,6 +394,10 @@ private:
#else #else
std::map<std::string, future_std::any> mAttrs; std::map<std::string, future_std::any> mAttrs;
#endif #endif
public:
// Stores the comparison function for each attribute type ever used
static std::map<std::type_index, bool(*)(const future_std::any&, const future_std::any&)> mAnyCompare;
}; };
inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs) { inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs) {
...@@ -370,39 +406,7 @@ inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs ...@@ -370,39 +406,7 @@ inline bool operator<(const DynamicAttributes& lhs, const DynamicAttributes& rhs
} }
namespace future_std { namespace future_std {
inline bool operator<(const future_std::any& lhs, const future_std::any& rhs) { bool operator<(const future_std::any& lhs, const future_std::any& rhs);
bool result = (lhs.type().before(rhs.type()));
if (lhs.type() == rhs.type()) {
if (lhs.type() == typeid(std::string))
result = (future_std::any_cast<std::string>(lhs) < future_std::any_cast<std::string>(rhs));
else if (lhs.type() == typeid(bool))
result = (future_std::any_cast<bool>(lhs) < future_std::any_cast<bool>(rhs));
else if (lhs.type() == typeid(char))
result = (future_std::any_cast<char>(lhs) < future_std::any_cast<char>(rhs));
else if (lhs.type() == typeid(unsigned char))
result = (future_std::any_cast<unsigned char>(lhs) < future_std::any_cast<unsigned char>(rhs));
else if (lhs.type() == typeid(short))
result = (future_std::any_cast<short>(lhs) < future_std::any_cast<short>(rhs));
else if (lhs.type() == typeid(unsigned short))
result = (future_std::any_cast<unsigned short>(lhs) < future_std::any_cast<unsigned short>(rhs));
else if (lhs.type() == typeid(int))
result = (future_std::any_cast<int>(lhs) < future_std::any_cast<int>(rhs));
else if (lhs.type() == typeid(unsigned int))
result = (future_std::any_cast<unsigned int>(lhs) < future_std::any_cast<unsigned int>(rhs));
else if (lhs.type() == typeid(long long int))
result = (future_std::any_cast<long long int>(lhs) < future_std::any_cast<long long int>(rhs));
else if (lhs.type() == typeid(unsigned long long int))
result = (future_std::any_cast<unsigned long long int>(lhs) < future_std::any_cast<unsigned long long int>(rhs));
else if (lhs.type() == typeid(float))
result = (future_std::any_cast<float>(lhs) < future_std::any_cast<float>(rhs));
else if (lhs.type() == typeid(double))
result = (future_std::any_cast<double>(lhs) < future_std::any_cast<double>(rhs));
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported type {} in std::any operator<", lhs.type().name());
}
}
return result;
}
} }
#endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */ #endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <string> #include <string>
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
...@@ -31,102 +32,67 @@ public: ...@@ -31,102 +32,67 @@ public:
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
void, void,
OperatorImpl, OperatorImpl,
forward, forward
); );
} }
void backward() override { void backward() override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
void, void,
OperatorImpl, OperatorImpl,
backward, backward
); );
} }
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
OperatorImpl,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
OperatorImpl,
"get_nb_required_protected",
getNbRequiredProtected,
inputIdx
); std::shared_ptr<ProdConso> getProdConso() const override {
}
Elts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_NAME( PYBIND11_OVERRIDE_NAME(
Elts_t, std::shared_ptr<ProdConso>,
OperatorImpl, OperatorImpl,
"get_required_memory", "get_prod_conso",
getRequiredMemory, getProdConso
outputIdx,
inputsSize
); );
} }
Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
OperatorImpl,
"get_nb_consumed_data",
getNbConsumedData,
inputIdx
); std::vector<ImplSpec> getAvailableImplSpecs() const noexcept override {
}
Elts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_NAME( PYBIND11_OVERRIDE_NAME(
Elts_t, std::vector<ImplSpec>,
OperatorImpl, OperatorImpl,
"get_nb_produced_data", "get_available_impl_specs",
getNbProducedData, getAvailableImplSpecs
outputIdx
); );
} }
void updateConsummerProducer() override { };
PYBIND11_OVERRIDE_NAME(
void,
OperatorImpl,
"update_consummer_producer",
updateConsummerProducer,
);
}
void resetConsummerProducer() override {
PYBIND11_OVERRIDE_NAME(
void,
OperatorImpl,
"reset_consummer_producer",
resetConsummerProducer,
); // See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions
} class OperatorImpl_Publicist : public OperatorImpl {
public:
using OperatorImpl::getProdConso;
using OperatorImpl::getAvailableImplSpecs;
}; };
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
py::class_<ImplSpec>(m, "ImplSpec")
.def(py::init<DynamicAttributes>())
.def(py::init<ImplSpec::IOSpec, DynamicAttributes>())
.def(py::init<ImplSpec::IOSpec, ImplSpec::IOSpec, DynamicAttributes>())
;
py::class_<ImplSpec::IOSpec>(m, "IOSpec")
.def(py::init<DataType, DataFormat, std::vector<std::pair<int, int>>>())
;
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def("forward", &OperatorImpl::forward) .def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward) .def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData) .def("prod_conso", &OperatorImpl::prodConso)
.def("get_nb_required_protected", &OperatorImpl::getNbRequiredProtected) .def("backend", &OperatorImpl::backend)
.def("get_required_memory", &OperatorImpl::getRequiredMemory) .def("get_required_spec", &OperatorImpl::getRequiredSpec)
.def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData) .def("get_best_match", &OperatorImpl::getBestMatch)
.def("get_nb_produced_data", &OperatorImpl::getNbProducedData) .def("get_adaptation", &OperatorImpl::getAdaptation)
.def("update_consummer_producer", &OperatorImpl::updateConsummerProducer) .def("get_best_adaptation", &OperatorImpl::getBestAdaptation)
.def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer) .def("get_prod_conso", &OperatorImpl_Publicist::getProdConso)
.def("get_available_impl_specs", &OperatorImpl_Publicist::getAvailableImplSpecs)
; ;
} }
} }
...@@ -78,6 +78,7 @@ void init_GraphViewHelper(py::module&); ...@@ -78,6 +78,7 @@ void init_GraphViewHelper(py::module&);
void init_Scheduler(py::module&); void init_Scheduler(py::module&);
void init_MemoryManager(py::module&); void init_MemoryManager(py::module&);
void init_ProdConso(py::module& m);
void init_TensorUtils(py::module&); void init_TensorUtils(py::module&);
void init_Filler(py::module&); void init_Filler(py::module&);
...@@ -146,6 +147,7 @@ void init_Aidge(py::module& m) { ...@@ -146,6 +147,7 @@ void init_Aidge(py::module& m) {
init_GraphViewHelper(m); init_GraphViewHelper(m);
init_Scheduler(m); init_Scheduler(m);
init_MemoryManager(m); init_MemoryManager(m);
init_ProdConso(m);
init_TensorUtils(m); init_TensorUtils(m);
init_Filler(m); init_Filler(m);
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "aidge/operator/Operator.hpp"
#include "aidge/scheduler/ProdConso.hpp"
namespace py = pybind11;
namespace Aidge {
/**
* @brief Trampoline class for binding
*
*/
class pyProdConso: public ProdConso {
public:
using ProdConso::ProdConso; // Inherit constructors
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
ProdConso,
"get_nb_required_data",
getNbRequiredData,
inputIdx
);
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
ProdConso,
"get_nb_required_protected",
getNbRequiredProtected,
inputIdx
);
}
Elts_t getRequiredMemory(const IOIndex_t outputIdx,
const std::vector<DimSize_t> &inputsSize) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
ProdConso,
"get_required_memory",
getRequiredMemory,
outputIdx,
inputsSize
);
}
Elts_t getNbConsumedData(const IOIndex_t inputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
ProdConso,
"get_nb_consumed_data",
getNbConsumedData,
inputIdx
);
}
Elts_t getNbProducedData(const IOIndex_t outputIdx) const override {
PYBIND11_OVERRIDE_NAME(
Elts_t,
ProdConso,
"get_nb_produced_data",
getNbProducedData,
outputIdx
);
}
void updateConsummerProducer() override {
PYBIND11_OVERRIDE_NAME(
void,
ProdConso,
"update_consummer_producer",
updateConsummerProducer,
);
}
void resetConsummerProducer() override {
PYBIND11_OVERRIDE_NAME(
void,
ProdConso,
"reset_consummer_producer",
resetConsummerProducer,
);
}
};
void init_ProdConso(py::module& m){
py::class_<ProdConso, std::shared_ptr<ProdConso>, pyProdConso>(m, "ProdConso", py::dynamic_attr())
.def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def_static("default_model", &ProdConso::defaultModel)
.def_static("in_place_model", &ProdConso::inPlaceModel)
.def("get_nb_required_data", &ProdConso::getNbRequiredData)
.def("get_nb_required_protected", &ProdConso::getNbRequiredProtected)
.def("get_required_memory", &ProdConso::getRequiredMemory)
.def("get_nb_consumed_data", &ProdConso::getNbConsumedData)
.def("get_nb_produced_data", &ProdConso::getNbProducedData)
.def("update_consummer_producer", &ProdConso::updateConsummerProducer)
.def("reset_consummer_producer", &ProdConso::resetConsummerProducer)
;
}
}
...@@ -22,6 +22,15 @@ ...@@ -22,6 +22,15 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::ImplSpec::ImplSpec(DynamicAttributes attrs_):
attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(IOSpec io, DynamicAttributes attrs_):
inputs(1, io), outputs(1, io), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_):
inputs(1, i), outputs(1, o), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default;
Aidge::ImplSpec::~ImplSpec() noexcept = default;
Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend): Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend):
mOp(op), mOp(op),
mBackend(backend) mBackend(backend)
...@@ -121,8 +130,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) ...@@ -121,8 +130,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
match = false; match = false;
break; break;
} }
else if (requiredSpecs.attrs.get(attrName) < spec.attrs.get(name) else if (requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name)
|| spec.attrs.get(name) < requiredSpecs.attrs.get(attrName)) || spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName))
{ {
// Attribute value mismatch // Attribute value mismatch
match = false; match = false;
...@@ -133,8 +142,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) ...@@ -133,8 +142,8 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
const int attrPriority = (!qualifier.empty()) ? std::stoi(qualifier) : 0; const int attrPriority = (!qualifier.empty()) ? std::stoi(qualifier) : 0;
if (spec.attrs.hasAttr(name) if (spec.attrs.hasAttr(name)
&& !(requiredSpecs.attrs.get(attrName) < spec.attrs.get(name)) && !(requiredSpecs.attrs.getAny(attrName) < spec.attrs.getAny(name))
&& !(spec.attrs.get(name) < requiredSpecs.attrs.get(attrName))) && !(spec.attrs.getAny(name) < requiredSpecs.attrs.getAny(attrName)))
{ {
// Attribute value match // Attribute value match
priority = std::max(priority, attrPriority); priority = std::max(priority, attrPriority);
...@@ -162,6 +171,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) ...@@ -162,6 +171,7 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs)
} }
bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const { bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const ImplSpec::IOSpec& spec) const {
// Check type
if (required.type != DataType::Any if (required.type != DataType::Any
&& spec.type != DataType::Any && spec.type != DataType::Any
&& required.type != spec.type) && required.type != spec.type)
...@@ -169,6 +179,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im ...@@ -169,6 +179,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im
return false; return false;
} }
// Check format
if (required.format != DataFormat::Any if (required.format != DataFormat::Any
&& spec.format != DataFormat::Any && spec.format != DataFormat::Any
&& required.format != spec.format) && required.format != spec.format)
...@@ -182,6 +193,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im ...@@ -182,6 +193,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im
} }
} }
// Check dims
if (!required.dims.empty() && !spec.dims.empty()) { if (!required.dims.empty() && !spec.dims.empty()) {
if (required.dims.size() != spec.dims.size()) { if (required.dims.size() != spec.dims.size()) {
return false; return false;
...@@ -213,6 +225,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -213,6 +225,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i]; const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.inputs[i];
std::shared_ptr<Node> parent = node; std::shared_ptr<Node> parent = node;
// Input type
if (requiredIOSpec.type != DataType::Any if (requiredIOSpec.type != DataType::Any
&& IOSpec.type != DataType::Any && IOSpec.type != DataType::Any
&& requiredIOSpec.type != IOSpec.type) && requiredIOSpec.type != IOSpec.type)
...@@ -223,6 +236,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -223,6 +236,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
op->getInput(i)->setDataType(IOSpec.type); op->getInput(i)->setDataType(IOSpec.type);
} }
// Input format
if (requiredIOSpec.format != DataFormat::Any if (requiredIOSpec.format != DataFormat::Any
&& IOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format) && requiredIOSpec.format != IOSpec.format)
...@@ -236,6 +250,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -236,6 +250,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
op->getInput(i)->setDataFormat(IOSpec.format); op->getInput(i)->setDataFormat(IOSpec.format);
} }
// Input dims
if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
return nullptr; return nullptr;
...@@ -261,6 +276,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -261,6 +276,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i]; const ImplSpec::IOSpec& requiredIOSpec = requiredSpecs.outputs[i];
std::shared_ptr<Node> parent = node; std::shared_ptr<Node> parent = node;
// Output type
if (requiredIOSpec.type != DataType::Any if (requiredIOSpec.type != DataType::Any
&& IOSpec.type != DataType::Any && IOSpec.type != DataType::Any
&& requiredIOSpec.type != IOSpec.type) && requiredIOSpec.type != IOSpec.type)
...@@ -271,6 +287,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -271,6 +287,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
op->getOutput(i)->setDataType(IOSpec.type); op->getOutput(i)->setDataType(IOSpec.type);
} }
// Output format
if (requiredIOSpec.format != DataFormat::Any if (requiredIOSpec.format != DataFormat::Any
&& IOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format) && requiredIOSpec.format != IOSpec.format)
...@@ -284,6 +301,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -284,6 +301,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
op->getOutput(i)->setDataFormat(IOSpec.format); op->getOutput(i)->setDataFormat(IOSpec.format);
} }
// Output dims
if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) { if (!requiredIOSpec.dims.empty() && !IOSpec.dims.empty()) {
if (requiredIOSpec.dims.size() != IOSpec.dims.size()) { if (requiredIOSpec.dims.size() != IOSpec.dims.size()) {
return nullptr; return nullptr;
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/data/Data.hpp"
Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) {
// Permutation array from default format to src format
const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)];
// Permutation array from default format to dst format
const auto dstDefToFormat = DataFormatTransposeDict[static_cast<int>(dst)];
// Compute permutation array from src format to default format:
DataFormatTranspose srcFormatToDef{};
for (size_t i = 0; i < srcDefToFormat.size(); ++i) {
if (srcDefToFormat[i] > 0) {
srcFormatToDef[srcDefToFormat[i] - 1] = i;
}
else {
srcFormatToDef[i] = i;
}
}
// Compute permutation array from src format to dst format:
DataFormatTranspose srcToDst{};
for (size_t i = 0; i < dstDefToFormat.size(); ++i) {
if (dstDefToFormat[srcFormatToDef[i]] > 0) {
srcToDst[i] = dstDefToFormat[srcFormatToDef[i]] - 1;
}
else {
srcToDst[i] = i;
}
}
return srcToDst;
}
...@@ -430,6 +430,9 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::No ...@@ -430,6 +430,9 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::No
return out; return out;
} }
Aidge::Node::~Node() = default;
// namespace Aidge { // namespace Aidge {
// std::ostream& operator << (std::ostream& os, Aidge::Node& n) { // std::ostream& operator << (std::ostream& os, Aidge::Node& n) {
// using namespace std; // using namespace std;
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/utils/DynamicAttributes.hpp"
std::map<std::type_index, bool(*)(const future_std::any&, const future_std::any&)> Aidge::DynamicAttributes::mAnyCompare;
bool future_std::operator<(const future_std::any& lhs, const future_std::any& rhs) {
return (lhs.type() == rhs.type())
? Aidge::DynamicAttributes::mAnyCompare.at(lhs.type())(lhs, rhs)
: (lhs.type().before(rhs.type()));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment