Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1584 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Registrar.hpp 5.38 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_UTILS_REGISTRAR_H_
#define AIDGE_CORE_UTILS_REGISTRAR_H_

#ifdef PYBIND
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // declare_registrable key can recquire stl
#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn

#endif

#include "aidge/utils/ErrorHandling.hpp"

#include <functional>
#include <map>
#include <cassert>

namespace Aidge {
#ifdef PYBIND
namespace py = pybind11;
#endif

// Abstract class used to test if a class is Registrable.
class AbstractRegistrable {};

template <class DerivedClass, class Key, class Func> // curiously rucurring template pattern
class Registrable {
public:
    typedef Key registrar_key;
    typedef std::function<Func> registrar_type;

    static std::map<Key, std::function<Func>>& registry()
    {
        #ifdef PYBIND
        #define _CRT_SECURE_NO_WARNINGS
        if (Py_IsInitialized()){
            std::string name = std::string("registrar_")+typeid(Registrable<DerivedClass, Key, Func>).name();
            static auto shared_data = reinterpret_cast<std::map<Key, std::function<Func>> *>(py::get_shared_data(name));
            if (!shared_data)
                shared_data = static_cast<std::map<Key, std::function<Func>> *>(py::set_shared_data(name, new std::map<Key, std::function<Func>>()));
            return *shared_data;
        }
        #endif // PYBIND
        static std::map<Key, std::function<Func>> rMap;
        return rMap;
    }

};

template <class C>
struct Registrar {
    typedef typename C::registrar_key registrar_key;
    typedef typename C::registrar_type registrar_type;

    Registrar(const registrar_key& key, registrar_type func) {
        //fmt::print("REGISTRAR: {}\n", key);
        // bool newInsert;
        // std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
        C::registry().erase(key);
        C::registry().insert(std::make_pair(key, func));
        //assert(newInsert && "registrar already exists");
    }

    static bool exists(const registrar_key& key) {
        const auto it = C::registry().find(key);
        return (it != C::registry().end());
    }

    static auto create(const registrar_key& key){
        const auto it = C::registry().find(key);
        AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);

        return (*it).second;
    }
    static std::vector<registrar_key> getKeys(){
        std::vector<registrar_key> keys;
        for(auto keyValue : C::registry())
            keys.push_back(keyValue.first);
        return keys;
    }
};

#ifdef PYBIND
/**
 * @brief Function to define register function for a registrable class
 * Defined here to have access to this function in every module who wants
 * to create a new registrable class.
 *
 * @tparam C registrable class
 * @param m pybind module
 * @param class_name python name of the class
 */
template <class C>
void declare_registrable(py::module& m, const std::string& class_name){
    typedef typename C::registrar_key registrar_key;
    typedef typename C::registrar_type registrar_type;
    m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){
        Registrar<C>(key, function);
    })
    .def(("get_keys_"+ class_name).c_str(), [](){
        return Registrar<C>::getKeys();
    });
}
#endif

/*
* This macro allow to set an implementation to an operator
* This macro is mandatory for using implementation registered in python
* PyBind when calling create method will do a call to the copy ctor if
* op is not visible to the python world (if the create method return a python function)
* See this issue for more information https://github.com/pybind/pybind11/issues/4417
* Note: using a method to do this is not possible has any call to a function will call
* the cpy ctor. This is why I used a macro
* Note: I duplicated
*             (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
* This is because the py::cast need to be done in the same scope.
* I know this only empyrically not sure what happens under the hood...
*
* If someone wants to find an alternative to this Macro, you can contact me:
*   cyril.moineau@cea.fr
*/
#ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
     \
        if (Registrar<T_Op>::exists(backend_name)) { \
            if(Py_IsInitialized()) { \
                auto obj = py::cast(&(op)); \
                (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
            } else { \
                (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
            } \
        }
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name)                          \
    if (Registrar<T_Op>::exists(backend_name)) {                        \
        (op).setImpl(Registrar<T_Op>::create(backend_name)(op));        \
    }
#endif

}

#endif //AIDGE_CORE_UTILS_REGISTRAR_H_