From 47515b82d50851fabf1fa003240620fc60b0a888 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Wed, 6 Dec 2023 13:08:46 +0000 Subject: [PATCH] Add Stimuli and stimuliImpl classes for data loading --- include/aidge/backend/StimuliImpl.hpp | 34 +++++++++ include/aidge/stimuli/Stimuli.hpp | 101 ++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 include/aidge/backend/StimuliImpl.hpp create mode 100644 include/aidge/stimuli/Stimuli.hpp diff --git a/include/aidge/backend/StimuliImpl.hpp b/include/aidge/backend/StimuliImpl.hpp new file mode 100644 index 000000000..925020beb --- /dev/null +++ b/include/aidge/backend/StimuliImpl.hpp @@ -0,0 +1,34 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_STIMULIIMPL_H_ +#define AIDGE_STIMULIIMPL_H_ +#include <memory> +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" + +namespace Aidge { +/** + * @brief StimuliImpl. Base class to implement data loading functions. + * + */ +class StimuliImpl { +public: + + StimuliImpl(){}; + + virtual std::shared_ptr<Tensor> load(){}; + + virtual ~StimuliImpl() = default; +}; +} // namespace Aidge + +#endif /* AIDGE_STIMULIIMPL_H_ */ diff --git a/include/aidge/stimuli/Stimuli.hpp b/include/aidge/stimuli/Stimuli.hpp new file mode 100644 index 000000000..a6e08347b --- /dev/null +++ b/include/aidge/stimuli/Stimuli.hpp @@ -0,0 +1,101 @@ +#ifndef STIMULI_H +#define STIMULI_H + +#include <cstring> +#include <iostream> +#include <memory> + +#include "aidge/backend/StimuliImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Stimuli. A class wrapping a data sample. Stimuli has two functioning modes. The first mode enables to load data samples from a dataPath and optionnaly store the data in-memory. The second mode enables to store a data sample that was already loaded in memory. + * @details When Stimuli is used in the first mode, the loading function is determined automaticaly based on the backend and the file extension. + */ +class Stimuli : public Registrable<Stimuli, std::tuple<std::string, std::string>, std::unique_ptr<StimuliImpl>(const std::string&)> { +public: + + Stimuli() = delete; + /** + * @brief Construct a new Stimuli object based on a dataPath to load the data. + * + * @param dataPath path to the data to be loaded. + * @param loadDataInMemory when true, keep the data in memory once loaded + */ + Stimuli(const std::string& dataPath, + bool loadDataInMemory = false) : + mDataPath(dataPath) + { + size_t dotPos = dataPath.find_last_of("."); + assert(dotPos != std::string::npos && "Cannot find extension"); + mFileExtension = dataPath.substr(dotPos + 1); + }; + + /** + * @brief Construct a new Stimuli object copied from another one. + * @param otherStimuli + */ + Stimuli(const Stimuli& otherStimuli) + : mDataPath(otherStimuli.mDataPath), + mLoadDataInMemory(otherStimuli.mLoadDataInMemory), + mFileExtension(otherStimuli.mFileExtension), + mData(otherStimuli.mData) + { + if (otherStimuli.mImpl) { + mImpl = Registrar<Stimuli>::create({"opencv", mFileExtension})(mDataPath); + } + } + + /** + * @brief Construct a new Stimuli object based on a tensor that is already loaded in memory. + * + * @param data the data tensor. + */ + Stimuli(const std::shared_ptr<Tensor> data) : + mData(data), + mLoadDataInMemory(true) {} + virtual ~Stimuli() {}; + + /** + * @brief Set the backend of the stimuli associated load implementation + * @details Create and initialize an implementation. + * @param name name of the backend. + */ + inline void setBackend(const std::string &name) { + mImpl = Registrar<Stimuli>::create({name, mFileExtension})(mDataPath); + } + + /** + * @brief Get the data tensor associated to the stimuli. The data is either loaded from a datapath or passed from an in-memory tensor. + * + * @return std::shared_ptr<Tensor> the data tensor. + */ + virtual std::shared_ptr<Tensor> load(){ + assert((mImpl!=nullptr || mData!=nullptr) && "No load implementation and No stored data"); + + if (mLoadDataInMemory){ + if (mData == nullptr){ + mData = mImpl->load(); + } + return mData; + } + return mImpl->load(); + }; + +protected: + // Implementation of the Stimuli + std::unique_ptr<StimuliImpl> mImpl; + + /// Stimuli data path + std::string mDataPath; + std::string mFileExtension; + bool mLoadDataInMemory; + + /// Stimuli data ptr + std::shared_ptr<Tensor> mData; +}; +} // namespace Aidge + +#endif // STIMULI_H -- GitLab