Skip to content
Snippets Groups Projects
Commit b6ef490e authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added maxElements input to Stack

parent 3d9f7779
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ class StackOp
public:
static const std::string s_type;
StackOp(std::uint32_t maxElements);
StackOp(std::uint32_t maxElements = 0);
/**
* @brief Copy-constructor. Copy the operator attributes and its output
......@@ -71,6 +71,7 @@ class StackOp
std::set<std::string> getAvailableBackends() const override;
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void forward() override;
......@@ -87,14 +88,14 @@ class StackOp
}
static const std::vector<std::string> getInputsName() {
return {"data_input"};
return {"data_input", "max_elements"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> stack(std::uint32_t maxElements,
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0,
const std::string &name = "");
} // namespace Aidge
......
......@@ -29,8 +29,8 @@ void init_Stack(py::module &m) {
.def_readonly_static("Type", &StackOp::s_type);
m.def("Stack",
&stack,
py::arg("max_elements"),
&Stack,
py::arg("max_elements") = 0,
py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a Stack operator.
......
......@@ -26,7 +26,7 @@ namespace Aidge {
// inputSize
Elts_t StackProdConso::getRequiredMemory(
const Aidge::IOIndex_t inputIdx,
const std::vector<DimSize_t> &inputsSize) const {
const std::vector<DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
......@@ -62,15 +62,10 @@ void StackOpImpl::forward() {
}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data}, 1),
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::ForwardStep>(0))) {
if (maxElements == 0) {
AIDGE_THROW_OR_ABORT(
std::invalid_argument,
"StackOp creation failed: maxElements must be greater than 0.");
}
mImpl = std::make_shared<StackOpImpl>(*this);
}
......@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
return std::make_shared<StackOp>(*this);
}
bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
bool Aidge::StackOp::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::StackOp::forwardDims(bool allowDataDependency) {
if (inputsAssociated()) {
// Copy optional input #1 first dimension, if present, to attribute MaxElements
if (getInput(1)) {
if (!allowDataDependency) {
Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
std::shared_ptr<Tensor> fallback;
const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu");
AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0");
this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0];
}
AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0");
auto inputDims = getInput(0)->dims();
inputDims.insert(inputDims.begin(), maxElements());
getOutput(0)->resize(inputDims);
......@@ -116,7 +136,7 @@ void StackOp::forward() {
++forwardStep();
}
std::shared_ptr<Node> stack(std::uint32_t maxElements,
std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
name);
......
......@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
REQUIRE(op2.maxElements() == maxElements);
REQUIRE(op2.forwardStep() == 0);
}
// Invalid arguments
REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument);
}
SECTION("forwardDims") {
......@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems);
}
auto myStack = stack(numTensors);
auto myStack = Stack(numTensors);
myStack->getOperator()->setBackend("cpu");
myStack->getOperator()->setDataType(DataType::Float32);
auto op =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment