From c85ae1703e5c1d1d78f19d1ebaaede1db6ba1bf4 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Thu, 14 Nov 2024 12:19:29 +0100 Subject: [PATCH] feat: Stack operator --- include/aidge/operator/Stack.hpp | 88 ++++++++++++++++++++++++++ src/operator/Stack.cpp | 86 +++++++++++++++++++++++++ unit_tests/operator/Test_StackImpl.cpp | 47 ++++++++++++++ 3 files changed, 221 insertions(+) create mode 100644 include/aidge/operator/Stack.hpp create mode 100644 src/operator/Stack.cpp create mode 100644 unit_tests/operator/Test_StackImpl.cpp diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp new file mode 100644 index 000000000..24ac075cf --- /dev/null +++ b/include/aidge/operator/Stack.hpp @@ -0,0 +1,88 @@ +#ifndef AIDGE_CORE_OPERATOR_STACK_H_ +#define AIDGE_CORE_OPERATOR_STACK_H_ + +#include <memory> +#include <string> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class StackOpImpl : public OperatorImpl { + public: + StackOpImpl(const Operator &op, const std::string &backend = "") + : OperatorImpl(op, backend) {} + void forward() override; +}; + +enum class StackAttr { MaxElements }; + +class StackOp + : public OperatorTensor, + public Registrable< + StackOp, + std::string, + std::function<std::unique_ptr<OperatorImpl>(const StackOp &)>> { + + private: + using Attributes_ = StaticAttributes<StackAttr, std::uint32_t>; + template <StackAttr e> using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + + public: + static const std::string s_type; + + StackOp(std::uint32_t maxElements); + + /** + * @brief Copy-constructor. Copy the operator attributes and its output + * tensor(s), but not its input tensors (the new operator has no input + * associated). + * @param op Operator to copy. + */ + StackOp(const StackOp &op); + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::StackOp + */ + std::shared_ptr<Operator> clone() const override; + + void setBackend(const std::string &name, + DeviceIdx_t device = 0) override final; + + std::set<std::string> getAvailableBackends() const override; + + bool forwardDims(bool allowDataDependency = false) override final; + void forward() override; + + inline std::shared_ptr<Attributes> attributes() const override { + return mAttributes; + } + inline std::uint32_t &maxElements() const { + return mAttributes->template getAttr<StackAttr::MaxElements>(); + } + + static const std::vector<std::string> getInputsName() { + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName() { + return {"data_output"}; + } +}; + +std::shared_ptr<Node> stack(std::uint32_t maxElements, + const std::string &name = ""); +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::StackAttr>::data[] = {"max_elements"}; +} + +#endif diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp new file mode 100644 index 000000000..9bdfadd08 --- /dev/null +++ b/src/operator/Stack.cpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/Stack.hpp" + +#include <memory> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +const std::string StackOp::s_type = "Stack"; + +void StackOpImpl::forward() { + const StackOp &op = dynamic_cast<const StackOp &>(mOp); + assert(op.getInput(0) && "missing input #0"); + //*op.getOutput(0) = op.getInput(0)->extract({op.forwardStep()}); +} + +StackOp::StackOp(std::uint32_t maxElements) + : OperatorTensor(s_type, {InputCategory::Data}, 1), + mAttributes(std::make_shared<Attributes_>( + attr<StackAttr::MaxElements>(maxElements))) { + mImpl = std::make_shared<StackOpImpl>(*this); +} + +StackOp::StackOp(const Aidge::StackOp &op) + : OperatorTensor(op), mAttributes(op.mAttributes) { + if (!op.backend().empty()) { + SET_IMPL_MACRO(StackOp, *this, op.backend()); + } else { + mImpl = std::make_shared<StackOpImpl>(*this); + } +} + +std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const { + return std::make_shared<StackOp>(*this); +} + +bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { + auto inputDims = getInput(0)->dims(); + inputDims.insert(inputDims.begin(), maxElements()); + getOutput(0)->resize(inputDims); + return true; + } + + return false; +} + +void StackOp::setBackend(const std::string &name, DeviceIdx_t device) { + if (Registrar<StackOp>::exists({name})) { + SET_IMPL_MACRO(StackOp, *this, name); + } else { + mImpl = std::make_shared<StackOpImpl>(*this); + } + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> StackOp::getAvailableBackends() const { + return Registrar<StackOp>::getKeys(); +} + +void StackOp::forward() { + Operator::forward(); +} + +std::shared_ptr<Node> stack(std::uint32_t maxElements, + const std::string &name) { + return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), + name); +} +} // namespace Aidge diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp new file mode 100644 index 000000000..7652c03a7 --- /dev/null +++ b/unit_tests/operator/Test_StackImpl.cpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <catch2/generators/catch_generators_random.hpp> +#include <catch2/matchers/catch_matchers_string.hpp> +#include <cstddef> +#include <memory> +#include <random> +#include <vector> + +#include "aidge/operator/Stack.hpp" +#include "aidge/data/Tensor.hpp" + +using Catch::Matchers::Equals; + +namespace Aidge { +TEST_CASE("[core/operator] Stack(forwardDims)", "[Stack][forwardDims]") { + + auto rd = Catch::Generators::Detail::getSeed; + std::mt19937 gen(rd()); + std::uniform_int_distribution<std::size_t> dimsDist(1, 10); + + auto maxElementsToStack = dimsDist(gen); + + auto stackNode = stack(maxElementsToStack); + auto op = std::dynamic_pointer_cast<StackOp>(stackNode->getOperator()); + std::shared_ptr<Tensor> t0 = std::make_shared<Tensor>(Aidge::Array1D<int,3>{{4,5,6}}); + + // input #0 should be associated with a Tensor + REQUIRE_THROWS_WITH(op->forwardDims(), Equals("Stack: input #0 should be associated with a Tensor")); + + op->associateInput(0, t0); + REQUIRE_NOTHROW(op->forwardDims()); + + auto dims = op->getOutput(0)->dims(); + REQUIRE(dims[0] == maxElementsToStack); +} +} // namespace Aidge -- GitLab