Skip to content
Snippets Groups Projects
Commit b6ef490e authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added maxElements input to Stack

parent 3d9f7779
No related branches found
No related tags found
2 merge requests!279v0.4.0,!267Fixed issues for LSTM
Pipeline #60489 passed
...@@ -50,7 +50,7 @@ class StackOp ...@@ -50,7 +50,7 @@ class StackOp
public: public:
static const std::string s_type; static const std::string s_type;
StackOp(std::uint32_t maxElements); StackOp(std::uint32_t maxElements = 0);
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output * @brief Copy-constructor. Copy the operator attributes and its output
...@@ -71,6 +71,7 @@ class StackOp ...@@ -71,6 +71,7 @@ class StackOp
std::set<std::string> getAvailableBackends() const override; std::set<std::string> getAvailableBackends() const override;
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
void forward() override; void forward() override;
...@@ -87,14 +88,14 @@ class StackOp ...@@ -87,14 +88,14 @@ class StackOp
} }
static const std::vector<std::string> getInputsName() { static const std::vector<std::string> getInputsName() {
return {"data_input"}; return {"data_input", "max_elements"};
} }
static const std::vector<std::string> getOutputsName() { static const std::vector<std::string> getOutputsName() {
return {"data_output"}; return {"data_output"};
} }
}; };
std::shared_ptr<Node> stack(std::uint32_t maxElements, std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0,
const std::string &name = ""); const std::string &name = "");
} // namespace Aidge } // namespace Aidge
......
...@@ -29,8 +29,8 @@ void init_Stack(py::module &m) { ...@@ -29,8 +29,8 @@ void init_Stack(py::module &m) {
.def_readonly_static("Type", &StackOp::s_type); .def_readonly_static("Type", &StackOp::s_type);
m.def("Stack", m.def("Stack",
&stack, &Stack,
py::arg("max_elements"), py::arg("max_elements") = 0,
py::arg("name") = "", py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a node containing a Stack operator. Initialize a node containing a Stack operator.
......
...@@ -26,7 +26,7 @@ namespace Aidge { ...@@ -26,7 +26,7 @@ namespace Aidge {
// inputSize // inputSize
Elts_t StackProdConso::getRequiredMemory( Elts_t StackProdConso::getRequiredMemory(
const Aidge::IOIndex_t inputIdx, const Aidge::IOIndex_t inputIdx,
const std::vector<DimSize_t> &inputsSize) const { const std::vector<DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input"); assert(mOp.getRawInput(inputIdx) && "requires valid input");
const StackOp &op = dynamic_cast<const StackOp &>(mOp); const StackOp &op = dynamic_cast<const StackOp &>(mOp);
...@@ -62,15 +62,10 @@ void StackOpImpl::forward() { ...@@ -62,15 +62,10 @@ void StackOpImpl::forward() {
} }
StackOp::StackOp(std::uint32_t maxElements) StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data}, 1), : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements), attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::ForwardStep>(0))) { attr<StackAttr::ForwardStep>(0))) {
if (maxElements == 0) {
AIDGE_THROW_OR_ABORT(
std::invalid_argument,
"StackOp creation failed: maxElements must be greater than 0.");
}
mImpl = std::make_shared<StackOpImpl>(*this); mImpl = std::make_shared<StackOpImpl>(*this);
} }
...@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const { ...@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
return std::make_shared<StackOp>(*this); return std::make_shared<StackOp>(*this);
} }
bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { bool Aidge::StackOp::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::StackOp::forwardDims(bool allowDataDependency) {
if (inputsAssociated()) { if (inputsAssociated()) {
// Copy optional input #1 first dimension, if present, to attribute MaxElements
if (getInput(1)) {
if (!allowDataDependency) {
Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
std::shared_ptr<Tensor> fallback;
const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu");
AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0");
this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0];
}
AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0");
auto inputDims = getInput(0)->dims(); auto inputDims = getInput(0)->dims();
inputDims.insert(inputDims.begin(), maxElements()); inputDims.insert(inputDims.begin(), maxElements());
getOutput(0)->resize(inputDims); getOutput(0)->resize(inputDims);
...@@ -116,7 +136,7 @@ void StackOp::forward() { ...@@ -116,7 +136,7 @@ void StackOp::forward() {
++forwardStep(); ++forwardStep();
} }
std::shared_ptr<Node> stack(std::uint32_t maxElements, std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) { const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
name); name);
......
...@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { ...@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
REQUIRE(op2.maxElements() == maxElements); REQUIRE(op2.maxElements() == maxElements);
REQUIRE(op2.forwardStep() == 0); REQUIRE(op2.forwardStep() == 0);
} }
// Invalid arguments
REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument);
} }
SECTION("forwardDims") { SECTION("forwardDims") {
...@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { ...@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems); tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems);
} }
auto myStack = stack(numTensors); auto myStack = Stack(numTensors);
myStack->getOperator()->setBackend("cpu"); myStack->getOperator()->setBackend("cpu");
myStack->getOperator()->setDataType(DataType::Float32); myStack->getOperator()->setDataType(DataType::Float32);
auto op = auto op =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment