Skip to content
Snippets Groups Projects
Commit ef66ded3 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working LSTM with any sequence length

parent e5cc0ce2
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #39011 failed
......@@ -27,6 +27,7 @@
#include "aidge/backend/cpu/operator/MemorizeImpl.hpp"
#include "aidge/backend/cpu/operator/MulImpl.hpp"
#include "aidge/backend/cpu/operator/PadImpl.hpp"
#include "aidge/backend/cpu/operator/PopImpl.hpp"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include "aidge/backend/cpu/operator/ProducerImpl.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_POPIMPL_H_
#define AIDGE_CPU_OPERATOR_POPIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Pop_Op;
// compute kernel registry for forward and backward
class PopImplForward_cpu
: public Registrable<PopImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class PopImplBackward_cpu
: public Registrable<PopImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class PopImpl_cpu : public OperatorImpl {
public:
PopImpl_cpu(const Pop_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<PopImpl_cpu> create(const Pop_Op& op) {
return std::make_unique<PopImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Pop_Op> registrarPopImpl_cpu("cpu", Aidge::PopImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_POPIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Pop.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/PopImpl.hpp"
Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::PopImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
const unsigned int forwardStep = op.template getAttr<PopAttr::ForwardStep>();
*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))
= std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->extract({forwardStep});
}
......@@ -14,6 +14,7 @@
#include <cstdlib>
#include <memory>
#include "aidge/utils/TensorUtils.hpp"
#include "aidge/backend/cpu/operator/ConvImpl.hpp"
#include "aidge/backend/cpu/operator/PadImpl.hpp"
#include "aidge/data/Tensor.hpp"
......@@ -21,6 +22,7 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/operator/Pop.hpp"
using namespace Aidge;
......@@ -197,8 +199,8 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -207,12 +209,12 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
Array2D<float, 1, 64>{{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
op->associateInput(0, myInput);
op->associateInput(1, myInit);
op->associateInput(2, myInit);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
op->computeOutputDims();
microGraph->save("lstm_dims", true, true);
REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_dims", false, false);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
......@@ -230,7 +232,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -278,4 +280,69 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
REQUIRE(approxEq<float>(*(op->getOutput(1)), *myHiddenState));
}
SECTION("LSTM(forward_values_seq)") {
auto pop = Pop();
auto myLSTM = LSTM(2, 3, 2, true, "ltsm");
auto op = std::static_pointer_cast<MetaOperator_Op>(myLSTM->getOperator());
// NOTE: LSTM really need to be flatten in the graph before execution.
// Here, we actually don't use the meta-op as a closed black-box,
// because its scheduling cannot be run independently of the input.
// Since we use the Pop operator to generate sequential inputs, running
// the meta-op internal scheduler would not work because it would not
// update its input!
// We just borrow its micro-graph into our larger myGraph graph.
auto myGraph = std::make_shared<GraphView>();
myGraph->add(pop);
myGraph->add(op->getMicroGraph());
pop->addChild(op->getMicroGraph()->getOrderedInputs()[0].first, 0, 0);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
pop->getOperator()->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
// Weights X
op->associateInput(1, myInitW);
op->associateInput(2, myInitW);
op->associateInput(3, myInitW);
op->associateInput(4, myInitW);
// Weights H
op->associateInput(5, myInitR);
op->associateInput(6, myInitR);
op->associateInput(7, myInitR);
op->associateInput(8, myInitR);
myGraph->setDataType(DataType::Float32);
myGraph->setBackend("cpu");
std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.24439372, 0.24439372, 0.24439372},
{0.49801484, 0.49801484, 0.49801484},
{0.67162132, 0.67162132, 0.67162132}}});
auto scheduler = SequentialScheduler(myGraph);
scheduler.forward();
scheduler.saveSchedulingDiagram("lstm_seq_schedule");
myGraph->save("lstm_seq", true, true);
op->getOutput(1)->print();
myHiddenState->print();
REQUIRE(approxEq<float>(*(op->getOutput(1)), *myHiddenState));
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment