Skip to content
Snippets Groups Projects
Commit b6505394 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : added @olivierbichler fix for scheduling + added FAILING test for cyclical graphs

parent b89440d2
No related branches found
No related tags found
No related merge requests found
...@@ -116,6 +116,7 @@ private: ...@@ -116,6 +116,7 @@ private:
/** @brief List of nodes ordered by their */ /** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
size_t mStaticScheduleStep = 0; size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -171,7 +171,6 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const { ...@@ -171,7 +171,6 @@ Aidge::IOIndex_t Aidge::Node::nbValidOutputs() const {
void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) { void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId) {
assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index"); assert(inId != gk_IODefaultIndex && (inId < nbInputs()) && "Must be a valid index");
if (mIdOutParents[inId] != gk_IODefaultIndex) { if (mIdOutParents[inId] != gk_IODefaultIndex) {
fmt::print("Warning: filling a Tensor already attributed\n");
auto originalParent = input(inId); auto originalParent = input(inId);
// remove original parent reference to child // remove original parent reference to child
// find the output ID for original Parent // find the output ID for original Parent
......
...@@ -83,6 +83,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -83,6 +83,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
if (verbose) fmt::print("List of consumers with their priors:\n"); if (verbose) fmt::print("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> requiredProducers;
std::set<std::shared_ptr<Node>> priorConsumers; std::set<std::shared_ptr<Node>> priorConsumers;
mPriorCache.clear();
for (const auto& consumer : consumers) { for (const auto& consumer : consumers) {
if (verbose) { if (verbose) {
...@@ -567,12 +568,15 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared ...@@ -567,12 +568,15 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const const std::shared_ptr<Node>& node) const
{ {
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
return priorCache->second;
}
PriorProducersConsumers prior; PriorProducersConsumers prior;
IOIndex_t inputIdx = 0; IOIndex_t inputIdx = 0;
std::cout << *node << std::endl;
for (const auto& parent : node->inputs()) { for (const auto& parent : node->inputs()) {
std::cout << "parent.first " << *parent.first << std::endl;
if (parent.first && if (parent.first &&
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second)) parent.first->getOperator()->getNbProducedData(parent.second))
...@@ -609,5 +613,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: ...@@ -609,5 +613,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
if (prior.priorConsumers.empty()) { if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node); prior.priorConsumers.insert(node);
} }
mPriorCache.insert(std::make_pair(node, prior));
return prior; return prior;
} }
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
...@@ -30,53 +30,103 @@ ...@@ -30,53 +30,103 @@
using namespace Aidge; using namespace Aidge;
TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
const size_t nbTests = 10; const size_t nbTests = 10;
size_t nbUnicity = 0; size_t nbUnicity = 0;
std::uniform_int_distribution<std::size_t> nb_nodes_dist(70, 100); std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500);
for (int test = 0; test < nbTests; ++test) { for (int test = 0; test < nbTests; ++test) {
std::random_device rd; std::random_device rd;
const std::mt19937::result_type seed(rd()); const std::mt19937::result_type seed(rd());
std::mt19937 gen(rd()); std::mt19937 gen(rd());
RandomGraph randGraph; RandomGraph randGraph;
randGraph.acyclic = true; const auto g1 = std::make_shared<GraphView>("g1");
const auto g1 = std::make_shared<GraphView>("g1"); const size_t nb_nodes = nb_nodes_dist(gen);
// const size_t nb_nodes = nb_nodes_dist(gen);
const size_t nb_nodes = 85; SECTION("Acyclic Graph") {
const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes)); fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes);
g1->save("test_graph_"+std::to_string(test)); randGraph.acyclic = true;
if (unicity1) { const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
for (auto& node : g1->getNodes()) { // g1->save("test_graph_" + std::to_string(test));
std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
} if (unicity1) {
for (auto &node : g1->getNodes()) {
const auto orderedInputs = g1->getOrderedInputs(); std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
for (const auto& input : orderedInputs) { ->setComputeOutputDims(
auto prod = Producer({16, 32}); GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
prod->addChild(input.first, 0, input.second); }
g1->add(prod);
} const auto orderedInputs = g1->getOrderedInputs();
for (const auto &input : orderedInputs) {
g1->save("schedule"); auto prod = Producer({16, 32});
g1->forwardDims(); prod->addChild(input.first, 0, input.second);
g1->add(prod);
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(true);
const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(sch.begin(), sch.end(),
std::back_inserter(nodesName),
[&namePtrTable](auto val){ return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
CHECK(sch.size() == nb_nodes + orderedInputs.size());
} }
g1->save("schedule");
g1->forwardDims();
fmt::print("gen scheduling...\n");
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling();
fmt::print("gen scheduling finished\n");
const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(
sch.begin(), sch.end(), std::back_inserter(nodesName),
[&namePtrTable](auto val) { return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
++nbUnicity;
}
} }
SECTION("Cyclic graph") {
fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes);
randGraph.acyclic = false;
const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
// g1->save("test_graph_" + std::to_string(test));
if (unicity1) {
for (auto &node : g1->getNodes()) {
std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
->setComputeOutputDims(
GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
}
const auto orderedInputs = g1->getOrderedInputs();
for (const auto &input : orderedInputs) {
auto prod = Producer({16, 32});
prod->addChild(input.first, 0, input.second);
g1->add(prod);
}
g1->save("schedule");
g1->forwardDims();
fmt::print("gen scheduling...\n");
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling();
fmt::print("gen scheduling finished\n");
const auto sch = scheduler.getStaticScheduling();
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(
sch.begin(), sch.end(), std::back_inserter(nodesName),
[&namePtrTable](auto val) { return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
++nbUnicity;
}
}
}
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment