Skip to content
Snippets Groups Projects
Commit f1afe398 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working concept

parent 1bd36647
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
Pipeline #38317 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_MEMORIZE_H_
#define AIDGE_CORE_OPERATOR_MEMORIZE_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/StaticAttributes.hpp"
namespace Aidge {
enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep };
class Memorize_Op : public OperatorTensor,
public Registrable<Memorize_Op, std::string, std::unique_ptr<OperatorImpl>(const Memorize_Op&)>,
public StaticAttributes<MemorizeAttr, unsigned int, unsigned int, unsigned int> {
public:
static const std::string Type;
using Attributes_ = StaticAttributes<MemorizeAttr, unsigned int, unsigned int, unsigned int>;
template <MemorizeAttr e>
using attr = typename Attributes_::template attr<e>;
Memorize_Op(const unsigned int endStep)
: OperatorTensor(Type, 2, 0, 2),
Attributes_(attr<MemorizeAttr::ScheduleStep>(0),
attr<MemorizeAttr::ForwardStep>(0),
attr<MemorizeAttr::EndStep>(endStep))
{
mOutputs[1] = mOutputs[0];
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Memorize_Op(const Memorize_Op& op)
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Memorize_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
mOutputs[1] = mOutputs[0];
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Memorize_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Memorize_Op>(*this);
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Memorize_Op>::create({name})(*this);
mOutputs[0]->setBackend(name, device);
}
void computeOutputDims() override;
bool outputDimsForwarded() const override;
void updateConsummerProducer() override;
void forward() override;
static const std::vector<std::string> getInputsName(){
return {"data_input", "data_input_init"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output", "data_output_rec"};
}
};
inline std::shared_ptr<Node> Memorize(const unsigned int endStep, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Memorize_Op>(endStep), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::MemorizeAttr>::data[] = {
"ScheduleStep",
"ForwardStep",
"EndStep"
};
}
#endif /* AIDGE_CORE_OPERATOR_MEMORIZE_H_ */
\ No newline at end of file
...@@ -306,7 +306,12 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { ...@@ -306,7 +306,12 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
nextList.insert(nodePtr); nextList.insert(nodePtr);
} else { // compute output dimensions of children } else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end()); for (auto child : children) {
const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator());
if (!childOp->outputDimsForwarded()) {
nextList.insert(child);
}
}
} }
} }
} }
...@@ -319,6 +324,9 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { ...@@ -319,6 +324,9 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
} }
} }
} }
AIDGE_INTERNAL_ASSERT(nextList != listNodes);
if (!nextList.empty()) { if (!nextList.empty()) {
_forwardDims(nextList); _forwardDims(nextList);
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Memorize.hpp"
const std::string Aidge::Memorize_Op::Type = "Memorize";
void Aidge::Memorize_Op::computeOutputDims() {
// Only require input #1 dims (initialization of the Memorize operator)
// Otherwise, forwardDims() won't converge!
bool associated = (nbInputs() > 0); // do not compute anything if no input
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input #1 should be associated with a Tensor");
}
associated &= !(getInput(1)->empty());
if (associated) {
const auto expectedDims = getInput(1)->dims();
mOutputs[0]->resize(expectedDims);
}
}
bool Aidge::Memorize_Op::outputDimsForwarded() const {
// Only check the output dims
bool forwarded = true;
// check outputs have been filled
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty());
}
return forwarded;
}
void Aidge::Memorize_Op::updateConsummerProducer() {
Operator::updateConsummerProducer();
++this->template getAttr<MemorizeAttr::ScheduleStep>();
this->template getAttr<MemorizeAttr::ForwardStep>() = 0;
}
void Aidge::Memorize_Op::forward() {
Operator::forward();
++this->template getAttr<MemorizeAttr::ForwardStep>();
this->template getAttr<MemorizeAttr::ScheduleStep>() = 0;
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('['); putchar('[');
...@@ -43,7 +44,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -43,7 +44,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// setup initial producers list // setup initial producers list
std::set<std::shared_ptr<Node>> producers; std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
if (nodePtr->type() == "Producer") { if (nodePtr->type() == Producer_Op::Type) {
producers.insert(nodePtr); producers.insert(nodePtr);
} }
} }
...@@ -64,7 +65,62 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -64,7 +65,62 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
/* It may not be necessary to initialize producer */ /* It may not be necessary to initialize producer */
std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
std::set<std::shared_ptr<Node>> frozenConsumers;
do { do {
// Check required producers
std::set<std::shared_ptr<Node>> requiredProducers;
if (verbose) printf("Required producers:\n");
for (const auto& consumer : consumers) {
if (verbose) {
printf("\t- consumer: "
"\x1b[1;37m"
"%s"
"\x1b[0m"
"\n",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
}
std::set<std::shared_ptr<Node>> consumerRequiredProducers;
bool requiredProducerOnly = true;
IOIndex_t inputIdx = 0;
for (const auto& consumerParent : consumer->inputs()) {
if (verbose) printf("\t\t#%u: ", inputIdx);
if (consumerParent.first &&
(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf("required data from %s: C%zu + R%zu > P%zu\n",
consumerParent.first->type().c_str(),
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
if (consumerParent.first->type() == Producer_Op::Type) {
consumerRequiredProducers.insert(consumerParent.first);
}
else {
requiredProducerOnly = false;
break;
}
}
else {
if (verbose) printf("no data required\n");
}
++inputIdx;
}
if (requiredProducerOnly) {
requiredProducers.insert(consumerRequiredProducers.begin(), consumerRequiredProducers.end());
}
}
// Make producers generate the required data
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
mStaticSchedule.push_back(requiredProducer);
}
// find runnable consumers // find runnable consumers
std::set<std::shared_ptr<Node>> runnableConsumers; std::set<std::shared_ptr<Node>> runnableConsumers;
if (verbose) printf("List of layers receiving data:\n"); if (verbose) printf("List of layers receiving data:\n");
...@@ -74,7 +130,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -74,7 +130,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
"\x1b[1;37m" "\x1b[1;37m"
"%s" "%s"
"\x1b[0m" "\x1b[0m"
"\n\t\tR/C:\t", "\n\t\tC/R:\t",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), printf("%zu/%zu\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
...@@ -89,18 +145,25 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -89,18 +145,25 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
printf("\n"); printf("\n");
} }
bool isRunnable = true; bool isRunnable = true;
IOIndex_t parentID = 0; // FIXME: handle this correctly IOIndex_t inputIdx = 0; // FIXME: handle this correctly
// Check every input has got enought data to run // Check every input has got enought data to run
for (const auto& consumerParent : consumer->dataInputs()) { for (const auto& consumerParent : consumer->inputs()) {
if (consumerParent.first && if (consumerParent.first &&
consumer->getOperator()->getNbRequiredData(parentID++) > (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf(" not runnable: C%zu + R%zu > P%zu\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
// not enough data to run // not enough data to run
isRunnable = false; isRunnable = false;
break; break;
} }
++inputIdx;
} }
if (isRunnable) { if (isRunnable) {
...@@ -115,13 +178,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -115,13 +178,22 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
mStaticSchedule.push_back(runnable); mStaticSchedule.push_back(runnable);
} }
if (runnableConsumers.empty()) {
if (frozenConsumers.empty()) {
frozenConsumers = consumers;
}
}
else {
frozenConsumers.clear();
}
// update producers and consumers list // update producers and consumers list
if (verbose) printf("Updating producer and consumer lists...\n"); if (verbose) printf("Updating producer and consumer lists...\n");
const auto oldConsumers = consumers; const auto oldConsumers = consumers;
for (const auto& consumer : oldConsumers) { for (const auto& consumer : oldConsumers) {
if (verbose) { if (verbose) {
printf("\t- consumer: %s\n\t\tR/C:\t", printf("\t- consumer: %s\n\t\tC/R:\t",
(consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
...@@ -138,16 +210,21 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -138,16 +210,21 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
} }
bool isStillConsumer = false; bool isStillConsumer = false;
IOIndex_t parentID = 0; // FIXME: handle this correctly IOIndex_t inputIdx = 0; // FIXME: handle this correctly
// should we check input or dataInput ? // should we check input or dataInput ?
for (const auto& consumerParent : consumer->inputs()) { for (const auto& consumerParent : consumer->inputs()) {
if (consumerParent.first && if (consumerParent.first &&
consumer->getOperator()->getNbConsumedData(parentID++) < consumer->getOperator()->getNbConsumedData(inputIdx) <
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf(" still consumer: C%zu < P%zu\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
// there is still data to consume // there is still data to consume
isStillConsumer = true; isStillConsumer = true;
break; break;
} }
++inputIdx;
} }
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
...@@ -169,9 +246,15 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -169,9 +246,15 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
} }
} }
if (verbose) printf("*************\n"); if (verbose) printf("********************\n");
} while (!consumers.empty()); } while (!consumers.empty() && consumers != frozenConsumers);
if (verbose) {
if (!consumers.empty()) {
printf("*** Frozen state ***\n");
printf("********************\n");
}
}
} }
// TODO: handle multiple inputs/outputs // TODO: handle multiple inputs/outputs
...@@ -183,7 +266,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { ...@@ -183,7 +266,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
// If scheduling was already generated (in one or several steps, i.e. one or // If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice // several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) { if (mStaticSchedule.empty()) {
this->generateScheduling(); this->generateScheduling(verbose);
} }
// Clear previous scheduling results // Clear previous scheduling results
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment