diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp index f78598057cafe0b5b02d268bd5a73ede5a2981d8..a0d232f6b2ec30adf5c505c2e1f9acddf18c85e2 100644 --- a/include/aidge/backend/cpu.hpp +++ b/include/aidge/backend/cpu.hpp @@ -24,6 +24,7 @@ #include "aidge/backend/cpu/operator/FCImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/MatMulImpl.hpp" +#include "aidge/backend/cpu/operator/MemorizeImpl.hpp" #include "aidge/backend/cpu/operator/MulImpl.hpp" #include "aidge/backend/cpu/operator/PadImpl.hpp" #include "aidge/backend/cpu/operator/PowImpl.hpp" diff --git a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c003e7b5757740f4282e7a39300ccc118558b1c0 --- /dev/null +++ b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp @@ -0,0 +1,43 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ +#define AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include <memory> +#include <vector> + +namespace Aidge { +class MemorizeImpl_cpu : public OperatorImpl { +public: + MemorizeImpl_cpu(const Memorize_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<MemorizeImpl_cpu> create(const Memorize_Op& op) { + return std::make_unique<MemorizeImpl_cpu>(op); + } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + void updateConsummerProducer() override final; + void forward() override; +}; + +namespace { +static Registrar<Memorize_Op> registrarMemorizeImpl_cpu("cpu", Aidge::MemorizeImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ProducerImpl.hpp b/include/aidge/backend/cpu/operator/ProducerImpl.hpp index c1d27f7efc4457fd3b02b6cde006401e2ca71661..2e9c90c428c0ee746f5b483b71132b37bbbbcd06 100644 --- a/include/aidge/backend/cpu/operator/ProducerImpl.hpp +++ b/include/aidge/backend/cpu/operator/ProducerImpl.hpp @@ -29,7 +29,6 @@ public: return std::make_unique<ProducerImpl_cpu>(op); } - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; void forward() override; }; diff --git a/src/operator/MemorizeImpl.cpp b/src/operator/MemorizeImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64cb3bcf237acd0aea706d8635eb4ab5e1b947b1 --- /dev/null +++ b/src/operator/MemorizeImpl.cpp @@ -0,0 +1,80 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Memorize.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" + +#include "aidge/backend/cpu/operator/MemorizeImpl.hpp" + +Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData( + Aidge::IOIndex_t inputIdx) const +{ + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + + if (scheduleStep == 0 && inputIdx == 0) { + // No data input is required for the initial step. + // Initialization data is required however. + return 0; + } + else if (scheduleStep > 0 && inputIdx == 1) { + // No initialization data is required after the initial step. + return 0; + } + else { + return OperatorImpl::getNbRequiredData(inputIdx); + } +} + +Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbProducedData( + Aidge::IOIndex_t outputIdx) const +{ + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + + if (outputIdx == 1 && scheduleStep >= endStep) { + return endStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(); + } + else { + return scheduleStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(); + } +} + +void Aidge::MemorizeImpl_cpu::updateConsummerProducer() { + OperatorImpl::updateConsummerProducer(); + + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded"); +} + +void Aidge::MemorizeImpl_cpu::forward() { + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded"); + + if (forwardStep == 0) { + op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size()); + } + else { + op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size()); + } +} diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp index 4c5883a9b0155e7bb6e16cbac1b8de1a3a9e9e16..2decd9559a98f28e170b7ed45bafbac3643d040a 100644 --- a/src/operator/ProducerImpl.cpp +++ b/src/operator/ProducerImpl.cpp @@ -20,16 +20,6 @@ #include "aidge/backend/cpu/operator/ProducerImpl.hpp" -Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData( - Aidge::IOIndex_t outputIdx) const -{ - // Requires the whole tensors, regardless of available data on inputs - assert(outputIdx == 0 && "operator has only one output"); - (void) outputIdx; - - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); -} - void Aidge::ProducerImpl_cpu::forward() { } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 8ea8e726f286035a1059a317471b893ce4639251..e1815725f32b3afab01f691a53bfa50a2ef624ed 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -205,5 +205,54 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { SECTION("Test Residual graph") { } - SECTION("Test Recurrent graph") {} -} \ No newline at end of file + SECTION("Test Recurrent graph") { + std::shared_ptr<Tensor> in = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); + std::shared_ptr<Tensor> initTensor = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{0, 0, 0}, {1, 1, 1}}}); + std::shared_ptr<Tensor> biasTensor = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{2, 0, 0}, {1, 0, 0}}}); + + auto add1 = Add(2, "add1"); + auto mem = Memorize(3, "mem1"); + auto add2 = Add(2, "add2"); + auto bias = Producer(biasTensor, "bias"); + auto init = Producer(initTensor, "init"); + init->getOperator()->setBackend("cpu"); + init->getOperator()->setDataType(Aidge::DataType::Int32); + + std::shared_ptr<GraphView> g = Sequential({add1, mem, add2}); + init->addChild(mem, 0, 1); + mem->addChild(add1, 1, 1); + bias->addChild(add2, 0, 1); + add1->getOperator()->setInput(0, in); + // Update GraphView inputs/outputs following previous connections: + g->add(mem); + g->add(add1); + g->add(add2); + //g->add(init); // not working because of forwardDims() + // TODO: FIXME: + // forwardDims() starts with inputNodes(). If the initializer + // of Memorize is inside the graph, forwardDims() will get stuck + // to the node taking the recursive connection because Memorize + // output dims must first be computed from the initializer. + g->add(bias); + + g->setBackend("cpu"); + g->setDataType(Aidge::DataType::Int32); + g->save("graphRecurrent"); + g->forwardDims(); + SequentialScheduler scheduler(g); + REQUIRE_NOTHROW(scheduler.forward(true, true)); + scheduler.saveSchedulingDiagram("schedulingRecurrent"); + + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{5, 6, 9}, {14, 16, 19}}}); + std::shared_ptr<Tensor> result = + std::static_pointer_cast<Tensor>(g->getNode("add2")->getOperator()->getRawOutput(0)); + result->print(); + expectedOutput->print(); + bool equal = (*result == *expectedOutput); + REQUIRE(equal); + } +}