From b6ef490e3e9c6f8d0bd568267df6307f9ec9167e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 29 Nov 2024 14:13:26 +0100 Subject: [PATCH] Added maxElements input to Stack --- include/aidge/operator/Stack.hpp | 7 +++-- python_binding/operator/pybind_Stack.cpp | 4 +-- src/operator/Stack.cpp | 38 ++++++++++++++++++------ unit_tests/operator/Test_StackImpl.cpp | 5 +--- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index 9644620d7..21633e451 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -50,7 +50,7 @@ class StackOp public: static const std::string s_type; - StackOp(std::uint32_t maxElements); + StackOp(std::uint32_t maxElements = 0); /** * @brief Copy-constructor. Copy the operator attributes and its output @@ -71,6 +71,7 @@ class StackOp std::set<std::string> getAvailableBackends() const override; + bool dimsForwarded() const override final; bool forwardDims(bool allowDataDependency = false) override final; void forward() override; @@ -87,14 +88,14 @@ class StackOp } static const std::vector<std::string> getInputsName() { - return {"data_input"}; + return {"data_input", "max_elements"}; } static const std::vector<std::string> getOutputsName() { return {"data_output"}; } }; -std::shared_ptr<Node> stack(std::uint32_t maxElements, +std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0, const std::string &name = ""); } // namespace Aidge diff --git a/python_binding/operator/pybind_Stack.cpp b/python_binding/operator/pybind_Stack.cpp index 232889210..c9bd969fa 100644 --- a/python_binding/operator/pybind_Stack.cpp +++ b/python_binding/operator/pybind_Stack.cpp @@ -29,8 +29,8 @@ void init_Stack(py::module &m) { .def_readonly_static("Type", &StackOp::s_type); m.def("Stack", - &stack, - py::arg("max_elements"), + &Stack, + py::arg("max_elements") = 0, py::arg("name") = "", R"mydelimiter( Initialize a node containing a Stack operator. diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 4ca7cc983..ab9ddc4f7 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -26,7 +26,7 @@ namespace Aidge { // inputSize Elts_t StackProdConso::getRequiredMemory( const Aidge::IOIndex_t inputIdx, - const std::vector<DimSize_t> &inputsSize) const { + const std::vector<DimSize_t> &/*inputsSize*/) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); const StackOp &op = dynamic_cast<const StackOp &>(mOp); @@ -62,15 +62,10 @@ void StackOpImpl::forward() { } StackOp::StackOp(std::uint32_t maxElements) - : OperatorTensor(s_type, {InputCategory::Data}, 1), + : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1), mAttributes(std::make_shared<Attributes_>( attr<StackAttr::MaxElements>(maxElements), attr<StackAttr::ForwardStep>(0))) { - if (maxElements == 0) { - AIDGE_THROW_OR_ABORT( - std::invalid_argument, - "StackOp creation failed: maxElements must be greater than 0."); - } mImpl = std::make_shared<StackOpImpl>(*this); } @@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const { return std::make_shared<StackOp>(*this); } -bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { +bool Aidge::StackOp::dimsForwarded() const { + if ((getInput(1) && !getInput(1)->undefined())) + { + // output dims are data dependent + return false; + } + + return OperatorTensor::dimsForwarded(); +} + +bool Aidge::StackOp::forwardDims(bool allowDataDependency) { if (inputsAssociated()) { + // Copy optional input #1 first dimension, if present, to attribute MaxElements + if (getInput(1)) { + if (!allowDataDependency) { + Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } + + std::shared_ptr<Tensor> fallback; + const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu"); + AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0"); + this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0]; + } + + AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0"); + auto inputDims = getInput(0)->dims(); inputDims.insert(inputDims.begin(), maxElements()); getOutput(0)->resize(inputDims); @@ -116,7 +136,7 @@ void StackOp::forward() { ++forwardStep(); } -std::shared_ptr<Node> stack(std::uint32_t maxElements, +std::shared_ptr<Node> Stack(std::uint32_t maxElements, const std::string &name) { return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), name); diff --git a/unit_tests/operator/Test_StackImpl.cpp b/unit_tests/operator/Test_StackImpl.cpp index d853a1ba2..ccdf5787d 100644 --- a/unit_tests/operator/Test_StackImpl.cpp +++ b/unit_tests/operator/Test_StackImpl.cpp @@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { REQUIRE(op2.maxElements() == maxElements); REQUIRE(op2.forwardStep() == 0); } - - // Invalid arguments - REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument); } SECTION("forwardDims") { @@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems); } - auto myStack = stack(numTensors); + auto myStack = Stack(numTensors); myStack->getOperator()->setBackend("cpu"); myStack->getOperator()->setDataType(DataType::Float32); auto op = -- GitLab