Skip to content
Snippets Groups Projects
Commit 5d0baf0b authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'lstm' into 'dev'

Fixed issues for LSTM

See merge request eclipse/aidge/aidge_core!267
parents 269145ad b6ef490e
No related branches found
No related tags found
2 merge requests!279v0.4.0,!267Fixed issues for LSTM
Pipeline #60496 passed
......@@ -50,7 +50,7 @@ class StackOp
public:
static const std::string s_type;
StackOp(std::uint32_t maxElements);
StackOp(std::uint32_t maxElements = 0);
/**
* @brief Copy-constructor. Copy the operator attributes and its output
......@@ -71,6 +71,7 @@ class StackOp
std::set<std::string> getAvailableBackends() const override;
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void forward() override;
......@@ -87,14 +88,14 @@ class StackOp
}
static const std::vector<std::string> getInputsName() {
return {"data_input"};
return {"data_input", "max_elements"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> stack(std::uint32_t maxElements,
std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0,
const std::string &name = "");
} // namespace Aidge
......
......@@ -29,8 +29,8 @@ void init_Stack(py::module &m) {
.def_readonly_static("Type", &StackOp::s_type);
m.def("Stack",
&stack,
py::arg("max_elements"),
&Stack,
py::arg("max_elements") = 0,
py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a Stack operator.
......
......@@ -20,14 +20,8 @@
#include "aidge/utils/Types.h"
void Aidge::Shape_OpImpl::forward() {
const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp);
const auto start = op.start();
const auto end = op.end();
op.getOutput(0)->getImpl()->copyCast(std::next(op.getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
// Do nothing...
// Output is already valid after forwardDims()
}
///////////////////////////////////////////////
......@@ -75,6 +69,11 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape");
mOutputs[0]->resize({roi});
// Ensure the output of this operator is valid after forwardDims():
mOutputs[0]->getImpl()->copyCast(std::next(getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
return true;
}
......
......@@ -26,7 +26,7 @@ namespace Aidge {
// inputSize
Elts_t StackProdConso::getRequiredMemory(
const Aidge::IOIndex_t inputIdx,
const std::vector<DimSize_t> &inputsSize) const {
const std::vector<DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
const StackOp &op = dynamic_cast<const StackOp &>(mOp);
......@@ -62,15 +62,10 @@ void StackOpImpl::forward() {
}
StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data}, 1),
: OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::ForwardStep>(0))) {
if (maxElements == 0) {
AIDGE_THROW_OR_ABORT(
std::invalid_argument,
"StackOp creation failed: maxElements must be greater than 0.");
}
mImpl = std::make_shared<StackOpImpl>(*this);
}
......@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
return std::make_shared<StackOp>(*this);
}
bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
bool Aidge::StackOp::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::StackOp::forwardDims(bool allowDataDependency) {
if (inputsAssociated()) {
// Copy optional input #1 first dimension, if present, to attribute MaxElements
if (getInput(1)) {
if (!allowDataDependency) {
Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
std::shared_ptr<Tensor> fallback;
const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu");
AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0");
this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0];
}
AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0");
auto inputDims = getInput(0)->dims();
inputDims.insert(inputDims.begin(), maxElements());
getOutput(0)->resize(inputDims);
......@@ -116,7 +136,7 @@ void StackOp::forward() {
++forwardStep();
}
std::shared_ptr<Node> stack(std::uint32_t maxElements,
std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
name);
......
......@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
REQUIRE(op2.maxElements() == maxElements);
REQUIRE(op2.forwardStep() == 0);
}
// Invalid arguments
REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument);
}
SECTION("forwardDims") {
......@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems);
}
auto myStack = stack(numTensors);
auto myStack = Stack(numTensors);
myStack->getOperator()->setBackend("cpu");
myStack->getOperator()->setDataType(DataType::Float32);
auto op =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment