Skip to content
Snippets Groups Projects
Commit 14b20fe1 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : change type of nodes generated

parent 0c4623cd
No related branches found
No related tags found
No related merge requests found
Pipeline #42004 passed
...@@ -85,48 +85,49 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -85,48 +85,49 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
++nbUnicity; ++nbUnicity;
} }
} }
SECTION("Cyclic graph") { // SECTION("Cyclic graph") {
fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes); // fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes);
randGraph.acyclic = false; // randGraph.acyclic = false;
// randGraph.types={"Memorize"};
const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
// g1->save("test_graph_" + std::to_string(test)); // const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
// // g1->save("test_graph_" + std::to_string(test));
if (unicity1) {
for (auto &node : g1->getNodes()) { // if (unicity1) {
std::static_pointer_cast<GenericOperator_Op>(node->getOperator()) // for (auto &node : g1->getNodes()) {
->setComputeOutputDims( // std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
GenericOperator_Op::InputIdentity(0, node->nbOutputs())); // ->setComputeOutputDims(
} // GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
// }
const auto orderedInputs = g1->getOrderedInputs();
for (const auto &input : orderedInputs) { // const auto orderedInputs = g1->getOrderedInputs();
auto prod = Producer({16, 32}); // for (const auto &input : orderedInputs) {
prod->addChild(input.first, 0, input.second); // auto prod = Producer({16, 32});
g1->add(prod); // prod->addChild(input.first, 0, input.second);
} // g1->add(prod);
// }
g1->save("schedule");
g1->forwardDims(); // g1->save("schedule");
// g1->forwardDims();
fmt::print("gen scheduling...\n");
auto scheduler = SequentialScheduler(g1); // fmt::print("gen scheduling...\n");
scheduler.generateScheduling(); // auto scheduler = SequentialScheduler(g1);
fmt::print("gen scheduling finished\n"); // scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling(); // fmt::print("gen scheduling finished\n");
// const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
// const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::string> nodesName;
std::transform( // std::vector<std::string> nodesName;
sch.begin(), sch.end(), std::back_inserter(nodesName), // std::transform(
[&namePtrTable](auto val) { return namePtrTable.at(val); }); // sch.begin(), sch.end(), std::back_inserter(nodesName),
// [&namePtrTable](auto val) { return namePtrTable.at(val); });
fmt::print("schedule: {}\n", nodesName);
REQUIRE(sch.size() == nb_nodes + orderedInputs.size()); // fmt::print("schedule: {}\n", nodesName);
++nbUnicity; // REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
} // ++nbUnicity;
} // }
// }
} }
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment