Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1376 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pybind_Data.cpp 2.14 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <pybind11/pybind11.h>

#include "aidge/data/Data.hpp"

namespace py = pybind11;
namespace Aidge {

void init_Data(py::module& m){
    // Define enumeration names for python as lowercase dtype name
    // This defined enum names compatible with basic numpy dtype
    // name such as: float32, flot64, [u]int32, [u]int64, ...
    auto python_enum_name = [](const DataType& dtype) {
        auto str_lower = [](std::string& str) {
            std::transform(str.begin(), str.end(), str.begin(),
                           [](unsigned char c){
                               return std::tolower(c);
                           });
        };
        auto dtype_name = std::string(Aidge::format_as(dtype));
        str_lower(dtype_name);
        return dtype_name;
    };
    // Auto generate enumeration names from lowercase dtype strings
    std::vector<std::string> enum_names;
    for (auto dtype_str : EnumStrings<Aidge::DataType>::data) {
        auto dtype = static_cast<DataType>(enum_names.size());
        auto enum_name = python_enum_name(dtype);
        enum_names.push_back(enum_name);
    }

    // Define python side enumeration aidge_core.dtype
    auto e_dtype = py::enum_<DataType>(m, "dtype");

    // Add enum value for each enum name
    for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
        e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx));
    }

    // Define str() to return the bare enum name value, it allows
    // to compare directly for instance str(tensor.dtype())
    // with str(nparray.dtype)
    e_dtype.def("__str__", [enum_names](const DataType& dtype) {
        return enum_names[static_cast<int>(dtype)];
    }, py::prepend());;

    py::class_<Data, std::shared_ptr<Data>>(m,"Data");


}
}