Skip to content
Snippets Groups Projects
Commit 86d866d4 authored by vincent  lorrain's avatar vincent lorrain
Browse files

Merge remote-tracking branch 'origin/main' into graphRegex

parents c2f55dc6 51af88a0
No related branches found
No related tags found
1 merge request!14Graph regex
Pipeline #32148 failed
...@@ -326,7 +326,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara ...@@ -326,7 +326,7 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
// add learnable parameters to the graph // add learnable parameters to the graph
if (includeLearnableParam) { if (includeLearnableParam) {
for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) { for (IOIndex_t i = node->nbDataInputs(); i < node->nbInputs(); ++i) {
std::shared_ptr<Node> parentNode = node->getParents(static_cast<IOIndex_t>(i)); std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i));
if (parentNode) { if (parentNode) {
parentNode->addView(shared_from_this()); parentNode->addView(shared_from_this());
mNodes.insert(parentNode); mNodes.insert(parentNode);
......
...@@ -226,7 +226,7 @@ void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t ...@@ -226,7 +226,7 @@ void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t
} }
void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) { void Aidge::Node::addParent(const std::shared_ptr<Node> other_node, const IOIndex_t inId) {
if (getParents(inId) != nullptr) { if (getParent(inId) != nullptr) {
printf("Warning, you're replacing a Parent.\n"); printf("Warning, you're replacing a Parent.\n");
} }
assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound."); assert((inId != gk_IODefaultIndex) && (inId < nbInputs()) && "Input index out of bound.");
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) { for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) { if(elt.node() != nullptr) {
// >= to allow incomplete graphViews // >= to allow incomplete graphViews
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size()); assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size());
/* /*
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted. * /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted.
* Prefer a functional description for detailed inputs * Prefer a functional description for detailed inputs
*/ */
...@@ -44,7 +44,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs ...@@ -44,7 +44,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs
} }
std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>(); std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for(const OpArgs& elt : inputs) { for(const OpArgs& elt : inputs) {
if (elt.node()!=nullptr) if (elt.node()!=nullptr)
...@@ -56,7 +56,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> ...@@ -56,7 +56,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs>
} }
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> inputs) { std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs); std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection"); assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin(); std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
...@@ -70,4 +70,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> ...@@ -70,4 +70,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs>
assert(lastNode->getNbFreeDataInputs()>=1); assert(lastNode->getNbFreeDataInputs()>=1);
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex); gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv; return gv;
} }
\ No newline at end of file
...@@ -38,7 +38,18 @@ Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) co ...@@ -38,7 +38,18 @@ Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) co
Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const { Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
return mImpl->getNbProducedData(outputIdx); return mImpl->getNbProducedData(outputIdx);
} }
void Aidge::Operator::updateConsummerProducer(){
mImpl->updateConsummerProducer();
}
void Aidge::Operator::forward() { mImpl->forward(); } void Aidge::Operator::runHooks() const {
for (auto& hook : mHooks) {
hook.second->call();
}
}
void Aidge::Operator::forward() {
mImpl->forward();
runHooks();
}
void Aidge::Operator::backward() { mImpl->backward(); } void Aidge::Operator::backward() { mImpl->backward(); }
...@@ -59,12 +59,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -59,12 +59,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Step 2 : Branch existing producers & create the others // Step 2 : Branch existing producers & create the others
// link weights & bias // link weights & bias
if (matmul->getParents(1)==nullptr) { if (matmul->getParent(1)==nullptr) {
matmul->getParents(0)->addChild(fc, 0, 1); matmul->getParent(0)->addChild(fc, 0, 1);
} else { } else {
if (matmul->getParents(0)!=nullptr) if (matmul->getParent(0)!=nullptr)
matmul->getParents(0)->addChild(fc, 0, 0); matmul->getParent(0)->addChild(fc, 0, 0);
matmul->getParents(1)->addChild(fc, 0, 1); matmul->getParent(1)->addChild(fc, 0, 1);
} }
(producer_add_bias.first)->addChild(fc,0,2); (producer_add_bias.first)->addChild(fc,0,2);
......
...@@ -33,26 +33,19 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona ...@@ -33,26 +33,19 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona
fflush(stdout); fflush(stdout);
} }
// TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) {
if (frowardDims) {mGraphView->forwardDims(); }
mScheduling.clear();
// setup initial producers list // setup initial producers list
// add each Producer Node. mComputationNumber = 0;
std::set<std::shared_ptr<Node>> computationOver;
std::size_t computationNumber = 0;
std::set<std::shared_ptr<Node>> producers; std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
if (nodePtr->type() == "Producer") { if (nodePtr->type() == "Producer") {
producers.insert(nodePtr); producers.insert(nodePtr);
} else { } else {
++computationNumber; ++mComputationNumber;
} }
} }
// add Data Input // add Data Input
// FIXME : shoudl be changed when the real system for providing // FIXME : should be changed when the real system for providing
// data is implemented // data is implemented
for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) { for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) {
for (const auto& parentPtr : nodePtr->getParents()) { for (const auto& parentPtr : nodePtr->getParents()) {
...@@ -112,21 +105,10 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { ...@@ -112,21 +105,10 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) {
} }
} }
// run sequencially every runnable consumers once // Push consumers in the list of nodes to run and update the consumer producer system
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
for (const auto& runnable : runnableConsumers) { for (const auto& runnable : runnableConsumers) {
if (verbose) runnable->getOperator()->updateConsummerProducer();
printf("run: %s\n", mStaticSchedule.push_back(runnable);
(runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
else
drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(computationNumber), 50,
(std::string("running ") + runnable->type() + "_" +
std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
} }
// update producers and consumers list // update producers and consumers list
...@@ -164,18 +146,6 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { ...@@ -164,18 +146,6 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) {
} }
} }
bool computationOverForConsumer = true;
for (IOIndex_t parentIDi = 0; parentIDi < consumer->nbInputs(); ++parentIDi) {
if (consumer->getOperator()->getNbConsumedData(parentIDi) <
consumer->getOperator()->getNbRequiredData(parentIDi)) {
computationOverForConsumer = false;
break;
}
}
if (computationOverForConsumer) {
computationOver.insert(consumer);
}
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
if (consumer->getOperator()->getNbProducedData(outId) > 0) { if (consumer->getOperator()->getNbProducedData(outId) > 0) {
if (verbose) printf(" also producer\n"); if (verbose) printf(" also producer\n");
...@@ -197,8 +167,52 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) { ...@@ -197,8 +167,52 @@ void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) {
if (verbose) printf("*************\n"); if (verbose) printf("*************\n");
} while (!consumers.empty()); } while (!consumers.empty());
}
// TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
if (forwardDims) {mGraphView->forwardDims(); }
// add each Producer Node.
std::set<std::shared_ptr<Node>> computationOver;
mScheduling.clear();
this->generateScheduling();
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
for (const auto& runnable : mStaticSchedule) {
bool computationOverForConsumer = true;
for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) {
if (runnable->getOperator()->getNbConsumedData(parentIDi) <
runnable->getOperator()->getNbRequiredData(parentIDi)) {
computationOverForConsumer = false;
break;
}
}
if (computationOverForConsumer) {
computationOver.insert(runnable);
}
if (verbose)
printf("run: %s\n",
(runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
else
drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50,
(std::string("running ") + runnable->type() + "_" +
std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
}
if (!verbose) drawProgressBar(1.0, 50, " "); if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n"); printf("\n");
} }
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
...@@ -231,4 +245,4 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( ...@@ -231,4 +245,4 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
} }
return consumers; return consumers;
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment