diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 9c516575690fbca947496920c7068874bda6bf63..84d77e9f1370977e899331bad27f2ade4b2178f3 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -64,6 +64,7 @@ #include "aidge/stimuli/Stimulus.hpp" #include "aidge/recipes/Recipes.hpp" +#include "aidge/filler/Filler.hpp" #include "aidge/utils/Attributes.hpp" #include "aidge/utils/StaticAttributes.hpp" diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 1f9c5a5ec14cca4469b0329f2f968cf9dbc7b0de..2bd73d53784add075a355112aa2565d9e380e39c 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -292,7 +292,7 @@ public: * @brief Set the DataType of the Tensor and converts data * if the Tensor has already been initialized and copyCast is true. * @param dt DataType - * @param copyCast If true (default), previous data is copy-casted. Otherwise + * @param copyCast If truOe (default), previous data is copy-casted. Otherwise * previous data is lost. */ void setDataType(const DataType dt, bool copyCast = true) { diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a8419e1caf3f8639e067892501b8b4c8996952ae --- /dev/null +++ b/include/aidge/filler/Filler.hpp @@ -0,0 +1,52 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_FILLER_H_ +#define AIDGE_CORE_FILLER_H_ + +#include <memory> + +#include "aidge/data/Tensor.hpp" + +namespace Aidge { + +// void heFiller(std::shared_ptr<Tensor> tensor); + +// template <typename T> +// void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue); + +template <typename T> +void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, constantValue); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +} + +void normalFiller(std::shared_ptr<Tensor> tensor, float mean, float var); +// void uniformFiller(std::shared_ptr<Tensor> tensor); +// void xavierFiller(std::shared_ptr<Tensor> tensor); + +} // namespace Aidge + +#endif /* AIDGE_CORE_FILLER_H_ */ diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dacf0af40012fdb64655e80c6949e688d3a5d131 --- /dev/null +++ b/python_binding/filler/pybind_Filler.cpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/filler/Filler.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_Filler(py::module &m) { + m.def("constant_filler", + [](std::shared_ptr<Tensor> tensor, py::object value) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + constantFiller<double>(tensor, value.cast<double>()); + break; + case DataType::Float32: + constantFiller<float>(tensor, value.cast<float>()); + break; + case DataType::Int8: + constantFiller<int8_t>(tensor, value.cast<int8_t>()); + break; + case DataType::Int16: + constantFiller<std::int16_t>(tensor, + value.cast<std::int16_t>()); + break; + case DataType::Int32: + constantFiller<std::int32_t>(tensor, + value.cast<std::int32_t>()); + break; + case DataType::Int64: + constantFiller<std::int64_t>(tensor, + value.cast<std::int64_t>()); + break; + case DataType::UInt8: + constantFiller<std::uint8_t>(tensor, + value.cast<std::uint8_t>()); + break; + case DataType::UInt16: + constantFiller<std::uint16_t>( + tensor, value.cast<std::uint16_t>()); + break; + default: + AIDGE_THROW_OR_ABORT(py::value_error, + "Data type is not supported."); + } + }); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 52863735ca431e797fab3426d7e61796a8725dd2..80f0cc44f8f85e042bb8c2ce2cef3d8bba9099a2 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -11,8 +11,7 @@ #include <pybind11/pybind11.h> -#include "aidge/backend/cpu/data/TensorImpl.hpp" // This include add Tensor - +#include "aidge/backend/cpu/data/TensorImpl.hpp" // This include add Tensor namespace py = pybind11; @@ -71,9 +70,9 @@ void init_Recipes(py::module&); void init_Scheduler(py::module&); void init_TensorUtils(py::module&); +void init_Filler(py::module&); - -void init_Aidge(py::module& m){ +void init_Aidge(py::module& m) { init_Data(m); init_Database(m); init_DataProvider(m); @@ -129,9 +128,8 @@ void init_Aidge(py::module& m){ init_Recipes(m); init_Scheduler(m); init_TensorUtils(m); + init_Filler(m); } -PYBIND11_MODULE(aidge_core, m) { - init_Aidge(m); -} -} +PYBIND11_MODULE(aidge_core, m) { init_Aidge(m); } +} // namespace Aidge diff --git a/src/filler/ConstantFiller.cpp b/src/filler/ConstantFiller.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a67d40cac6fcabd974b17f7e6195921facffdc6 --- /dev/null +++ b/src/filler/ConstantFiller.cpp @@ -0,0 +1,32 @@ +// /******************************************************************************** +// * Copyright (c) 2023 CEA-List +// * +// * This program and the accompanying materials are made available under the +// * terms of the Eclipse Public License 2.0 which is available at +// * http://www.eclipse.org/legal/epl-2.0. +// * +// * SPDX-License-Identifier: EPL-2.0 +// * +// ********************************************************************************/ + +// #include "aidge/filler/Filler.hpp" + +// template<typename T> +// void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValue){ +// AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); +// AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + +// std::shared_ptr<Tensor> cpyTensor; +// // Create cpy only if tensor not on CPU +// const Tensor& tensorWithValues = tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + +// // Setting values +// for(std::size_t idx = 0; idx<tensorWithValues.size(); ++idx){ +// tensorWithValues.set<T>(idx, constantValue); +// } + +// // Copy values back to the original tensors (actual copy only if needed) +// tensor->copyCastFrom(tensorWithValues); + + +// }