Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2025 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_Scheduler.cpp 4.47 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <algorithm> // std::sort
#include <cassert>
#include <map>
#include <memory>
#include <set>
#include <string>

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/Scheduler.hpp"

using namespace Aidge;

TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
  const size_t nbTests = 10;
  size_t nbUnicity = 0;
  std::uniform_int_distribution<std::size_t> nb_nodes_dist(100, 500);

  for (int test = 0; test < nbTests; ++test) {
    std::random_device rd;
    const std::mt19937::result_type seed(rd());
    std::mt19937 gen(rd());

    RandomGraph randGraph;
    const auto g1 = std::make_shared<GraphView>("g1");
    const size_t nb_nodes = nb_nodes_dist(gen);

    SECTION("Acyclic Graph") {
      fmt::print("gen acyclic graph of {} nodes...\n", nb_nodes);
      randGraph.acyclic = true;

      const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
      // g1->save("test_graph_" + std::to_string(test));

      if (unicity1) {
        for (auto &node : g1->getNodes()) {
          std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
              ->setComputeOutputDims(
                  GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
        }

        const auto orderedInputs = g1->getOrderedInputs();
        for (const auto &input : orderedInputs) {
          auto prod = Producer({16, 32});
          prod->addChild(input.first, 0, input.second);
          g1->add(prod);
        }

        g1->save("schedule");
        g1->forwardDims();

        fmt::print("gen scheduling...\n");
        auto scheduler = SequentialScheduler(g1);
        scheduler.generateScheduling();
        fmt::print("gen scheduling finished\n");
        const auto sch = scheduler.getStaticScheduling();

        const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");

        std::vector<std::string> nodesName;
        std::transform(
            sch.begin(), sch.end(), std::back_inserter(nodesName),
            [&namePtrTable](auto val) { return namePtrTable.at(val); });

        fmt::print("schedule: {}\n", nodesName);
        REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
        ++nbUnicity;
      }
    }
    // SECTION("Cyclic graph") {
    //   fmt::print("gen cyclic graph of {} nodes...\n", nb_nodes);
    //   randGraph.acyclic = false;
    //   randGraph.types={"Memorize"};

    //   const bool unicity1 = g1->add(randGraph.gen(seed, nb_nodes));
    //   // g1->save("test_graph_" + std::to_string(test));

    //   if (unicity1) {
    //     for (auto &node : g1->getNodes()) {
    //       std::static_pointer_cast<GenericOperator_Op>(node->getOperator())
    //           ->setComputeOutputDims(
    //               GenericOperator_Op::InputIdentity(0, node->nbOutputs()));
    //     }

    //     const auto orderedInputs = g1->getOrderedInputs();
    //     for (const auto &input : orderedInputs) {
    //       auto prod = Producer({16, 32});
    //       prod->addChild(input.first, 0, input.second);
    //       g1->add(prod);
    //     }

    //     g1->save("schedule");
    //     g1->forwardDims();

    //     fmt::print("gen scheduling...\n");
    //     auto scheduler = SequentialScheduler(g1);
    //     scheduler.generateScheduling();
    //     fmt::print("gen scheduling finished\n");
    //     const auto sch = scheduler.getStaticScheduling();

    //     const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");

    //     std::vector<std::string> nodesName;
    //     std::transform(
    //         sch.begin(), sch.end(), std::back_inserter(nodesName),
    //         [&namePtrTable](auto val) { return namePtrTable.at(val); });

    //     fmt::print("schedule: {}\n", nodesName);
    //     REQUIRE(sch.size() == nb_nodes + orderedInputs.size());
    //     ++nbUnicity;
    //   }
    // }
  }
  fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
}