Skip to content
Snippets Groups Projects
Commit 08a66f38 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'fix/scheduler_exec_time' into 'dev'

fix/scheduler_exec_time

See merge request !100
parents 5fd1117a a1a83f25
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!100fix/scheduler_exec_time
Pipeline #42558 passed
......@@ -456,7 +456,17 @@ private:
* @param inId index for adding the parent.
*/
void addParent(const NodePtr otherNode, const IOIndex_t inId);
// OPERATOR FUNCTIONNAL but commented out to avoid iostream inclusion
// /**
// * @brief operator<< overload to ease print & debug of nodes
// * @param[inout] ostream to print to
// * @param[in] n node to print
// */
// friend std::ostream& operator << (std::ostream& os, Node& n);
};
} // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_NODE_H_ */
......@@ -121,6 +121,7 @@ private:
/** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
};
} // namespace Aidge
......
......@@ -173,7 +173,7 @@ void Aidge::Node::setInputId(const IOIndex_t inId, const IOIndex_t newNodeoutId)
"Input index ({}) is out of bound ({}) for node {} (of type {})",
inId, nbInputs(), name(), type());
if (mIdOutParents[inId] != gk_IODefaultIndex) {
fmt::print("Warning: filling a Tensor already attributed\n");
Log::warn("Warning: filling a Tensor already attributed\n");
auto originalParent = input(inId);
// remove original parent reference to child
// find the output ID for original Parent
......@@ -390,6 +390,26 @@ std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::No
return out;
}
// namespace Aidge {
// std::ostream& operator << (std::ostream& os, Aidge::Node& n) {
// using namespace std;
// os << "Node :\tName :\t\"" << n.name() << "\"\tType : \"" << n.getOperator()->type()<< "\"\tIN/OUTputs : "<< n.nbInputs() <<"/"<< n.nbOutputs() <<endl;
// os << "\tParents :\t" ;
// for (const auto & p : n.getParents())
// {
// os << "\"" <<p->name() << "\"\t";
// }
// os << endl;
// os << "\tChildren :\t" ;
// for (const auto & c : n.getChildren())
// {
// os << "\"" << c->name() << "\"\t";
// }
// os << endl;
// return os;
// }
// }
/////////////////////////////////////////////////////////////////////////////////////////////
// private
......
......@@ -84,6 +84,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
if (verbose) fmt::print("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> requiredProducers;
std::set<std::shared_ptr<Node>> priorConsumers;
mPriorCache.clear();
for (const auto& consumer : consumers) {
if (verbose) {
......@@ -620,6 +621,11 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
return priorCache->second;
}
PriorProducersConsumers prior;
IOIndex_t inputIdx = 0;
......@@ -660,5 +666,6 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}
mPriorCache.insert(std::make_pair(node, prior));
return prior;
}
......@@ -21,8 +21,8 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/Scheduler.hpp"
......@@ -30,48 +30,105 @@
using namespace Aidge;
TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
const size_t nbTests = 100;
size_t nbUnicity = 0;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomGraph randGraph;
randGraph.acyclic = true;
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10));
if (unicity1) {
for (auto& node : g1->getNodes()) {
std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
}
const auto orderedInputs = g1->getOrderedInputs();
for (const auto& input : orderedInputs) {
auto prod = Producer({16, 32});
prod->addChild(input.first, 0, input.second);
g1->add(prod);
}
g1->save("schedule");
g1->compile();
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(true);
const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(sch.begin(), sch.end(),
std::back_inserter(nodesName),
[&namePtrTable](auto val){ return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == 10 + orderedInputs.size());
const size_t nbTests = 10;
size_t nbUnicity = 0;
std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500);
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
std::mt19937 gen(rd());
RandomGraph randGraph;
const auto g1 = std::make_shared<GraphView>("g1");
const size_t nb_nodes = nb_nodes_dist(gen);
SECTION("Acyclic Graph") {
Aidge::Log::setConsoleLevel(Aidge::Log::Warn);
fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes);
randGraph.acyclic = true;
const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
// g1->save("test_graph_" + std::to_string(test));
if (unicity1) {
for (auto &node : g1->getNodes()) {
std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
->setComputeOutputDims(
GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
}
const auto orderedInputs = g1->getOrderedInputs();
for (const auto &input : orderedInputs) {
auto prod = Producer({16, 32});
prod->addChild(input.first, 0, input.second);
g1->add(prod);
}
g1->save("schedule");
g1->compile();
fmt::print("gen scheduling...\n");
auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling();
fmt::print("gen scheduling finished\n");
const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform(
sch.begin(), sch.end(), std::back_inserter(nodesName),
[&namePtrTable](auto val) { return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
++nbUnicity;
}
}
// SECTION("Cyclic graph") {
// fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes);
// randGraph.acyclic = false;
// randGraph.types={"Memorize"};
// const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
// // g1->save("test_graph_" + std::to_string(test));
// if (unicity1) {
// for (auto &node : g1->getNodes()) {
// std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
// ->setComputeOutputDims(
// GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
// }
// const auto orderedInputs = g1->getOrderedInputs();
// for (const auto &input : orderedInputs) {
// auto prod = Producer({16, 32});
// prod->addChild(input.first, 0, input.second);
// g1->add(prod);
// }
// g1->save("schedule");
// g1->forwardDims();
// fmt::print("gen scheduling...\n");
// auto scheduler = SequentialScheduler(g1);
// scheduler.generateScheduling();
// fmt::print("gen scheduling finished\n");
// const auto sch = scheduler.getStaticScheduling();
// const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
// std::vector<std::string> nodesName;
// std::transform(
// sch.begin(), sch.end(), std::back_inserter(nodesName),
// [&namePtrTable](auto val) { return namePtrTable.at(val); });
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
// fmt::print("schedule: {}\n", nodesName);
// REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
// ++nbUnicity;
// }
// }
}
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment