diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index c966b5f5c1bb4914f3e46f96493da87a6707b1ff..449235712dd2867c4644ff9cbecb029778e508e2 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -38,7 +38,9 @@ private: public: GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) : OperatorTensor(type, nbData, nbParam, nbOut) - {} + { + mImpl = std::make_shared<OperatorImpl>(*this); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -46,7 +48,9 @@ public: */ GenericOperator_Op(const GenericOperator_Op& op) : OperatorTensor(op) - {} + { + mImpl = std::make_shared<OperatorImpl>(*this); + } /** * @brief Clone the operator using its copy-constructor. @@ -58,6 +62,7 @@ public: // Helper functions that can be used with setComputeOutputDims(): static const ComputeDimsFunc Identity; + static const ComputeDimsFunc InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs); void setComputeOutputDims(ComputeDimsFunc func) { mComputeOutputDims = func; @@ -99,20 +104,6 @@ public: void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); } void setDataType(const DataType& /*datatype*/) const override { printf("setDataType: not available yet.\n"); } - void forward() override final { - if(mImpl){ - mImpl->forward(); - }else{ - printf("forward: No implementation is linked.\n"); - } - } - void backward() override final { - if(mImpl){ - mImpl->backward(); - }else{ - printf("backward: No implementation is linked.\n"); - } - } }; /** diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index fe9b044e2309eb7e724d6648b84c044d7407bafb..f0d6c29a5f39ecbd7e1b20c334368f64e745673a 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -45,6 +45,7 @@ public: Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); + mImpl = std::make_shared<OperatorImpl>(*this); } Producer_Op(const std::shared_ptr<Tensor> tensor, bool constant = false) @@ -52,6 +53,7 @@ public: Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0] = tensor; // copy the pointer of the Tensor + mImpl = std::make_shared<OperatorImpl>(*this); } /** @@ -65,7 +67,9 @@ public: for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) { mOutputs[i] = std::make_shared<Tensor>(*(op.getOutput(i))); } - mImpl = op.mImpl ? Registrar<Producer_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; + mImpl = (mOutputs[0]->getImpl() && Registrar<Producer_Op>::exists({mOutputs[0]->getImpl()->backend()})) + ? Registrar<Producer_Op>::create(mOutputs[0]->getImpl()->backend())(*this) + : std::make_shared<OperatorImpl>(*this); } /** @@ -88,7 +92,9 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - mImpl = Registrar<Producer_Op>::create(name)(*this); + if (Registrar<Producer_Op>::exists({name})) { + mImpl = Registrar<Producer_Op>::create({name})(*this); + } mOutputs[0]->setBackend(name, device); } diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index 192036651cfbe2df71139dd63ca3d71f07300964..5556f4ff5c87d1adc23f5bff1aaf90c230de06cc 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -15,3 +15,7 @@ const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity = [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; }; + +const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs) { + return [nbOutputs, inputIdx](const std::vector<std::vector<size_t>>& inputsDims) { return std::vector<std::vector<size_t>>(nbOutputs, inputsDims[inputIdx]); }; +} diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb75669f382b4352492ccbf22f9c918bcbe18033 --- /dev/null +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -0,0 +1,78 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::sort +#include <cassert> +#include <map> +#include <memory> +#include <set> +#include <string> + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/scheduler/Scheduler.hpp" + +using namespace Aidge; + +TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { + const size_t nbTests = 100; + size_t nbUnicity = 0; + + for (int test = 0; test < nbTests; ++test) { + std::random_device rd; + const std::mt19937::result_type seed(rd()); + + RandomGraph randGraph; + randGraph.acyclic = true; + const auto g1 = std::make_shared<GraphView>("g1"); + const bool unicity1 = g1->add(randGraph.gen(seed, 10)); + + if (unicity1) { + for (auto& node : g1->getNodes()) { + std::static_pointer_cast<GenericOperator_Op>(node->getOperator())->setComputeOutputDims(GenericOperator_Op::InputIdentity(0, node->nbOutputs())); + } + + const auto orderedInputs = g1->getOrderedInputs(); + for (const auto& input : orderedInputs) { + auto prod = Producer({16, 32}); + prod->addChild(input.first, 0, input.second); + g1->add(prod); + } + + g1->save("schedule"); + g1->forwardDims(); + + auto scheduler = SequentialScheduler(g1); + scheduler.generateScheduling(true); + const auto sch = scheduler.getStaticScheduling(); + + std::map<std::shared_ptr<Node>, std::string> namePtrTable + = g1->getRankedNodesName("{0} ({1}#{3})"); + + std::vector<std::string> nodesName; + std::transform(sch.begin(), sch.end(), + std::back_inserter(nodesName), + [&namePtrTable](auto val){ return namePtrTable[val].c_str(); }); + + fmt::print("schedule: {}\n", nodesName); + REQUIRE(sch.size() == 10 + orderedInputs.size()); + } + } + + printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); +}