From f1afe39803d20c1eaaf49529a453d9a058d5d15b Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 5 Feb 2024 14:34:29 +0100 Subject: [PATCH] Working concept --- include/aidge/operator/Memorize.hpp | 101 ++++++++++++++++++++++++++ src/graph/GraphView.cpp | 10 ++- src/operator/Memorize.cpp | 52 ++++++++++++++ src/scheduler/Scheduler.cpp | 105 +++++++++++++++++++++++++--- 4 files changed, 256 insertions(+), 12 deletions(-) create mode 100644 include/aidge/operator/Memorize.hpp create mode 100644 src/operator/Memorize.cpp diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp new file mode 100644 index 000000000..53340f866 --- /dev/null +++ b/include/aidge/operator/Memorize.hpp @@ -0,0 +1,101 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_MEMORIZE_H_ +#define AIDGE_CORE_OPERATOR_MEMORIZE_H_ + +#include <cassert> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/StaticAttributes.hpp" + +namespace Aidge { +enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep }; + +class Memorize_Op : public OperatorTensor, + public Registrable<Memorize_Op, std::string, std::unique_ptr<OperatorImpl>(const Memorize_Op&)>, + public StaticAttributes<MemorizeAttr, unsigned int, unsigned int, unsigned int> { +public: + static const std::string Type; + + using Attributes_ = StaticAttributes<MemorizeAttr, unsigned int, unsigned int, unsigned int>; + template <MemorizeAttr e> + using attr = typename Attributes_::template attr<e>; + + Memorize_Op(const unsigned int endStep) + : OperatorTensor(Type, 2, 0, 2), + Attributes_(attr<MemorizeAttr::ScheduleStep>(0), + attr<MemorizeAttr::ForwardStep>(0), + attr<MemorizeAttr::EndStep>(endStep)) + { + mOutputs[1] = mOutputs[0]; + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Memorize_Op(const Memorize_Op& op) + : OperatorTensor(op), + Attributes_(op) + { + mImpl = op.mImpl ? Registrar<Memorize_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; + mOutputs[1] = mOutputs[0]; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Memorize_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Memorize_Op>(*this); + } + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { + mImpl = Registrar<Memorize_Op>::create({name})(*this); + mOutputs[0]->setBackend(name, device); + } + + void computeOutputDims() override; + bool outputDimsForwarded() const override; + void updateConsummerProducer() override; + void forward() override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input", "data_input_init"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output", "data_output_rec"}; + } +}; + +inline std::shared_ptr<Node> Memorize(const unsigned int endStep, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Memorize_Op>(endStep), name); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::MemorizeAttr>::data[] = { + "ScheduleStep", + "ForwardStep", + "EndStep" +}; +} + +#endif /* AIDGE_CORE_OPERATOR_MEMORIZE_H_ */ \ No newline at end of file diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 968e98e75..5c7a524af 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -306,7 +306,12 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { nextList.insert(nodePtr); } else { // compute output dimensions of children std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); - nextList.insert(children.begin(), children.end()); + for (auto child : children) { + const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator()); + if (!childOp->outputDimsForwarded()) { + nextList.insert(child); + } + } } } } @@ -319,6 +324,9 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { } } } + + AIDGE_INTERNAL_ASSERT(nextList != listNodes); + if (!nextList.empty()) { _forwardDims(nextList); } diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp new file mode 100644 index 000000000..714bf97fc --- /dev/null +++ b/src/operator/Memorize.cpp @@ -0,0 +1,52 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Memorize.hpp" + +const std::string Aidge::Memorize_Op::Type = "Memorize"; + +void Aidge::Memorize_Op::computeOutputDims() { + // Only require input #1 dims (initialization of the Memorize operator) + // Otherwise, forwardDims() won't converge! + bool associated = (nbInputs() > 0); // do not compute anything if no input + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input #1 should be associated with a Tensor"); + } + associated &= !(getInput(1)->empty()); + + if (associated) { + const auto expectedDims = getInput(1)->dims(); + mOutputs[0]->resize(expectedDims); + } +} + +bool Aidge::Memorize_Op::outputDimsForwarded() const { + // Only check the output dims + bool forwarded = true; + // check outputs have been filled + for (IOIndex_t i = 0; i < nbOutputs(); ++i) { + forwarded &= !(getOutput(i)->empty()); + } + return forwarded; +} + +void Aidge::Memorize_Op::updateConsummerProducer() { + Operator::updateConsummerProducer(); + ++this->template getAttr<MemorizeAttr::ScheduleStep>(); + this->template getAttr<MemorizeAttr::ForwardStep>() = 0; +} + +void Aidge::Memorize_Op::forward() { + Operator::forward(); + ++this->template getAttr<MemorizeAttr::ForwardStep>(); + this->template getAttr<MemorizeAttr::ScheduleStep>() = 0; +} diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 3afbcd044..953804f62 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -20,6 +20,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); @@ -43,7 +44,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // setup initial producers list std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { - if (nodePtr->type() == "Producer") { + if (nodePtr->type() == Producer_Op::Type) { producers.insert(nodePtr); } } @@ -64,7 +65,62 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { /* It may not be necessary to initialize producer */ std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); + std::set<std::shared_ptr<Node>> frozenConsumers; do { + // Check required producers + std::set<std::shared_ptr<Node>> requiredProducers; + if (verbose) printf("Required producers:\n"); + + for (const auto& consumer : consumers) { + if (verbose) { + printf("\t- consumer: " + "\x1b[1;37m" + "%s" + "\x1b[0m" + "\n", + (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); + } + + std::set<std::shared_ptr<Node>> consumerRequiredProducers; + bool requiredProducerOnly = true; + IOIndex_t inputIdx = 0; + for (const auto& consumerParent : consumer->inputs()) { + if (verbose) printf("\t\t#%u: ", inputIdx); + + if (consumerParent.first && + (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { + if (verbose) printf("required data from %s: C%zu + R%zu > P%zu\n", + consumerParent.first->type().c_str(), + consumer->getOperator()->getNbConsumedData(inputIdx), + consumer->getOperator()->getNbRequiredData(inputIdx), + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + + if (consumerParent.first->type() == Producer_Op::Type) { + consumerRequiredProducers.insert(consumerParent.first); + } + else { + requiredProducerOnly = false; + break; + } + } + else { + if (verbose) printf("no data required\n"); + } + ++inputIdx; + } + + if (requiredProducerOnly) { + requiredProducers.insert(consumerRequiredProducers.begin(), consumerRequiredProducers.end()); + } + } + + // Make producers generate the required data + for (const auto& requiredProducer : requiredProducers) { + requiredProducer->getOperator()->updateConsummerProducer(); + mStaticSchedule.push_back(requiredProducer); + } + // find runnable consumers std::set<std::shared_ptr<Node>> runnableConsumers; if (verbose) printf("List of layers receiving data:\n"); @@ -74,7 +130,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { "\x1b[1;37m" "%s" "\x1b[0m" - "\n\t\tR/C:\t", + "\n\t\tC/R:\t", (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), @@ -89,18 +145,25 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } + bool isRunnable = true; - IOIndex_t parentID = 0; // FIXME: handle this correctly + IOIndex_t inputIdx = 0; // FIXME: handle this correctly // Check every input has got enought data to run - for (const auto& consumerParent : consumer->dataInputs()) { + for (const auto& consumerParent : consumer->inputs()) { if (consumerParent.first && - consumer->getOperator()->getNbRequiredData(parentID++) > + (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { + if (verbose) printf(" not runnable: C%zu + R%zu > P%zu\n", + consumer->getOperator()->getNbConsumedData(inputIdx), + consumer->getOperator()->getNbRequiredData(inputIdx), + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + // not enough data to run isRunnable = false; break; } + ++inputIdx; } if (isRunnable) { @@ -115,13 +178,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { mStaticSchedule.push_back(runnable); } + if (runnableConsumers.empty()) { + if (frozenConsumers.empty()) { + frozenConsumers = consumers; + } + } + else { + frozenConsumers.clear(); + } + // update producers and consumers list if (verbose) printf("Updating producer and consumer lists...\n"); const auto oldConsumers = consumers; for (const auto& consumer : oldConsumers) { if (verbose) { - printf("\t- consumer: %s\n\t\tR/C:\t", + printf("\t- consumer: %s\n\t\tC/R:\t", (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), @@ -138,16 +210,21 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } bool isStillConsumer = false; - IOIndex_t parentID = 0; // FIXME: handle this correctly + IOIndex_t inputIdx = 0; // FIXME: handle this correctly // should we check input or dataInput ? for (const auto& consumerParent : consumer->inputs()) { if (consumerParent.first && - consumer->getOperator()->getNbConsumedData(parentID++) < + consumer->getOperator()->getNbConsumedData(inputIdx) < consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { + if (verbose) printf(" still consumer: C%zu < P%zu\n", + consumer->getOperator()->getNbConsumedData(inputIdx), + consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + // there is still data to consume isStillConsumer = true; break; } + ++inputIdx; } for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { @@ -169,9 +246,15 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } } - if (verbose) printf("*************\n"); - } while (!consumers.empty()); + if (verbose) printf("********************\n"); + } while (!consumers.empty() && consumers != frozenConsumers); + if (verbose) { + if (!consumers.empty()) { + printf("*** Frozen state ***\n"); + printf("********************\n"); + } + } } // TODO: handle multiple inputs/outputs @@ -183,7 +266,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice if (mStaticSchedule.empty()) { - this->generateScheduling(); + this->generateScheduling(verbose); } // Clear previous scheduling results -- GitLab