Skip to content
Snippets Groups Projects
Commit 5d0baf0b authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'lstm' into 'dev'

Fixed issues for LSTM

See merge request !267
parents 269145ad b6ef490e
No related branches found
No related tags found
No related merge requests found
...@@ -50,7 +50,7 @@ class StackOp ...@@ -50,7 +50,7 @@ class StackOp
public: public:
static const std::string s_type; static const std::string s_type;
StackOp(std::uint32_t maxElements); StackOp(std::uint32_t maxElements = 0);
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output * @brief Copy-constructor. Copy the operator attributes and its output
...@@ -71,6 +71,7 @@ class StackOp ...@@ -71,6 +71,7 @@ class StackOp
std::set<std::string> getAvailableBackends() const override; std::set<std::string> getAvailableBackends() const override;
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
void forward() override; void forward() override;
...@@ -87,14 +88,14 @@ class StackOp ...@@ -87,14 +88,14 @@ class StackOp
} }
static const std::vector<std::string> getInputsName() { static const std::vector<std::string> getInputsName() {
return {"data_input"}; return {"data_input", "max_elements"};
} }
static const std::vector<std::string> getOutputsName() { static const std::vector<std::string> getOutputsName() {
return {"data_output"}; return {"data_output"};
} }
}; };
std::shared_ptr<Node> stack(std::uint32_t maxElements, std::shared_ptr<Node> Stack(std::uint32_t maxElements = 0,
const std::string &name = ""); const std::string &name = "");
} // namespace Aidge } // namespace Aidge
......
...@@ -29,8 +29,8 @@ void init_Stack(py::module &m) { ...@@ -29,8 +29,8 @@ void init_Stack(py::module &m) {
.def_readonly_static("Type", &StackOp::s_type); .def_readonly_static("Type", &StackOp::s_type);
m.def("Stack", m.def("Stack",
&stack, &Stack,
py::arg("max_elements"), py::arg("max_elements") = 0,
py::arg("name") = "", py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a node containing a Stack operator. Initialize a node containing a Stack operator.
......
...@@ -20,14 +20,8 @@ ...@@ -20,14 +20,8 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
void Aidge::Shape_OpImpl::forward() { void Aidge::Shape_OpImpl::forward() {
const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp); // Do nothing...
const auto start = op.start(); // Output is already valid after forwardDims()
const auto end = op.end();
op.getOutput(0)->getImpl()->copyCast(std::next(op.getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
} }
/////////////////////////////////////////////// ///////////////////////////////////////////////
...@@ -75,6 +69,11 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -75,6 +69,11 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape"); AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape");
mOutputs[0]->resize({roi}); mOutputs[0]->resize({roi});
// Ensure the output of this operator is valid after forwardDims():
mOutputs[0]->getImpl()->copyCast(std::next(getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ namespace Aidge { ...@@ -26,7 +26,7 @@ namespace Aidge {
// inputSize // inputSize
Elts_t StackProdConso::getRequiredMemory( Elts_t StackProdConso::getRequiredMemory(
const Aidge::IOIndex_t inputIdx, const Aidge::IOIndex_t inputIdx,
const std::vector<DimSize_t> &inputsSize) const { const std::vector<DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input"); assert(mOp.getRawInput(inputIdx) && "requires valid input");
const StackOp &op = dynamic_cast<const StackOp &>(mOp); const StackOp &op = dynamic_cast<const StackOp &>(mOp);
...@@ -62,15 +62,10 @@ void StackOpImpl::forward() { ...@@ -62,15 +62,10 @@ void StackOpImpl::forward() {
} }
StackOp::StackOp(std::uint32_t maxElements) StackOp::StackOp(std::uint32_t maxElements)
: OperatorTensor(s_type, {InputCategory::Data}, 1), : OperatorTensor(s_type, {InputCategory::Data, InputCategory::OptionalData}, 1),
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
attr<StackAttr::MaxElements>(maxElements), attr<StackAttr::MaxElements>(maxElements),
attr<StackAttr::ForwardStep>(0))) { attr<StackAttr::ForwardStep>(0))) {
if (maxElements == 0) {
AIDGE_THROW_OR_ABORT(
std::invalid_argument,
"StackOp creation failed: maxElements must be greater than 0.");
}
mImpl = std::make_shared<StackOpImpl>(*this); mImpl = std::make_shared<StackOpImpl>(*this);
} }
...@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const { ...@@ -87,8 +82,33 @@ std::shared_ptr<Aidge::Operator> Aidge::StackOp::clone() const {
return std::make_shared<StackOp>(*this); return std::make_shared<StackOp>(*this);
} }
bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { bool Aidge::StackOp::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::StackOp::forwardDims(bool allowDataDependency) {
if (inputsAssociated()) { if (inputsAssociated()) {
// Copy optional input #1 first dimension, if present, to attribute MaxElements
if (getInput(1)) {
if (!allowDataDependency) {
Log::warn("StackOp: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
std::shared_ptr<Tensor> fallback;
const auto& maxElements = getInput(1)->refCastFrom(fallback, NativeType<std::uint32_t>::type, "cpu");
AIDGE_ASSERT(maxElements.size() > 0, "Input#1 size should be > 0");
this->maxElements() = static_cast<std::uint32_t*>(maxElements.getImpl()->hostPtr())[0];
}
AIDGE_ASSERT(this->maxElements() > 0, "Input#1 first element or MaxElements attribute should be > 0");
auto inputDims = getInput(0)->dims(); auto inputDims = getInput(0)->dims();
inputDims.insert(inputDims.begin(), maxElements()); inputDims.insert(inputDims.begin(), maxElements());
getOutput(0)->resize(inputDims); getOutput(0)->resize(inputDims);
...@@ -116,7 +136,7 @@ void StackOp::forward() { ...@@ -116,7 +136,7 @@ void StackOp::forward() {
++forwardStep(); ++forwardStep();
} }
std::shared_ptr<Node> stack(std::uint32_t maxElements, std::shared_ptr<Node> Stack(std::uint32_t maxElements,
const std::string &name) { const std::string &name) {
return std::make_shared<Node>(std::make_shared<StackOp>(maxElements), return std::make_shared<Node>(std::make_shared<StackOp>(maxElements),
name); name);
......
...@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { ...@@ -56,9 +56,6 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
REQUIRE(op2.maxElements() == maxElements); REQUIRE(op2.maxElements() == maxElements);
REQUIRE(op2.forwardStep() == 0); REQUIRE(op2.forwardStep() == 0);
} }
// Invalid arguments
REQUIRE_THROWS_AS(StackOp(0), std::invalid_argument);
} }
SECTION("forwardDims") { SECTION("forwardDims") {
...@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") { ...@@ -111,7 +108,7 @@ TEST_CASE("[core/operator] Stack(forward)", "[Stack]") {
tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems); tensors[i]->getImpl()->setRawPtr(arrays[i], nbElems);
} }
auto myStack = stack(numTensors); auto myStack = Stack(numTensors);
myStack->getOperator()->setBackend("cpu"); myStack->getOperator()->setBackend("cpu");
myStack->getOperator()->setDataType(DataType::Float32); myStack->getOperator()->setDataType(DataType::Float32);
auto op = auto op =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment