Skip to content
Snippets Groups Projects
Commit c85ae170 authored by Jerome Hue's avatar Jerome Hue Committed by Olivier BICHLER
Browse files

feat: Stack operator

parent 33c2fca9
No related branches found
No related tags found
2 merge requests!279v0.4.0,!256Add a Stack operator
#ifndef AIDGE_CORE_OPERATOR_STACK_H_
#define AIDGE_CORE_OPERATOR_STACK_H_
#include <memory>
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class StackOpImpl : public OperatorImpl {
public:
StackOpImpl(const Operator &op, const std::string &backend = "")
: OperatorImpl(op, backend) {}
void forward() override;
};
enum class StackAttr { MaxElements };
class StackOp
: public OperatorTensor,
public Registrable<
StackOp,
std::string,
std::function<std::unique_ptr<OperatorImpl>(const StackOp &)>> {
private:
using Attributes_ = StaticAttributes<StackAttr, std::uint32_t>;
template <StackAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
static const std::string s_type;
StackOp(std::uint32_t maxElements);
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
StackOp(const StackOp &op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::StackOp
*/
std::shared_ptr<Operator> clone() const override;
void setBackend(const std::string &name,
DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void forward() override;
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
inline std::uint32_t &maxElements() const {
return mAttributes->template getAttr<StackAttr::MaxElements>();
}
static const std::vector<std::string> getInputsName() {
return {"data_input"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> stack(std::uint32_t maxElements,
const std::string &name = "");
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::StackAttr>::data[] = {"max_elements"};
}
#endif
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Stack.hpp"
#include <memory>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
const std::string StackOp::s_type = "Stack";
void StackOpImpl::forward() {
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
assert(op.getInput(0) && "missing input #0");
//*op.getOutput(0) = op.getInput(0)->extract({op.forwardStep()});
}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements))) {
mImpl = std::make_shared<StackOpImpl>(*this);
}
StackOp::StackOp(const Aidge::StackOp &op)
: OperatorTensor(op), mAttributes(op.mAttributes) {
if (!op.backend().empty()) {
SET_IMPL_MACRO(StackOp, *this, op.backend());
} else {
mImpl = std::make_shared<StackOpImpl>(*this);
}
}
std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
return std::make_shared<StackOp>(*this);
}
bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
auto inputDims = getInput(0)->dims();
inputDims.insert(inputDims.begin(), maxElements());
getOutput(0)->resize(inputDims);
return true;
}
return false;
}
void StackOp::setBackend(const std::string &name, DeviceIdx_t device) {
if (Registrar<StackOp>::exists({name})) {
SET_IMPL_MACRO(StackOp, *this, name);
} else {
mImpl = std::make_shared<StackOpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> StackOp::getAvailableBackends() const {
return Registrar<StackOp>::getKeys();
}
void StackOp::forward() {
Operator::forward();
}
std::shared_ptr<Node> stack(std::uint32_t maxElements,
const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
name);
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators_random.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>
#include <cstddef>
#include <memory>
#include <random>
#include <vector>
#include "aidge/operator/Stack.hpp"
#include "aidge/data/Tensor.hpp"
using Catch::Matchers::Equals;
namespace Aidge {
TEST_CASE("[core/operator] Stack(forwardDims)", "[Stack][forwardDims]") {
auto rd = Catch::Generators::Detail::getSeed;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::size_t> dimsDist(1, 10);
auto maxElementsToStack = dimsDist(gen);
auto stackNode = stack(maxElementsToStack);
auto op = std::dynamic_pointer_cast<StackOp>(stackNode->getOperator());
std::shared_ptr<Tensor> t0 = std::make_shared<Tensor>(Aidge::Array1D<int,3>{{4,5,6}});
// input #0 should be associated with a Tensor
REQUIRE_THROWS_WITH(op->forwardDims(), Equals("Stack: input #0 should be associated with a Tensor"));
op->associateInput(0, t0);
REQUIRE_NOTHROW(op->forwardDims());
auto dims = op->getOutput(0)->dims();
REQUIRE(dims[0] == maxElementsToStack);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment