From 761b32074fee26d74363f9a31a29f73e7d038813 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 19 Mar 2024 08:04:16 +0000 Subject: [PATCH] Add basic Uniform filler. --- include/aidge/aidge.hpp | 1 + include/aidge/data/Tensor.hpp | 2 +- include/aidge/filler/Filler.hpp | 52 +++++++++++++++++++++ python_binding/filler/pybind_Filler.cpp | 60 +++++++++++++++++++++++++ python_binding/pybind_core.cpp | 14 +++--- src/filler/ConstantFiller.cpp | 32 +++++++++++++ 6 files changed, 152 insertions(+), 9 deletions(-) create mode 100644 include/aidge/filler/Filler.hpp create mode 100644 python_binding/filler/pybind_Filler.cpp create mode 100644 src/filler/ConstantFiller.cpp diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 9c5165756..84d77e9f1 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -64,6 +64,7 @@ #include "aidge/stimuli/Stimulus.hpp" #include "aidge/recipes/Recipes.hpp" +#include "aidge/filler/Filler.hpp" #include "aidge/utils/Attributes.hpp" #include "aidge/utils/StaticAttributes.hpp" diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 1f9c5a5ec..2bd73d537 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -292,7 +292,7 @@ public: * @brief Set the DataType of the Tensor and converts data * if the Tensor has already been initialized and copyCast is true. * @param dt DataType - * @param copyCast If true (default), previous data is copy-casted. Otherwise + * @param copyCast If truOe (default), previous data is copy-casted. Otherwise * previous data is lost. */ void setDataType(const DataType dt, bool copyCast = true) { diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp new file mode 100644 index 000000000..a8419e1ca --- /dev/null +++ b/include/aidge/filler/Filler.hpp @@ -0,0 +1,52 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_FILLER_H_ +#define AIDGE_CORE_FILLER_H_ + +#include <memory> + +#include "aidge/data/Tensor.hpp" + +namespace Aidge { + +// void heFiller(std::shared_ptr<Tensor> tensor); + +// template <typename T> +// void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue); + +template <typename T> +void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, constantValue); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +} + +void normalFiller(std::shared_ptr<Tensor> tensor, float mean, float var); +// void uniformFiller(std::shared_ptr<Tensor> tensor); +// void xavierFiller(std::shared_ptr<Tensor> tensor); + +} // namespace Aidge + +#endif /* AIDGE_CORE_FILLER_H_ */ diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp new file mode 100644 index 000000000..dacf0af40 --- /dev/null +++ b/python_binding/filler/pybind_Filler.cpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/filler/Filler.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_Filler(py::module &m) { + m.def("constant_filler", + [](std::shared_ptr<Tensor> tensor, py::object value) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + constantFiller<double>(tensor, value.cast<double>()); + break; + case DataType::Float32: + constantFiller<float>(tensor, value.cast<float>()); + break; + case DataType::Int8: + constantFiller<int8_t>(tensor, value.cast<int8_t>()); + break; + case DataType::Int16: + constantFiller<std::int16_t>(tensor, + value.cast<std::int16_t>()); + break; + case DataType::Int32: + constantFiller<std::int32_t>(tensor, + value.cast<std::int32_t>()); + break; + case DataType::Int64: + constantFiller<std::int64_t>(tensor, + value.cast<std::int64_t>()); + break; + case DataType::UInt8: + constantFiller<std::uint8_t>(tensor, + value.cast<std::uint8_t>()); + break; + case DataType::UInt16: + constantFiller<std::uint16_t>( + tensor, value.cast<std::uint16_t>()); + break; + default: + AIDGE_THROW_OR_ABORT(py::value_error, + "Data type is not supported."); + } + }); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 52863735c..80f0cc44f 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -11,8 +11,7 @@ #include <pybind11/pybind11.h> -#include "aidge/backend/cpu/data/TensorImpl.hpp" // This include add Tensor - +#include "aidge/backend/cpu/data/TensorImpl.hpp" // This include add Tensor namespace py = pybind11; @@ -71,9 +70,9 @@ void init_Recipes(py::module&); void init_Scheduler(py::module&); void init_TensorUtils(py::module&); +void init_Filler(py::module&); - -void init_Aidge(py::module& m){ +void init_Aidge(py::module& m) { init_Data(m); init_Database(m); init_DataProvider(m); @@ -129,9 +128,8 @@ void init_Aidge(py::module& m){ init_Recipes(m); init_Scheduler(m); init_TensorUtils(m); + init_Filler(m); } -PYBIND11_MODULE(aidge_core, m) { - init_Aidge(m); -} -} +PYBIND11_MODULE(aidge_core, m) { init_Aidge(m); } +} // namespace Aidge diff --git a/src/filler/ConstantFiller.cpp b/src/filler/ConstantFiller.cpp new file mode 100644 index 000000000..9a67d40ca --- /dev/null +++ b/src/filler/ConstantFiller.cpp @@ -0,0 +1,32 @@ +// /******************************************************************************** +// * Copyright (c) 2023 CEA-List +// * +// * This program and the accompanying materials are made available under the +// * terms of the Eclipse Public License 2.0 which is available at +// * http://www.eclipse.org/legal/epl-2.0. +// * +// * SPDX-License-Identifier: EPL-2.0 +// * +// ********************************************************************************/ + +// #include "aidge/filler/Filler.hpp" + +// template<typename T> +// void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValue){ +// AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); +// AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + +// std::shared_ptr<Tensor> cpyTensor; +// // Create cpy only if tensor not on CPU +// const Tensor& tensorWithValues = tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + +// // Setting values +// for(std::size_t idx = 0; idx<tensorWithValues.size(); ++idx){ +// tensorWithValues.set<T>(idx, constantValue); +// } + +// // Copy values back to the original tensors (actual copy only if needed) +// tensor->copyCastFrom(tensorWithValues); + + +// } -- GitLab