From d39f04f3c3662e0eafdeb2a805fd6f6a3b8be296 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 14 Feb 2025 17:57:29 +0100 Subject: [PATCH 01/13] Add Select, Mod and CryptoHash operators --- include/aidge/operator/CryptoHash.hpp | 126 +++++++++++++++++ include/aidge/operator/Mod.hpp | 128 ++++++++++++++++++ include/aidge/operator/Select.hpp | 92 +++++++++++++ include/aidge/utils/Registrar.hpp | 13 +- python_binding/operator/pybind_CryptoHash.cpp | 38 ++++++ python_binding/operator/pybind_Mod.cpp | 60 ++++++++ python_binding/operator/pybind_Select.cpp | 49 +++++++ python_binding/pybind_core.cpp | 6 + src/operator/CryptoHash.cpp | 64 +++++++++ src/operator/Mod.cpp | 89 ++++++++++++ src/operator/Select.cpp | 110 +++++++++++++++ 11 files changed, 774 insertions(+), 1 deletion(-) create mode 100644 include/aidge/operator/CryptoHash.hpp create mode 100644 include/aidge/operator/Mod.hpp create mode 100644 include/aidge/operator/Select.hpp create mode 100644 python_binding/operator/pybind_CryptoHash.cpp create mode 100644 python_binding/operator/pybind_Mod.cpp create mode 100644 python_binding/operator/pybind_Select.cpp create mode 100644 src/operator/CryptoHash.cpp create mode 100644 src/operator/Mod.cpp create mode 100644 src/operator/Select.cpp diff --git a/include/aidge/operator/CryptoHash.hpp b/include/aidge/operator/CryptoHash.hpp new file mode 100644 index 000000000..266adecd3 --- /dev/null +++ b/include/aidge/operator/CryptoHash.hpp @@ -0,0 +1,126 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CRYPTOHASH_H_ +#define AIDGE_CORE_OPERATOR_CRYPTOHASH_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +/** + * @enum CryptoHashAttr + * @brief Attributes for the CryptoHash operator. + */ +enum class CryptoHashAttr { + CryptoHashFunction ///< Cryptographic hash function to use. +}; + +/** + * @enum CryptoHashFunction + * @brief Cryptographic hash function. + */ +enum class CryptoHashFunction { + SHA256 ///< SHA256 +}; + +/** + * @brief Produce a cryptographic hash from the input. + * + * @see OperatorTensor + * @see Registrable + */ +class CryptoHash_Op : public OperatorTensor, + public Registrable<CryptoHash_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const CryptoHash_Op&)>> { + +public: + static const std::string Type; + +private: + using Attributes_ = StaticAttributes<CryptoHashAttr, CryptoHashFunction>; + template <CryptoHashAttr e> using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + +public: + CryptoHash_Op(); + + /** + * @brief Copy-constructor. + * @param op CryptoHash_Op to copy. + * @details Copies the operator attributes and its output tensor(s), but not + * its input tensors. The new operator has no associated input. + */ + CryptoHash_Op(const CryptoHash_Op& op); + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::CryptoHash_Op + */ + std::shared_ptr<Operator> clone() const override; + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + std::set<std::string> getAvailableBackends() const override; + + /** + * @brief Get the attributes of the operator. + * @return A shared pointer to the attributes. + */ + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + + /** + * @brief Get or modify the `crypto_hash_function` attribute. + * @return Reference to the `crypto_hash_function` attribute. + */ + inline CryptoHashFunction& cryptoHashFunction() const noexcept { return mAttributes->getAttr<CryptoHashAttr::CryptoHashFunction>(); } + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +std::shared_ptr<Node> CryptoHash(const std::string& name = ""); + +} // namespace Aidge + +namespace { + +/** + * @brief EnumStrings specialization for CryptoHashAttr. + */ +template <> +const char* const EnumStrings<Aidge::CryptoHashAttr>::data[] = { + "crypto_hash_function" +}; + +/** + * @brief EnumStrings specialization for CryptoHashFunction. + */ +template <> +const char* const EnumStrings<Aidge::CryptoHashFunction>::data[] = { + "sha256" +}; + +} // namespace + +#endif /* AIDGE_CORE_OPERATOR_CRYPTOHASH_H_ */ diff --git a/include/aidge/operator/Mod.hpp b/include/aidge/operator/Mod.hpp new file mode 100644 index 000000000..56a9381e0 --- /dev/null +++ b/include/aidge/operator/Mod.hpp @@ -0,0 +1,128 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_MOD_H_ +#define AIDGE_CORE_OPERATOR_MOD_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ModAttr { +/** + * @brief Enable fmod like behavior + * + * Whether the operator should behave like fmod (default is false meaning it + * will do integer mods); Set this to true to force fmod treatment + */ + Fmod +}; + +/** + * @brief Description of an element-wise binary modulus operation on input Tensors, + * supporting NumPy broadcasting. + * + * For each pair of elements x and y from the input Tensors, the function + * is defined as: + * `f(x, y) = x mod y + * + * Broadcasting adjusts shapes of the input Tensors to make them compatible: + * - Tensors are aligned from the rightmost dimensions. + * - Dimensions are compatible if they are equal, one of them is 1, or missing. + * + * The output Tensor shape is determined by taking the maximum size along + * each dimension of the input Tensors after broadcasting. + * + * Examples: + * 1. Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2) + * 2. Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3) + * + * @see OperatorTensor + * @see Registrable + */ +class Mod_Op : public OperatorTensor, + public Registrable<Mod_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Mod_Op&)>> { + +public: + static const std::string Type; + +private: + using Attributes_ = StaticAttributes<ModAttr, bool>; + template <ModAttr e> using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + +public: + Mod_Op(); + + /** + * @brief Copy-constructor. + * @param op Mod_Op to copy. + * @details Copies the operator attributes and its output tensor(s), but not + * its input tensors. The new operator has no associated input. + */ + Mod_Op(const Mod_Op& op); + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Mod_Op + */ + std::shared_ptr<Operator> clone() const override; + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + std::set<std::string> getAvailableBackends() const override; + + /** + * @brief Get the attributes of the operator. + * @return A shared pointer to the attributes. + */ + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + + /** + * @brief Get or modify the `fmod` attribute. + * @return Reference to the `fmod` attribute. + */ + inline bool& fmod() const noexcept { return mAttributes->getAttr<ModAttr::Fmod>(); } + + static const std::vector<std::string> getInputsName(){ + return {"dividend", "divisor"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"remainder"}; + } +}; + +std::shared_ptr<Node> Mod(const std::string& name = ""); + +} // namespace Aidge + +namespace { + +/** + * @brief EnumStrings specialization for ModAttr. + */ +template <> +const char* const EnumStrings<Aidge::ModAttr>::data[] = { + "fmod" +}; + +} // namespace + +#endif /* AIDGE_CORE_OPERATOR_MOD_H_ */ diff --git a/include/aidge/operator/Select.hpp b/include/aidge/operator/Select.hpp new file mode 100644 index 000000000..4dcace84e --- /dev/null +++ b/include/aidge/operator/Select.hpp @@ -0,0 +1,92 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_SELECT_H_ +#define AIDGE_CORE_OPERATOR_SELECT_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Implementation of the Select operator. + * @note This operator implementation is agnostic to the backend and is located here instead of in aidge_backend. + */ +class Select_OpImpl : public OperatorImpl { +public: + /** + * @brief Constructor for Select_OpImpl. + * @param[in] op The Operator instance. + * @param[in] backend The backend name (optional). + */ + Select_OpImpl(const Operator& op, const std::string& backend = "") + : OperatorImpl(op, backend) {} + + /** + * @brief Perform the forward operation for the reshape. + */ + void forward() override; + void backward() override; +}; + +/** + * @brief + * @see OperatorTensor + * @see Registrable + */ +class Select_Op : public OperatorTensor, + public Registrable<Select_Op, + std::string, + std::function<std::shared_ptr<OperatorImpl>(const Select_Op&)>> +{ +public: + static const std::string Type; + + Select_Op(const Aidge::IOIndex_t nbIn); + + /** + * @brief Copy-constructor. + * @param op Select_Op to copy. + * @details Copies the operator attributes and its output tensor(s), but not + * its input tensors. The new operator has no associated input. + */ + Select_Op(const Select_Op& op); + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Select_Op + */ + std::shared_ptr<Operator> clone() const override; + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + std::set<std::string> getAvailableBackends() const override; + + static const std::vector<std::string> getInputsName() { + return {"select", "data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName() { + return {"data_output"}; + } +}; + +std::shared_ptr<Node> Select(const IOIndex_t nbIn, const std::string& name = ""); +} + +#endif /* AIDGE_CORE_OPERATOR_SELECT_H_ */ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 28dab05f8..a2f33eea5 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -76,9 +76,20 @@ struct Registrar { } static auto create(const registrar_key& key) { - AIDGE_ASSERT(exists(key), "missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name()); + if (!exists(key)) { + Log::error("missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name()); + + Log::info("Available registrar keys are:"); + for(const auto& keyValue : C::registry()) { + Log::info("- {}", keyValue.first); + } + + AIDGE_THROW_OR_ABORT(std::runtime_error, "missing or invalid registrar key"); + } + return C::registry().at(key); } + static std::set<registrar_key> getKeys(){ std::set<registrar_key> keys; for(const auto& keyValue : C::registry()) diff --git a/python_binding/operator/pybind_CryptoHash.cpp b/python_binding/operator/pybind_CryptoHash.cpp new file mode 100644 index 000000000..923f91b60 --- /dev/null +++ b/python_binding/operator/pybind_CryptoHash.cpp @@ -0,0 +1,38 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + #include <pybind11/pybind11.h> + + #include "aidge/data/Tensor.hpp" + #include "aidge/operator/CryptoHash.hpp" + #include "aidge/operator/OperatorTensor.hpp" + + namespace py = pybind11; + namespace Aidge { + + void init_CryptoHash(py::module& m) { + py::enum_<CryptoHashFunction>(m, "crypto_hash_function") + .value("SHA256", CryptoHashFunction::SHA256) + .export_values(); + + py::class_<CryptoHash_Op, std::shared_ptr<CryptoHash_Op>, OperatorTensor>(m, "CryptoHashOp", py::multiple_inheritance()) + .def(py::init<>()) + .def_static("get_inputs_name", &CryptoHash_Op::getInputsName) + .def_static("get_outputs_name", &CryptoHash_Op::getOutputsName) + .def_readonly_static("Type", &CryptoHash_Op::Type); + + declare_registrable<CryptoHash_Op>(m, "CryptoHashOp"); + + m.def("CryptoHash", &CryptoHash, py::arg("name") = ""); + } + + } // namespace Aidge + \ No newline at end of file diff --git a/python_binding/operator/pybind_Mod.cpp b/python_binding/operator/pybind_Mod.cpp new file mode 100644 index 000000000..aa88f2068 --- /dev/null +++ b/python_binding/operator/pybind_Mod.cpp @@ -0,0 +1,60 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + #include <pybind11/pybind11.h> + + #include "aidge/data/Tensor.hpp" + #include "aidge/operator/Mod.hpp" + #include "aidge/operator/OperatorTensor.hpp" + + namespace py = pybind11; + namespace Aidge { + + void init_Mod(py::module& m) { + py::class_<Mod_Op, std::shared_ptr<Mod_Op>, OperatorTensor>(m, "ModOp", py::multiple_inheritance(), + R"mydelimiter( + Initialize a Mod operator. + This operator performs element-wise binary modulus between two input tensors. + The operation is defined as: + Output = Input1 mod Input2 + The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting. + Examples: + Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2) + Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3) + :param name : Name of the node (optional). + :type name : str + )mydelimiter") + .def(py::init<>()) + .def_static("get_inputs_name", &Mod_Op::getInputsName) + .def_static("get_outputs_name", &Mod_Op::getOutputsName) + .def_readonly_static("Type", &Mod_Op::Type); + + declare_registrable<Mod_Op>(m, "ModOp"); + + m.def("Mod", &Mod, py::arg("name") = "", + R"mydelimiter( + Initialize a node containing a Mod operator that performs element-wise binary modulus between two tensors. + The operation is defined as: + Output = Input1 mod Input2 + The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting. + Examples: + Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2) + Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3) + + :param name : Name of the node (optional). + :type name : str + :return: A node containing the Mod operator. + :rtype: :py:class:`ModOp` + )mydelimiter"); + } + + } // namespace Aidge + \ No newline at end of file diff --git a/python_binding/operator/pybind_Select.cpp b/python_binding/operator/pybind_Select.cpp new file mode 100644 index 000000000..0cb858acd --- /dev/null +++ b/python_binding/operator/pybind_Select.cpp @@ -0,0 +1,49 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Select.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Select(py::module& m) { + py::class_<Select_Op, std::shared_ptr<Select_Op>, OperatorTensor>(m, "SelectOp", py::multiple_inheritance(), + R"mydelimiter( + Initialize a Select operator. + + :param nb_inputs : The number of input tensors to select from. + :type nb_inputs : :py:class:`int` + )mydelimiter") + .def(py::init<const IOIndex_t>(), + py::arg("nb_inputs")) + .def_static("get_inputs_name", &Select_Op::getInputsName) + .def_static("get_outputs_name", &Select_Op::getOutputsName) + .def_readonly_static("Type", &Select_Op::Type); + + declare_registrable<Select_Op>(m, "SelectOp"); + + m.def("Select", &Select, py::arg("nb_inputs"), py::arg("name") = "", + R"mydelimiter( + Initialize a node containing a Select operator. + + :param nb_inputs : The number of input tensors to select from. + :type nb_inputs : :py:class:`int` + :param name : Name of the node. + :type name : :py:class:`str` + )mydelimiter"); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index ef1111b39..b2aa93dc9 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -48,6 +48,7 @@ void init_Concat(py::module&); void init_ConstantOfShape(py::module&); void init_Conv(py::module&); void init_ConvDepthWise(py::module&); +void init_CryptoHash(py::module&); void init_DepthToSpace(py::module&); void init_Div(py::module&); void init_Equal(py::module&); @@ -67,6 +68,7 @@ void init_MatMul(py::module&); void init_MaxPooling(py::module&); void init_Memorize(py::module&); void init_MetaOperatorDefs(py::module&); +void init_Mod(py::module&); void init_Mul(py::module&); void init_Pad(py::module&); void init_Pop(py::module&); @@ -79,6 +81,7 @@ void init_Reshape(py::module&); void init_Resize(py::module&); void init_Round(py::module&); void init_Scaling(py::module&); +void init_Select(py::module&); void init_Shape(py::module&); void init_Sigmoid(py::module&); void init_Slice(py::module&); @@ -149,6 +152,7 @@ void init_Aidge(py::module& m) { init_Conv(m); init_ConvDepthWise(m); init_ConstantOfShape(m); + init_CryptoHash(m); init_DepthToSpace(m); init_Div(m); init_Equal(m); @@ -168,6 +172,7 @@ void init_Aidge(py::module& m) { init_MaxPooling(m); init_Memorize(m); init_MetaOperatorDefs(m); + init_Mod(m); init_Mul(m); init_Pad(m); init_Pop(m); @@ -179,6 +184,7 @@ void init_Aidge(py::module& m) { init_Resize(m); init_Round(m); init_Scaling(m); + init_Select(m); init_Shape(m); init_Sigmoid(m); init_Slice(m); diff --git a/src/operator/CryptoHash.cpp b/src/operator/CryptoHash.cpp new file mode 100644 index 000000000..f6656cd61 --- /dev/null +++ b/src/operator/CryptoHash.cpp @@ -0,0 +1,64 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> // std::size_t +#include <stdexcept> // std::runtime_error +#include <string> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/CryptoHash.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +const std::string Aidge::CryptoHash_Op::Type = "CryptoHash"; + +Aidge::CryptoHash_Op::CryptoHash_Op() + : OperatorTensor(Type, {InputCategory::Data}, 1), + mAttributes(std::make_shared<Attributes_>( + attr<CryptoHashAttr::CryptoHashFunction>(CryptoHashFunction::SHA256))) +{} + +Aidge::CryptoHash_Op::CryptoHash_Op(const Aidge::CryptoHash_Op& op) + : OperatorTensor(op), + mAttributes(op.mAttributes) +{ + if (op.mImpl){ + SET_IMPL_MACRO(CryptoHash_Op, *this, op.backend()); + }else{ + mImpl = nullptr; + } +} + +std::shared_ptr<Aidge::Operator> Aidge::CryptoHash_Op::clone() const { + return std::make_shared<CryptoHash_Op>(*this); +} + +bool Aidge::CryptoHash_Op::forwardDims(bool /*allowDataDependency*/) { + mOutputs[0]->resize({256}); + return true; +} + +void Aidge::CryptoHash_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(CryptoHash_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> Aidge::CryptoHash_Op::getAvailableBackends() const { + return Registrar<CryptoHash_Op>::getKeys(); +} + +/////////////////////////////////////////// + +std::shared_ptr<Aidge::Node> Aidge::CryptoHash(const std::string& name) { + return std::make_shared<Node>(std::make_shared<CryptoHash_Op>(), name); +} diff --git a/src/operator/Mod.cpp b/src/operator/Mod.cpp new file mode 100644 index 000000000..038a3c284 --- /dev/null +++ b/src/operator/Mod.cpp @@ -0,0 +1,89 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> // std::size_t +#include <stdexcept> // std::runtime_error +#include <string> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Mod.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" + +const std::string Aidge::Mod_Op::Type = "Mod"; + +Aidge::Mod_Op::Mod_Op() + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1), + mAttributes(std::make_shared<Attributes_>( + attr<ModAttr::Fmod>(false))) +{} + +Aidge::Mod_Op::Mod_Op(const Aidge::Mod_Op& op) + : OperatorTensor(op), + mAttributes(op.mAttributes) +{ + if (op.mImpl){ + SET_IMPL_MACRO(Mod_Op, *this, op.backend()); + }else{ + mImpl = nullptr; + } +} + +std::shared_ptr<Aidge::Operator> Aidge::Mod_Op::clone() const { + return std::make_shared<Mod_Op>(*this); +} + +bool Aidge::Mod_Op::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { + const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); + const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); + + std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1; + const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1; + + std::size_t out_id = outDims.size() - 1; + std::size_t low_id = lowDims.size() - 1; + std::size_t i = 0; + while (i++ < lowDims.size()) { + if (outDims[out_id] == 1) { + outDims[out_id] = lowDims[low_id]; + } + else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Mod Operation: {} for input#0 vs {} for input#1", + inputsDims0, inputsDims1); + } + --out_id; + --low_id; + } + mOutputs[0]->resize(outDims); + return true; + } + + return false; +} + + +void Aidge::Mod_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(Mod_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> Aidge::Mod_Op::getAvailableBackends() const { + return Registrar<Mod_Op>::getKeys(); +} + +/////////////////////////////////////////// + +std::shared_ptr<Aidge::Node> Aidge::Mod(const std::string& name) { + return std::make_shared<Node>(std::make_shared<Mod_Op>(), name); +} \ No newline at end of file diff --git a/src/operator/Select.cpp b/src/operator/Select.cpp new file mode 100644 index 000000000..67e792cd0 --- /dev/null +++ b/src/operator/Select.cpp @@ -0,0 +1,110 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> // std::size_t +#include <stdexcept> // std::runtime_error +#include <string> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Select.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Registrar.hpp" + +void Aidge::Select_OpImpl::forward() { + const Select_Op& op = dynamic_cast<const Select_Op&>(mOp); + AIDGE_ASSERT(op.getInput(0)->size() > 0, "Select input is empty!"); + + std::shared_ptr<Tensor> selectFallback; + const auto& select = op.getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu"); + const auto selectVal = select.get<int32_t>(0); + AIDGE_ASSERT(selectVal >= 0 && selectVal < op.nbInputs() - 1, "Select input out of range. Expected value in range [0, {}], got {}", op.nbInputs() - 2, selectVal); + + op.getOutput(0)->getImpl()->copy(op.getInput(selectVal + 1)->getImpl()->rawPtr(), op.getInput(selectVal + 1)->size()); +} + +void Aidge::Select_OpImpl::backward() { + const Select_Op& op = dynamic_cast<const Select_Op&>(mOp); + AIDGE_ASSERT(op.getInput(0)->size() > 0, "Select input is empty!"); + + std::shared_ptr<Tensor> selectFallback; + const auto& select = op.getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu"); + const auto selectVal = select.get<int32_t>(0); + AIDGE_ASSERT(selectVal >= 0 && selectVal < op.nbInputs() - 1, "Select input out of range. Expected value in range [0, {}], got {}", op.nbInputs() - 2, selectVal); + + op.getInput(selectVal + 1)->grad()->getImpl()->copy(op.getOutput(0)->grad()->getImpl()->rawPtr(), op.getOutput(0)->size()); +} + +////////////////////////////////////////////////// + +const std::string Aidge::Select_Op::Type = "Select"; + +Aidge::Select_Op::Select_Op(const Aidge::IOIndex_t nbIn) + : OperatorTensor(Type, std::vector<InputCategory>(nbIn + 1, InputCategory::Data), 1) +{ + // ctor + AIDGE_ASSERT(nbIn > 1, "Select operator should have at least two inputs."); + mImpl = std::make_shared<Select_OpImpl>(*this); +} + +Aidge::Select_Op::Select_Op(const Select_Op& op) + : OperatorTensor(op) +{ + if (!op.backend().empty()) { + SET_IMPL_MACRO(Select_Op, *this, op.backend()); + } + else { + mImpl = std::make_shared<Select_OpImpl>(*this); + } +} + +std::shared_ptr<Aidge::Operator> Aidge::Select_Op::clone() const { + return std::make_shared<Select_Op>(*this); +} + +bool Aidge::Select_Op::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { + // First input is select input + const auto expectedDims = getInput(1)->dims(); + for (std::size_t i = 2; i < nbInputs(); ++i) { + if (expectedDims != getInput(i)->dims()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "{} operator's inputs should have the same dimensions: expected {} (input #0), given {} (input #{})", + type(), expectedDims, getInput(i)->dims(), i); + } + } + mOutputs[0]->resize(expectedDims); + return true; + } + + return false; +} + +void Aidge::Select_Op::setBackend(const std::string& name, DeviceIdx_t device) { + if (Registrar<Select_Op>::exists({name})){ + SET_IMPL_MACRO(Select_Op, *this, name); + } + else { + mImpl = std::make_shared<Select_OpImpl>(*this); + } + mOutputs[0]->setBackend(name, device); +} + +std::set<std::string> Aidge::Select_Op::getAvailableBackends() const { + return Registrar<Select_Op>::getKeys(); +} + +//////////////////////////////////////////////////////////////////////////////// + +std::shared_ptr<Aidge::Node> Aidge::Select(const Aidge::IOIndex_t nbIn, const std::string& name) { + return std::make_shared<Node>(std::make_shared<Select_Op>(nbIn), name); +} -- GitLab From 2512f855e591bb6230822f58e9ba64cc856ff83e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 16 Feb 2025 17:49:47 +0100 Subject: [PATCH 02/13] Adjust output size w.r.t datatype --- src/operator/CryptoHash.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/CryptoHash.cpp b/src/operator/CryptoHash.cpp index f6656cd61..064b480b1 100644 --- a/src/operator/CryptoHash.cpp +++ b/src/operator/CryptoHash.cpp @@ -44,7 +44,7 @@ std::shared_ptr<Aidge::Operator> Aidge::CryptoHash_Op::clone() const { } bool Aidge::CryptoHash_Op::forwardDims(bool /*allowDataDependency*/) { - mOutputs[0]->resize({256}); + mOutputs[0]->resize({256 / getDataTypeBitWidth(mOutputs[0]->dataType())}); return true; } -- GitLab From c3318428b1855e94abdd703e371bbffe42368bbe Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 16 Feb 2025 23:46:37 +0100 Subject: [PATCH 03/13] Working concept of with tagConditionalNodes() --- include/aidge/scheduler/Scheduler.hpp | 2 + include/aidge/utils/DynamicAttributes.hpp | 5 +- src/scheduler/Scheduler.cpp | 129 ++++++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index dfdc270fa..ed9db47b3 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -127,6 +127,8 @@ public: virtual ~Scheduler(); public: + void tagConditionalNodes(); + /** * @brief Get the static scheduling order of nodes. * @param step The step of the static schedule to retrieve (default is 0). diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 6ac76c138..633ce40d9 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -165,7 +165,10 @@ public: else { const auto ns = name.substr(0, dot); const auto nsName = name.substr(dot + 1); - future_std::any_cast<DynamicAttributes&>(mAttrs.at(ns)).delAttr(nsName); + auto it = mAttrs.find(ns); + if (it != mAttrs.end()) { + future_std::any_cast<DynamicAttributes&>(it->second).delAttr(nsName); + } } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index fabdc7ad2..3f59f2fdc 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } +void Aidge::Scheduler::tagConditionalNodes() { + // Get a list of selectors + std::vector<NodePtr> selectors; + for (const auto& node : mGraphView->getNodes()) { + if (node->type() == "Select") { + selectors.push_back(node); + } + node->attributes()->delAttr("schedule.cond"); + } + + std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) { + bool inBranch = true; + for (const auto& child : node->getChildren()) { + if (branchNodes.find(child) == branchNodes.end()) { + inBranch = false; + break; + } + } + + if (inBranch) { + branchNodes.insert(node); + for (const auto& parent : node->getParents()) { + recInBranch(parent, branchNodes); + } + } + }; + + // For each selector, tag nodes + for (const auto& select : selectors) { + for (size_t branch = 0; branch < select->getParents().size() - 1; ++branch) { + std::set<NodePtr> branchNodes; + branchNodes.insert(select); + recInBranch(select->getParent(branch + 1), branchNodes); + branchNodes.erase(select); + + for (const auto& node : branchNodes) { + std::set<std::pair<NodePtr, size_t>> attr; + if (node->attributes()->hasAttr("schedule.cond")) { + attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + } + + attr.insert({select, branch}); + node->attributes()->setAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond", attr); + } + } + } +} + std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { // 0) setup useful variables @@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera } } + if (consumer->attributes()->hasAttr("schedule.cond")) { + auto attr = consumer->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + AvailableDataStatus status; + + if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) > + getNbAvailableData(select, 0, status)) + { + isRunnable = false; + break; + } + } + } + if (isRunnable) { runnableConsumers.insert(consumer); } @@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } + if (node->attributes()->hasAttr("schedule.cond")) { + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + const auto& parent = select->input(0).first; + + const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), + [parent](const auto& v) { return (v->node == parent); }); + if (it != schedule.rend()) { + const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); + } + } + } + latest = std::max(latest, early); schedule[elt]->early = early; } @@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } + + if (child->type() == "Select") { + for (const auto& condNode : mGraphView->getNodes()) { + if (condNode->attributes()->hasAttr("schedule.cond")) { + auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + + if (node == select->input(0).first) { + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [condNode](const auto& v) { return (v->node == condNode); }); + if (it != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); + } + } + } + } + } + } } + // TODO: ADD HERE SCHEDULE COND + schedule[elt]->late = late; } } @@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) ++inputIdx; } + if (node->attributes()->hasAttr("schedule.cond")) { + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + const auto& parent = select->input(0); + + if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) > + parent.first->getOperator()->getNbProducedData(parent.second)) + { + const auto& parentPrior = getPriorProducersConsumers(parent.first); + + if (!parentPrior.isPrior) { + // only happens in case of cyclic graphs + return PriorProducersConsumers(); // not scheduled + } + else { + prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); + prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); + } + } + } + } + prior.isPrior = true; if (node->type() == Producer_Op::Type) { prior.requiredProducers.insert(node); -- GitLab From a1b31cdcf8db7fff9b5b170080e91e55288492bc Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 17 Feb 2025 12:04:03 +0100 Subject: [PATCH 04/13] Added conditional execution --- include/aidge/scheduler/Scheduler.hpp | 17 ++++++++++++++- src/scheduler/ParallelScheduler.cpp | 10 +++++++++ src/scheduler/Scheduler.cpp | 31 +++++++++++++++++++++------ src/scheduler/SequentialScheduler.cpp | 15 +++++++------ 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index ed9db47b3..a164f6c76 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -127,7 +127,22 @@ public: virtual ~Scheduler(); public: - void tagConditionalNodes(); + /** + * @brief Add schedule.cond attribute to conditional nodes. + * The schedule.cond attribute is a `std::set<std::pair<NodePtr, size_t>>`, + * where the first element is the Select node and the second element, the + * Select input index. + */ + void tagConditionalNodes() const; + + /** + * @brief Check if the node condition is valid. + * + * @param node Node to check the condition. + * @return true If the node condition is valid, meaning it has to be executed. + * @return false If the node condition is not valid, meaning it can be skipped. + */ + bool isNodeCondValid(NodePtr node) const; /** * @brief Get the static scheduling order of nodes. diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp index 2a44dd49f..8e53254d7 100644 --- a/src/scheduler/ParallelScheduler.cpp +++ b/src/scheduler/ParallelScheduler.cpp @@ -88,6 +88,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the critical node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + if (!isNodeCondValid(node)) { + finished = true; + return; + } + const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); @@ -144,6 +149,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + if (!isNodeCondValid(node)) { + finished = true; + return; + } + const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 3f59f2fdc..bc0b19cfe 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -56,7 +56,7 @@ void Aidge::Scheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } -void Aidge::Scheduler::tagConditionalNodes() { +void Aidge::Scheduler::tagConditionalNodes() const { // Get a list of selectors std::vector<NodePtr> selectors; for (const auto& node : mGraphView->getNodes()) { @@ -104,6 +104,27 @@ void Aidge::Scheduler::tagConditionalNodes() { } } +bool Aidge::Scheduler::isNodeCondValid(NodePtr node) const { + bool skip = false; + if (node->attributes()->hasAttr("schedule.cond")) { + skip = true; + + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + for (const auto& cond : attr) { + const auto& selectNode = cond.first; + const auto selectOp = std::static_pointer_cast<OperatorTensor>(selectNode->getOperator()); + + std::shared_ptr<Tensor> selectFallback; + const auto& select = selectOp->getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu"); + const auto selectVal = select.get<int32_t>(0); + + skip &= (selectVal != static_cast<int32_t>(cond.second)); + } + } + + return !skip; +} + std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { // 0) setup useful variables @@ -512,10 +533,10 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE const auto& select = cond.first; if (node == select->input(0).first) { - const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + const auto itElt = std::find_if(schedule.begin() + elt + 1, schedule.end(), [condNode](const auto& v) { return (v->node == condNode); }); - if (it != schedule.end()) { - const std::size_t step = std::distance(schedule.begin(), it); + if (itElt != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), itElt); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } @@ -526,8 +547,6 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } - // TODO: ADD HERE SCHEDULE COND - schedule[elt]->late = late; } } diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 4e6e91f51..5f6bb6c07 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -59,12 +59,15 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& runnable : staticSchedule) { - Log::debug("run: {}", namePtrTable.at(runnable->node)); - - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->node->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + const bool skip = !isNodeCondValid(runnable->node); + Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); + + if (!skip) { + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + } } ++mStaticScheduleStep; -- GitLab From a74fcf7cc80c00d5bcb74438fbd6661f8927e39c Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 17 Feb 2025 15:12:32 +0100 Subject: [PATCH 05/13] Improved doc --- include/aidge/scheduler/Scheduler.hpp | 11 +++--- src/graph/GraphView.cpp | 5 +-- src/scheduler/ParallelScheduler.cpp | 4 +-- src/scheduler/Scheduler.cpp | 48 +++++++++++++++++---------- src/scheduler/SequentialScheduler.cpp | 2 +- 5 files changed, 43 insertions(+), 27 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index a164f6c76..881d16e05 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -131,18 +131,19 @@ public: * @brief Add schedule.cond attribute to conditional nodes. * The schedule.cond attribute is a `std::set<std::pair<NodePtr, size_t>>`, * where the first element is the Select node and the second element, the - * Select input index. + * Select input index (starting from 0, ignoring the condition input). */ void tagConditionalNodes() const; /** - * @brief Check if the node condition is valid. + * @brief Check if the conditional node is required (if one of its conditions + * is true). * * @param node Node to check the condition. - * @return true If the node condition is valid, meaning it has to be executed. - * @return false If the node condition is not valid, meaning it can be skipped. + * @return true If any node condition is true, meaning it has to be executed. + * @return false If all node conditions are false, meaning it can be skipped. */ - bool isNodeCondValid(NodePtr node) const; + bool isConditionalNodeRequired(NodePtr node) const; /** * @brief Get the static scheduling order of nodes. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 315844858..dd17cd344 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -100,10 +100,11 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd const auto namePtrTable = getRankedNodesName("{3}"); for (const std::shared_ptr<Node> &node_ptr : mNodes) { + const std::string hasCondition = (node_ptr->attributes()->hasAttr("schedule.cond")) ? " fa:fa-circle-question" : ""; std::string givenName = (node_ptr->name().empty()) - ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" - : "\"" + node_ptr->name() + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; + ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + hasCondition + : "\"" + node_ptr->name() + hasCondition + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; if (verbose) { givenName += "<br/><span style='color:white; background-color: purple; float: right'>" + node_ptr->getOperator()->backend() + "</span>"; diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp index 8e53254d7..fb0d45c94 100644 --- a/src/scheduler/ParallelScheduler.cpp +++ b/src/scheduler/ParallelScheduler.cpp @@ -88,7 +88,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the critical node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - if (!isNodeCondValid(node)) { + if (!isConditionalNodeRequired(node)) { finished = true; return; } @@ -149,7 +149,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - if (!isNodeCondValid(node)) { + if (!isConditionalNodeRequired(node)) { finished = true; return; } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index bc0b19cfe..18210a1f8 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -104,7 +104,7 @@ void Aidge::Scheduler::tagConditionalNodes() const { } } -bool Aidge::Scheduler::isNodeCondValid(NodePtr node) const { +bool Aidge::Scheduler::isConditionalNodeRequired(NodePtr node) const { bool skip = false; if (node->attributes()->hasAttr("schedule.cond")) { skip = true; @@ -471,6 +471,8 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } + // Node can be run the earliest just after all its conditions are computed. + // A condition act like an additionnal parent. if (node->attributes()->hasAttr("schedule.cond")) { auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); @@ -514,7 +516,14 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } // Node can be run the latest just before its earliest child is run + bool condition = false; // check if node can be a condition for conditional nodes for (const auto& child : node->getChildren()) { + if (child->type() == "Select" && node == child->input(0).first) { + // If the node child is a Select operator, it may be a condition to + // some conditional nodes (if node is the first input of Select). + condition = true; + } + // Find child node earliest scheduled position const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), [child](const auto& v) { return (v->node == child); }); @@ -523,23 +532,28 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } + } - if (child->type() == "Select") { - for (const auto& condNode : mGraphView->getNodes()) { - if (condNode->attributes()->hasAttr("schedule.cond")) { - auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); - - for (const auto& cond : attr) { - const auto& select = cond.first; - - if (node == select->input(0).first) { - const auto itElt = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [condNode](const auto& v) { return (v->node == condNode); }); - if (itElt != schedule.end()) { - const std::size_t step = std::distance(schedule.begin(), itElt); - late = std::min(late, schedule[step]->late - 1); - schedule[step]->laterThan.push_back(schedule[elt]); - } + // When node is a condition to conditional nodes, it acts like a parent + // to them. Therefore, the conditional nodes should be considered as + // childs to this node. + if (condition) { + for (const auto& condNode : mGraphView->getNodes()) { + if (condNode->attributes()->hasAttr("schedule.cond")) { + auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + + // Check if node is a condition to this conditional node + if (node == select->input(0).first) { + // If so, the conditional node act like a child + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [condNode](const auto& v) { return (v->node == condNode); }); + if (it != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); } } } diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 5f6bb6c07..2b1956d79 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -59,7 +59,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& runnable : staticSchedule) { - const bool skip = !isNodeCondValid(runnable->node); + const bool skip = !isConditionalNodeRequired(runnable->node); Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); if (!skip) { -- GitLab From 5460accb52c9fce7c09ebaf3f5d56cb1f8914377 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 17 Feb 2025 17:13:37 +0100 Subject: [PATCH 06/13] Minor changes --- include/aidge/scheduler/Scheduler.hpp | 20 +++++++++---------- python_binding/scheduler/pybind_Scheduler.cpp | 1 + 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 881d16e05..61aeb49ba 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -135,16 +135,6 @@ public: */ void tagConditionalNodes() const; - /** - * @brief Check if the conditional node is required (if one of its conditions - * is true). - * - * @param node Node to check the condition. - * @return true If any node condition is true, meaning it has to be executed. - * @return false If all node conditions are false, meaning it can be skipped. - */ - bool isConditionalNodeRequired(NodePtr node) const; - /** * @brief Get the static scheduling order of nodes. * @param step The step of the static schedule to retrieve (default is 0). @@ -220,6 +210,16 @@ public: protected: + /** + * @brief Check if the conditional node is required (if one of its conditions + * is true). + * + * @param node Node to check the condition. + * @return true If any node condition is true, meaning it has to be executed. + * @return false If all node conditions are false, meaning it can be skipped. + */ + bool isConditionalNodeRequired(NodePtr node) const; + /** * @brief Getter for the set of children Nodes of the given input Nodes. * @param producers Set of Nodes for which we want to obtain the set of children Nodes. diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 01a27e455..34ed93520 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -30,6 +30,7 @@ void init_Scheduler(py::module& m){ py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("graph_view", &Scheduler::graphView) + .def("tag_conditional_nodes", &Scheduler::tagConditionalNodes) .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2) -- GitLab From 4f8f27b865149c83f1485ab31f941bc36978f088 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 19 Feb 2025 11:19:33 +0100 Subject: [PATCH 07/13] Imrpoved display --- include/aidge/utils/Registrar.hpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index a2f33eea5..a9368ecaf 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -76,16 +76,12 @@ struct Registrar { } static auto create(const registrar_key& key) { - if (!exists(key)) { - Log::error("missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name()); - - Log::info("Available registrar keys are:"); - for(const auto& keyValue : C::registry()) { - Log::info("- {}", keyValue.first); - } - - AIDGE_THROW_OR_ABORT(std::runtime_error, "missing or invalid registrar key"); - } + AIDGE_ASSERT(exists(key), + "missing or invalid registrar key: {} for registrable object {}\n" + "Did you include/import the corresponding module?\n" + "If so, it is possible that the object is not yet supported.\n\n" + "Available registrar keys are:\n {}", + key, typeid(C).name(), fmt::join(getKeys(), "\n ")); return C::registry().at(key); } -- GitLab From 7a9b48f4c3652efc544fdb5721896044b435c33e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 20 Feb 2025 16:06:25 +0100 Subject: [PATCH 08/13] Integrate a small fix for @raphaelmillet --- src/scheduler/Scheduler.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 18210a1f8..f65729d64 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -1165,8 +1165,14 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling( [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); } else if (sorting == EarlyLateSort::AsLateAsPossible) { + // The last condition (lhs->early > rhs->early) ensures that when on a + // branch join, one does not switch branch just before the join if there + // is only a single node (scheduled as late as possible, since not in the + // critical path) in one of the branch. + // @raphaelmillet Required for PNeuro export. + // TODO: add branch-level sorting policies (shortest to longuest branch for example) std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early < rhs->early)); }); + [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early > rhs->early)); }); } std::vector<std::shared_ptr<Node>> schedule; -- GitLab From d453aa1dd8f6ac7b925ba8aab3fe8e8f1b26659d Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 21 Feb 2025 10:20:32 +0100 Subject: [PATCH 09/13] Added hint in generateMemory() error message about forward dims --- src/scheduler/Scheduler.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index f65729d64..99f2de669 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -686,7 +686,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { const auto requiredSize = op->getRequiredMemory(outputIdx, {}); AIDGE_ASSERT(requiredSize.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.", node->name(), node->type()); // By default, specifies a fully monolithic memory block @@ -724,7 +724,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr const auto requiredData = op->getNbRequiredData(inputIdx); const auto requiredProtected = op->getNbRequiredProtected(inputIdx); AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.", node->name(), node->type()); const bool isWrappable = (requiredProtected.data < requiredData.data); @@ -848,7 +848,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer } AIDGE_ASSERT(requiredSize.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.", node->name(), node->type()); // By default, specifies a fully monolithic memory block @@ -892,7 +892,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer const auto requiredData = allocableNode->getOperator()->getNbRequiredData(inputIdx); const auto requiredProtected = allocableNode->getOperator()->getNbRequiredProtected(inputIdx); AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", + "Cannot generate memory with token-based producer-consumer model for node {} (of type {}). You may need to forward dimensions in the graph first.", node->name(), node->type()); const bool isWrappable = (requiredProtected.data < requiredData.data); -- GitLab From a89e63904c1ac932f2ad4cd34ef579a8a8d80ce9 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 21 Feb 2025 15:22:15 +0100 Subject: [PATCH 10/13] Refactored StaticAnalysis --- aidge_core/dynamic_analysis.py | 161 ++++++++++++++++++ aidge_core/export_utils/scheduler_export.py | 4 +- aidge_core/mem_info.py | 4 +- aidge_core/static_analysis.py | 4 +- include/aidge/analysis/DynamicAnalysis.hpp | 55 ++++++ .../OperatorStats.hpp} | 81 ++------- include/aidge/analysis/StaticAnalysis.hpp | 87 ++++++++++ include/aidge/scheduler/Scheduler.hpp | 32 +++- .../aidge/scheduler/SequentialScheduler.hpp | 7 - .../analysis/pybind_DynamicAnalysis.cpp | 39 +++++ .../pybind_OperatorStats.cpp} | 54 +----- .../analysis/pybind_StaticAnalysis.cpp | 69 ++++++++ python_binding/pybind_core.cpp | 4 + python_binding/scheduler/pybind_Scheduler.cpp | 30 +++- src/analysis/DynamicAnalysis.cpp | 57 +++++++ src/analysis/OperatorStats.cpp | 66 +++++++ src/{graph => analysis}/StaticAnalysis.cpp | 43 +---- src/scheduler/Scheduler.cpp | 26 ++- src/scheduler/SequentialScheduler.cpp | 41 ++--- .../Test_StaticAnalysis.cpp | 5 +- unit_tests/operator/Test_MetaOperator.cpp | 2 +- unit_tests/scheduler/Test_Scheduler.cpp | 12 +- 22 files changed, 651 insertions(+), 232 deletions(-) create mode 100644 aidge_core/dynamic_analysis.py create mode 100644 include/aidge/analysis/DynamicAnalysis.hpp rename include/aidge/{graph/StaticAnalysis.hpp => analysis/OperatorStats.hpp} (89%) create mode 100644 include/aidge/analysis/StaticAnalysis.hpp create mode 100644 python_binding/analysis/pybind_DynamicAnalysis.cpp rename python_binding/{graph/pybind_StaticAnalysis.cpp => analysis/pybind_OperatorStats.cpp} (57%) create mode 100644 python_binding/analysis/pybind_StaticAnalysis.cpp create mode 100644 src/analysis/DynamicAnalysis.cpp create mode 100644 src/analysis/OperatorStats.cpp rename src/{graph => analysis}/StaticAnalysis.cpp (78%) rename unit_tests/{graph => analysis}/Test_StaticAnalysis.cpp (93%) diff --git a/aidge_core/dynamic_analysis.py b/aidge_core/dynamic_analysis.py new file mode 100644 index 000000000..c8cdd6710 --- /dev/null +++ b/aidge_core/dynamic_analysis.py @@ -0,0 +1,161 @@ +import matplotlib +import matplotlib.pyplot as plt +from functools import partial +import numpy as np +import aidge_core + +class DynamicAnalysis(aidge_core.DynamicAnalysis): + def log_nb_arithm_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) + + def log_nb_logic_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale) + + def log_nb_comp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale) + + def log_nb_nl_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale) + + def log_nb_mac_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) + + def log_nb_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale) + + def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale) + + def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale) + + def log_nb_ops_by_type(self, filename, title=None, log_scale=False): + return self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops, + aidge_core.OperatorStats.get_nb_arithm_fp_ops, + aidge_core.OperatorStats.get_nb_logic_ops, + aidge_core.OperatorStats.get_nb_comp_ops, + aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale) + + def _log_callback(self, callback, filename, title=None, log_scale=False): + """ + Log a statistic given by an OperatorStats callback member function. + Usage: + + stats = DynamicAnalysis(model) + stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") + + :param func: OperatorStats member function to call. + :param filename: Output graph file name. + :type filename: str + :param title: Title of the graph. + :type title: str + """ + + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + stats = self.get_op_stats(node) + name = namePtrTable[node] + attr = {} + if type(node.get_operator()) is aidge_core.GenericOperatorOp: + # Display Generic Op in orange + attr = {'color': 'orange'} + elif not node.get_operator().is_atomic(): + # Display Meta Op in bold + attr = {'fontweight': 'bold'} + elif node.type() not in aidge_core.get_keys_OperatorStats(): + # Display unsupported operator in red labels + attr = {'color': 'red'} + if attr: + name = (name, attr) + if isinstance(callback, list): + series.append([name, [partial(cb, stats)() for cb in callback]]) + legend = [cb.__name__ for cb in callback] + if title is None: title = str(legend) + else: + series.append([name, partial(callback, stats)()]) + if title is None: title = callback.__name__ + + if title is None: title = str(callback) + if filename is not None: + self._log_bar(series, filename, title, legend, log_scale) + return series + + def _log_bar(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) + plt.xlim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + bot = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.bar(names_only, serie, bottom=bot) + bot += serie + else: + plt.bar(names_only, values) + if callable(getattr(ax.yaxis, 'minorticks_on', None)): + ax.yaxis.minorticks_on() # introduced in matplotlib 3.9.x + plt.grid(axis='y', which='major', linestyle='--', color='gray') + plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.yaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + + labels = plt.gca().get_xticks() + tick_labels = plt.gca().get_xticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + plt.xticks(rotation='vertical') + if log_scale: plt.yscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') + + def _log_barh(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4))) + plt.ylim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + left = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.barh(names_only, serie, left=left) + left += serie + else: + plt.barh(names_only, values) + if callable(getattr(ax.xaxis, 'minorticks_on', None)): + ax.xaxis.minorticks_on() # introduced in matplotlib 3.9.x + plt.grid(axis='x', which='major', linestyle='--', color='gray') + plt.grid(axis='x', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.xaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + plt.gca().xaxis.set_label_position('top') + plt.gca().xaxis.tick_top() + + labels = plt.gca().get_yticks() + tick_labels = plt.gca().get_yticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + if log_scale: plt.xscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 0995e4cea..8aaedc18d 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -42,7 +42,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = :param scheduler: Scheduler instance managing the computation graph. - Uses `graph_view` and `get_static_scheduling` methods + Uses `graph_view` and `get_sequential_static_scheduling` methods to retrieve the computation graph layout and ordered nodes. :type scheduler: aidge_core.Scheduler :param export_folder_path: Path to the folder where the generated export files will be saved. @@ -88,7 +88,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_size: List[int] = [] # List of aidge_core.Node ordered by scheduler - list_forward_nodes: List[aidge_core.Node] = scheduler.get_static_scheduling() + list_forward_nodes: List[aidge_core.Node] = scheduler.get_sequential_static_scheduling() # If exportLib define use it # else parse component in platform diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py index cabc2c72e..b8d3c6101 100644 --- a/aidge_core/mem_info.py +++ b/aidge_core/mem_info.py @@ -22,7 +22,7 @@ def compute_default_mem_info(scheduler: aidge_core.Scheduler) -> Tuple[int, List mem_size = 0 # Exclude Producers and the last layers (because the results are stored outside the export) - for i, node in enumerate(scheduler.get_static_scheduling()): + for i, node in enumerate(scheduler.get_sequential_static_scheduling()): if node.type() != "Producer": node_mem_info = [] for out_id in range(node.get_nb_outputs()): @@ -161,7 +161,7 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder mem_planes = mem_manager.get_planes() - for node in scheduler.get_static_scheduling(): + for node in scheduler.get_sequential_static_scheduling(): node_mem_info = [] if node.type() == "Producer": pass diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py index b4a82a4fb..907bc48f5 100644 --- a/aidge_core/static_analysis.py +++ b/aidge_core/static_analysis.py @@ -4,7 +4,7 @@ from functools import partial import numpy as np import aidge_core -class StaticAnalysisExt(aidge_core.StaticAnalysis): +class StaticAnalysis(aidge_core.StaticAnalysis): def log_nb_params(self, filename, title=None, log_scale=False): namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); nodes = self.get_graph().get_ordered_nodes() @@ -77,7 +77,7 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): Log a statistic given by an OperatorStats callback member function. Usage: - stats = StaticAnalysisExt(model) + stats = StaticAnalysis(model) stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") :param func: OperatorStats member function to call. diff --git a/include/aidge/analysis/DynamicAnalysis.hpp b/include/aidge/analysis/DynamicAnalysis.hpp new file mode 100644 index 000000000..3dadf79b3 --- /dev/null +++ b/include/aidge/analysis/DynamicAnalysis.hpp @@ -0,0 +1,55 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ +#define AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ + +#include <cstddef> // std::size_t +#include <memory> +#include <string> + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Base class to compute statistics from a scheduled graph + * + */ +class DynamicAnalysis : public std::enable_shared_from_this<DynamicAnalysis> { +public: + DynamicAnalysis() = delete; + DynamicAnalysis(const Scheduler& scheduler); + + virtual ~DynamicAnalysis(); + + std::size_t getNbArithmOps() const; + std::size_t getNbLogicOps() const; + std::size_t getNbCompOps() const; + std::size_t getNbNLOps() const; + std::size_t getNbOps() const; + std::size_t getNbArithmIntOps() const; + std::size_t getNbArithmFpOps() const; + std::size_t getNbMACOps() const; + +protected: + const Scheduler& mScheduler; + + std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; +}; +} + +#endif /* AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ */ diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/analysis/OperatorStats.hpp similarity index 89% rename from include/aidge/graph/StaticAnalysis.hpp rename to include/aidge/analysis/OperatorStats.hpp index cc5532224..ac1abcee7 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/analysis/OperatorStats.hpp @@ -10,8 +10,8 @@ * ********************************************************************************/ -#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_ -#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_ +#ifndef AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ +#define AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ #include <cstddef> // std::size_t #include <memory> @@ -44,6 +44,14 @@ public: OperatorStats() = delete; OperatorStats(const Operator& op); + /** + * @brief Get the Operator Stats object corresponding to the given node. + * + * @param node Node + * @return std::shared_ptr<OperatorStats> Node's Operator stats + */ + static std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node); + virtual ~OperatorStats(); inline const Operator& getOperator() const noexcept { return mOp; } @@ -156,73 +164,6 @@ protected: const Operator &mOp; }; -/** - * @brief Base class to compute statistics from a GraphView - * - */ -class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { -public: - StaticAnalysis() = delete; - StaticAnalysis(std::shared_ptr<GraphView> graph); - - virtual ~StaticAnalysis(); - - inline const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } - - /** - * @brief Get the Operator Stats object corresponding to the given node. - * - * @param node Node - * @return std::shared_ptr<OperatorStats> Node's Operator stats - */ - std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; - - /** - * @brief Get the number of parameters associated to a node. This includes - * all Producers directly connected to the node's inputs as well as all - * internal Producers (in case of a meta operator). - * - * Note: this function does not check if parameters are shared between - * several nodes or not. This means that simply adding parameters count from - * several nodes may lead to a higher number of parameters than in reality - * if some of them are shared. - * - * @param node Node - * @return std::size_t Number of parameters - */ - virtual std::size_t getNbParams(std::shared_ptr<Node> node) const; - - /** - * @brief Get the total parameters memory size, in bits, associated to a node. - * This includes all Producers directly connected to the node's inputs as - * well as all internal Producers (in case of a meta operator). - * - * Note: this function does not check if parameters are shared between - * several nodes or not. This means that simply adding parameters size from - * several nodes may lead to a higher parameter size than in reality - * if some of them are shared. - * - * @param node Node - * @return std::size_t Total parameters memory, in bits - */ - virtual std::size_t getParamsSize(std::shared_ptr<Node> node) const; - - std::size_t getNbArithmOps() const; - std::size_t getNbLogicOps() const; - std::size_t getNbCompOps() const; - std::size_t getNbNLOps() const; - std::size_t getNbOps() const; - std::size_t getNbArithmIntOps() const; - std::size_t getNbArithmFpOps() const; - std::size_t getNbMACOps() const; - virtual void summary(bool incProducers = false) const; - -protected: - const std::shared_ptr<GraphView> mGraph; - - std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; -}; - //////////////////////////////////////////////////////////////////////////////// class MetaOpStats : public OperatorStats { @@ -579,4 +520,4 @@ REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create); REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create); } -#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ +#endif /* AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ */ diff --git a/include/aidge/analysis/StaticAnalysis.hpp b/include/aidge/analysis/StaticAnalysis.hpp new file mode 100644 index 000000000..a0feadd72 --- /dev/null +++ b/include/aidge/analysis/StaticAnalysis.hpp @@ -0,0 +1,87 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ +#define AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ + +#include <cstddef> // std::size_t +#include <memory> +#include <string> + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Base class to compute statistics from a GraphView + * + */ +class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { +public: + StaticAnalysis() = delete; + StaticAnalysis(std::shared_ptr<GraphView> graph); + + virtual ~StaticAnalysis(); + + inline const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } + + /** + * @brief Get the number of parameters associated to a node. This includes + * all Producers directly connected to the node's inputs as well as all + * internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters count from + * several nodes may lead to a higher number of parameters than in reality + * if some of them are shared. + * + * @param node Node + * @return std::size_t Number of parameters + */ + virtual std::size_t getNbParams(std::shared_ptr<Node> node) const; + + /** + * @brief Get the total parameters memory size, in bits, associated to a node. + * This includes all Producers directly connected to the node's inputs as + * well as all internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters size from + * several nodes may lead to a higher parameter size than in reality + * if some of them are shared. + * + * @param node Node + * @return std::size_t Total parameters memory, in bits + */ + virtual std::size_t getParamsSize(std::shared_ptr<Node> node) const; + + std::size_t getNbArithmOps() const; + std::size_t getNbLogicOps() const; + std::size_t getNbCompOps() const; + std::size_t getNbNLOps() const; + std::size_t getNbOps() const; + std::size_t getNbArithmIntOps() const; + std::size_t getNbArithmFpOps() const; + std::size_t getNbMACOps() const; + virtual void summary(bool incProducers = false) const; + +protected: + const std::shared_ptr<GraphView> mGraph; + + std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; +}; +} + +#endif /* AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 61aeb49ba..7c309783d 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -46,7 +46,7 @@ class GraphView; * @see MemoryManager */ class Scheduler { -protected: +public: /** * @struct StaticSchedulingElement * @brief Represents a node in the static schedule. @@ -81,7 +81,7 @@ protected: std::chrono::time_point<std::chrono::high_resolution_clock> start; /** Actual start time of execution */ std::chrono::time_point<std::chrono::high_resolution_clock> end; /** Actual end time of execution */ }; -public: + enum class AvailableDataStatus { Connected, UpperNodeInputFound, @@ -90,7 +90,7 @@ public: NotConnected }; - enum class EarlyLateSort { + enum class SchedulingPolicy { Default, AsSoonAsPossible, AsLateAsPossible @@ -136,12 +136,28 @@ public: void tagConditionalNodes() const; /** - * @brief Get the static scheduling order of nodes. + * @brief Get the static scheduling (after generate scheduling). + * @return Vector of StaticSchedulingElement pointers. + */ + std::vector<StaticSchedulingElement*> getStaticScheduling(std::size_t step = 0) const { + return mStaticSchedule.at(step); + } + + /** + * @brief Get the static scheduling sequential order of nodes. * @param step The step of the static schedule to retrieve (default is 0). - * @param sorting Sorting mode. + * @param policy Sorting mode. * @return Vector of shared pointers to Nodes in their scheduled order. */ - std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const; + std::vector<std::shared_ptr<Node>> getSequentialStaticScheduling(std::size_t step = 0, SchedulingPolicy policy = SchedulingPolicy::Default) const; + + /** + * @brief Get the dynamic scheduling (after graph execution). + * @return Vector of SchedulingElement. + */ + std::vector<SchedulingElement> getScheduling() const { + return mScheduling; + } /** * @brief Get the GraphView associated with this Scheduler. @@ -199,14 +215,14 @@ public: * order of execution for the nodes, to a file in Mermaid format. * @param fileName Name of the file to save the diagram (without extension). */ - void saveStaticSchedulingDiagram(const std::string& fileName) const; + void saveStaticSchedulingDiagram(const std::string& fileName, bool ignoreProducers = false) const; void saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat = 2) const; /** * @brief Save in a Mermaid file the order of layers execution. * @param fileName Name of the generated file. */ - void saveSchedulingDiagram(const std::string& fileName) const; + void saveSchedulingDiagram(const std::string& fileName, bool ignoreProducers = false) const; protected: diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index 35dafead6..0ae18b085 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -25,13 +25,6 @@ namespace Aidge { * Multi-threaded parallel scheduler with dynamic scheduling. */ class SequentialScheduler : public Scheduler { -public: - enum class SchedulingPolicy { - Default, - AsSoonAsPossible, - AsLateAsPossible - }; - public: SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : Scheduler(graphView, upperNode), diff --git a/python_binding/analysis/pybind_DynamicAnalysis.cpp b/python_binding/analysis/pybind_DynamicAnalysis.cpp new file mode 100644 index 000000000..3cd71f741 --- /dev/null +++ b/python_binding/analysis/pybind_DynamicAnalysis.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/analysis/DynamicAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +class pyDynamicAnalysis: public DynamicAnalysis { +public: + using DynamicAnalysis::DynamicAnalysis; // Inherit constructors + +}; + +void init_DynamicAnalysis(py::module& m){ + py::class_<DynamicAnalysis, std::shared_ptr<DynamicAnalysis>, pyDynamicAnalysis>(m, "DynamicAnalysis", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("get_nb_arithm_ops", &DynamicAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &DynamicAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &DynamicAnalysis::getNbCompOps) + .def("get_nb_nl_ops", &DynamicAnalysis::getNbNLOps) + .def("get_nb_ops", &DynamicAnalysis::getNbOps) + .def("get_nb_arithm_int_ops", &DynamicAnalysis::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &DynamicAnalysis::getNbArithmFpOps) + .def("get_nb_mac_ops", &DynamicAnalysis::getNbMACOps) + ; +} +} diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/analysis/pybind_OperatorStats.cpp similarity index 57% rename from python_binding/graph/pybind_StaticAnalysis.cpp rename to python_binding/analysis/pybind_OperatorStats.cpp index b7c704d72..be2b79e67 100644 --- a/python_binding/graph/pybind_StaticAnalysis.cpp +++ b/python_binding/analysis/pybind_OperatorStats.cpp @@ -12,7 +12,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/OperatorStats.hpp" namespace py = pybind11; namespace Aidge { @@ -74,41 +74,10 @@ public: } }; -class pyStaticAnalysis: public StaticAnalysis { -public: - using StaticAnalysis::StaticAnalysis; // Inherit constructors - - size_t getNbParams(std::shared_ptr<Node> node) const override { - PYBIND11_OVERRIDE( - size_t, - StaticAnalysis, - getNbParams, - node - ); - } - - size_t getParamsSize(std::shared_ptr<Node> node) const override { - PYBIND11_OVERRIDE( - size_t, - StaticAnalysis, - getParamsSize, - node - ); - } - - void summary(bool incProducers) const override { - PYBIND11_OVERRIDE( - void, - StaticAnalysis, - summary, - incProducers - ); - } -}; - -void init_StaticAnalysis(py::module& m){ +void init_OperatorStats(py::module& m){ py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<const Operator&>(), py::arg("op")) + .def_static("get_op_stats", &OperatorStats::getOpStats, py::arg("node")) .def("get_operator", &OperatorStats::getOperator) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) @@ -120,22 +89,5 @@ void init_StaticAnalysis(py::module& m){ .def("get_nb_mac_ops", &OperatorStats::getNbMACOps) ; declare_registrable<OperatorStats>(m, "OperatorStats"); - - py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) - .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) - .def("get_graph", &StaticAnalysis::getGraph) - .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) - .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) - .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) - .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) - .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) - .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps) - .def("get_nb_ops", &StaticAnalysis::getNbOps) - .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps) - .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) - .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) - .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) - .def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node")) - ; } } diff --git a/python_binding/analysis/pybind_StaticAnalysis.cpp b/python_binding/analysis/pybind_StaticAnalysis.cpp new file mode 100644 index 000000000..65ee8e8b0 --- /dev/null +++ b/python_binding/analysis/pybind_StaticAnalysis.cpp @@ -0,0 +1,69 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/analysis/StaticAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +class pyStaticAnalysis: public StaticAnalysis { +public: + using StaticAnalysis::StaticAnalysis; // Inherit constructors + + size_t getNbParams(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getNbParams, + node + ); + } + + size_t getParamsSize(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getParamsSize, + node + ); + } + + void summary(bool incProducers) const override { + PYBIND11_OVERRIDE( + void, + StaticAnalysis, + summary, + incProducers + ); + } +}; + +void init_StaticAnalysis(py::module& m){ + py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("get_graph", &StaticAnalysis::getGraph) + .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) + .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) + .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) + .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps) + .def("get_nb_ops", &StaticAnalysis::getNbOps) + .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) + .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) + .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) + ; +} +} diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index b2aa93dc9..c7a7330b6 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -32,7 +32,9 @@ void init_OperatorImpl(py::module&); void init_Log(py::module&); void init_Operator(py::module&); void init_OperatorTensor(py::module&); +void init_OperatorStats(py::module&); void init_StaticAnalysis(py::module&); +void init_DynamicAnalysis(py::module&); void init_Abs(py::module&); void init_Add(py::module&); @@ -136,7 +138,9 @@ void init_Aidge(py::module& m) { init_Log(m); init_Operator(m); init_OperatorTensor(m); + init_OperatorStats(m); init_StaticAnalysis(m); + init_DynamicAnalysis(m); init_Abs(m); init_Add(m); diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 34ed93520..582ba4678 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,23 +21,39 @@ namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ - py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort") - .value("Default", Scheduler::EarlyLateSort::Default) - .value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible) - .value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible) + py::class_<Scheduler::StaticSchedulingElement>(m, "StaticSchedulingElement") + .def_readonly("node", &Scheduler::StaticSchedulingElement::node) + .def_readonly("early", &Scheduler::StaticSchedulingElement::early) + .def_readonly("late", &Scheduler::StaticSchedulingElement::late) + .def_readonly("earlier_than", &Scheduler::StaticSchedulingElement::earlierThan) + .def_readonly("later_than", &Scheduler::StaticSchedulingElement::laterThan) + ; + + py::class_<Scheduler::SchedulingElement>(m, "SchedulingElement") + .def_readonly("node", &Scheduler::SchedulingElement::node) + .def_readonly("start", &Scheduler::SchedulingElement::start) + .def_readonly("end", &Scheduler::SchedulingElement::end) + ; + + py::enum_<Scheduler::SchedulingPolicy>(m, "SchedulingPolicy") + .value("Default", Scheduler::SchedulingPolicy::Default) + .value("AsSoonAsPossible", Scheduler::SchedulingPolicy::AsSoonAsPossible) + .value("AsLateAsPossible", Scheduler::SchedulingPolicy::AsLateAsPossible) .export_values(); py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("graph_view", &Scheduler::graphView) .def("tag_conditional_nodes", &Scheduler::tagConditionalNodes) - .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) - .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) + .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"), py::arg("ignore_producers") = false) + .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name"), py::arg("ignore_producers") = false) .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2) .def("reset_scheduling", &Scheduler::resetScheduling) .def("clear_scheduling", &Scheduler::clearScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) - .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + .def("get_sequential_static_scheduling", &Scheduler::getSequentialStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::SchedulingPolicy::Default) + .def("get_scheduling", &Scheduler::getScheduling) .def("graph_view", &Scheduler::graphView) .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) diff --git a/src/analysis/DynamicAnalysis.cpp b/src/analysis/DynamicAnalysis.cpp new file mode 100644 index 000000000..039820154 --- /dev/null +++ b/src/analysis/DynamicAnalysis.cpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/analysis/DynamicAnalysis.hpp" + +#include <cstddef> // std::size_t +#include <memory> +#include <numeric> // std::accumulate +#include <set> + +#include <fmt/core.h> // fmt::println +#include <fmt/format.h> +#include <fmt/ranges.h> + +#include "aidge/data/DataType.hpp" // Aidge::isFloatingPoint +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/scheduler/Scheduler.hpp" + +Aidge::DynamicAnalysis::DynamicAnalysis(const Scheduler& scheduler) + : mScheduler(scheduler) +{ + //ctor +} + +Aidge::DynamicAnalysis::~DynamicAnalysis() = default; + +std::size_t Aidge::DynamicAnalysis::getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } +std::size_t Aidge::DynamicAnalysis::getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } +std::size_t Aidge::DynamicAnalysis::getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } +std::size_t Aidge::DynamicAnalysis::getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); } +std::size_t Aidge::DynamicAnalysis::getNbOps() const { return accumulate(&OperatorStats::getNbOps); } +std::size_t Aidge::DynamicAnalysis::getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); } +std::size_t Aidge::DynamicAnalysis::getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); } +std::size_t Aidge::DynamicAnalysis::getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } + +std::size_t Aidge::DynamicAnalysis::accumulate(std::size_t (OperatorStats::*func)() const) const { + const auto& scheduling = mScheduler.getScheduling(); + return std::accumulate( + scheduling.cbegin(), + scheduling.cend(), + std::size_t(0), + [this, func](const std::size_t& lhs, const Scheduler::SchedulingElement& rhs) { + return lhs + (OperatorStats::getOpStats(rhs.node).get()->*func)(); + }); +} diff --git a/src/analysis/OperatorStats.cpp b/src/analysis/OperatorStats.cpp new file mode 100644 index 000000000..a020403ad --- /dev/null +++ b/src/analysis/OperatorStats.cpp @@ -0,0 +1,66 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" + +#include <cstddef> // std::size_t +#include <memory> +#include <numeric> // std::accumulate +#include <set> + +#include <fmt/core.h> // fmt::println +#include <fmt/format.h> +#include <fmt/ranges.h> + +#include "aidge/data/DataType.hpp" // Aidge::isFloatingPoint +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +Aidge::OperatorStats::OperatorStats(const Operator& op) + : mOp(op) +{ + //ctor +} + +std::shared_ptr<Aidge::OperatorStats> Aidge::OperatorStats::getOpStats(std::shared_ptr<Node> node) { + return (Registrar<OperatorStats>::exists(node->type())) + ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) + : (node->getOperator()->isAtomic()) + ? std::make_shared<OperatorStats>(*(node->getOperator())) + : std::make_shared<MetaOpStats>(*(node->getOperator())); +} + +Aidge::OperatorStats::~OperatorStats() = default; + +std::size_t Aidge::OperatorStats::getNbArithmIntOps() const { + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + if (!isFloatingPoint(opTensor->getOutput(0)->dataType())) { + return getNbArithmOps(); + } + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////// + +Aidge::MetaOpStats::~MetaOpStats() = default; + +std::size_t Aidge::MetaOpStats::getNbArithmOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } +std::size_t Aidge::MetaOpStats::getNbLogicOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } +std::size_t Aidge::MetaOpStats::getNbCompOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } +std::size_t Aidge::MetaOpStats::getNbNLOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); } +std::size_t Aidge::MetaOpStats::getNbArithmIntOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); } +std::size_t Aidge::MetaOpStats::getNbMACOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } diff --git a/src/graph/StaticAnalysis.cpp b/src/analysis/StaticAnalysis.cpp similarity index 78% rename from src/graph/StaticAnalysis.cpp rename to src/analysis/StaticAnalysis.cpp index 418ae8936..0e32618c2 100644 --- a/src/graph/StaticAnalysis.cpp +++ b/src/analysis/StaticAnalysis.cpp @@ -9,7 +9,7 @@ * ********************************************************************************/ -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" #include <cstddef> // std::size_t #include <memory> @@ -27,26 +27,6 @@ #include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp" -Aidge::OperatorStats::OperatorStats(const Operator& op) - : mOp(op) -{ - //ctor -} - -Aidge::OperatorStats::~OperatorStats() = default; - -std::size_t Aidge::OperatorStats::getNbArithmIntOps() const { - const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); - if (opTensor) { - if (!isFloatingPoint(opTensor->getOutput(0)->dataType())) { - return getNbArithmOps(); - } - } - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) : mGraph(graph) { @@ -174,14 +154,6 @@ std::size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) con return paramsSize; } -std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { - return (Registrar<OperatorStats>::exists(node->type())) - ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) - : (node->getOperator()->isAtomic()) - ? std::make_shared<OperatorStats>(*(node->getOperator())) - : std::make_shared<MetaOpStats>(*(node->getOperator())); -} - std::size_t Aidge::StaticAnalysis::getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } std::size_t Aidge::StaticAnalysis::getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } std::size_t Aidge::StaticAnalysis::getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } @@ -197,17 +169,6 @@ std::size_t Aidge::StaticAnalysis::accumulate(std::size_t (OperatorStats::*func) mGraph->getNodes().cend(), std::size_t(0), [this, func](const std::size_t& lhs, const std::shared_ptr<Node>& rhs) { - return lhs + (this->getOpStats(rhs).get()->*func)(); + return lhs + (OperatorStats::getOpStats(rhs).get()->*func)(); }); } - -//////////////////////////////////////////////////////////////////////////////// - -Aidge::MetaOpStats::~MetaOpStats() = default; - -std::size_t Aidge::MetaOpStats::getNbArithmOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } -std::size_t Aidge::MetaOpStats::getNbLogicOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } -std::size_t Aidge::MetaOpStats::getNbCompOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } -std::size_t Aidge::MetaOpStats::getNbNLOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); } -std::size_t Aidge::MetaOpStats::getNbArithmIntOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); } -std::size_t Aidge::MetaOpStats::getNbMACOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 99f2de669..155a5e7e4 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -668,7 +668,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr MemoryManager memManager; for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { - for (const auto& node : getStaticScheduling(step)) { + for (const auto& node : getSequentialStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -787,7 +787,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { // AsLateAsPossible ensures that when a node child is Concat, all the parents // of the Concat parents have already been memory mapped! - for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) { + for (const auto& node : getSequentialStaticScheduling(step, SchedulingPolicy::AsLateAsPossible)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -1038,7 +1038,7 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te } } -void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName, bool ignoreProducers) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -1054,6 +1054,10 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const const auto globalStart = mScheduling[0].start; for (const auto& element : mScheduling) { + if (ignoreProducers && element.node->type() == "Producer") { + continue; + } + auto name = namePtrTable.at(element.node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); @@ -1068,7 +1072,7 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const fmt::print(fp.get(), "\n"); } -void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName, bool ignoreProducers) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -1084,6 +1088,10 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) for (const auto& schedule : mStaticSchedule) { for (const auto& element : schedule) { + if (ignoreProducers && element->node->type() == "Producer") { + continue; + } + auto name = namePtrTable.at(element->node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); @@ -1154,17 +1162,17 @@ void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const { - AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); - AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getSequentialStaticScheduling(std::size_t step, SchedulingPolicy policy) const { + AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getSequentialStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); + AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getSequentialStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end()); - if (sorting == EarlyLateSort::AsSoonAsPossible) { + if (policy == SchedulingPolicy::AsSoonAsPossible) { std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); } - else if (sorting == EarlyLateSort::AsLateAsPossible) { + else if (policy == SchedulingPolicy::AsLateAsPossible) { // The last condition (lhs->early > rhs->early) ensures that when on a // branch join, one does not switch branch just before the join if there // is only a single node (scheduled as late as possible, since not in the diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 2b1956d79..07f01ce09 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -45,28 +45,18 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } // Sort static scheduling according to the policy - std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); - - if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) { - std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); }); - } - else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) { - std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); }); - } - + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - for (const auto& runnable : staticSchedule) { - const bool skip = !isConditionalNodeRequired(runnable->node); - Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); + for (const auto& runnable : nodes) { + const bool skip = !isConditionalNodeRequired(runnable); + Log::debug("run: {}{}", namePtrTable.at(runnable), (skip) ? " -- skipped" : ""); if (!skip) { const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->node->forward(); + runnable->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); } } @@ -87,17 +77,20 @@ void Aidge::SequentialScheduler::backward() { } // map of node <-> info to display with verbose + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // run scheduled operators in reverse order - const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); - for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { - Log::debug("run: {}", namePtrTable.at((*runnable)->node)); - - const auto tStart = std::chrono::high_resolution_clock::now(); - (*runnable)->node->backward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd)); + for (auto runnable = nodes.crbegin(); runnable != nodes.crend(); ++runnable) { + const bool skip = !isConditionalNodeRequired((*runnable)); + Log::debug("run: {}{}", namePtrTable.at((*runnable)), (skip) ? " -- skipped" : ""); + + if (!skip) { + const auto tStart = std::chrono::high_resolution_clock::now(); + (*runnable)->backward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement((*runnable), tStart, tEnd)); + } } ++mStaticScheduleStep; diff --git a/unit_tests/graph/Test_StaticAnalysis.cpp b/unit_tests/analysis/Test_StaticAnalysis.cpp similarity index 93% rename from unit_tests/graph/Test_StaticAnalysis.cpp rename to unit_tests/analysis/Test_StaticAnalysis.cpp index 9488cbaf6..a491cb143 100644 --- a/unit_tests/graph/Test_StaticAnalysis.cpp +++ b/unit_tests/analysis/Test_StaticAnalysis.cpp @@ -16,7 +16,8 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/FC.hpp" @@ -53,7 +54,7 @@ TEST_CASE("[core/graph] StaticAnalysis") { REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8); REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16); - const auto conv1Stats = stats.getOpStats(g1->getNode("conv1")); + const auto conv1Stats = OperatorStats::getOpStats(g1->getNode("conv1")); REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); REQUIRE(conv1Stats->getNbArithmFpOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 042b04f01..cf4428055 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -58,7 +58,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { //op->getOperator()->updateConsummerProducer(); // require implementation //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); - //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); + //REQUIRE(microGraphScheduler->getSequentialStaticScheduling().size() == 2); } SECTION("LSTM") { diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index ec850d281..dbe0ef3ae 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -75,7 +75,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { auto scheduler = SequentialScheduler(g1); scheduler.generateScheduling(); fmt::print("gen scheduling finished\n"); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); @@ -118,7 +118,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { // auto scheduler = SequentialScheduler(g1); // scheduler.generateScheduling(); // fmt::print("gen scheduling finished\n"); - // const auto sch = scheduler.getStaticScheduling(); + // const auto sch = scheduler.getSequentialStaticScheduling(); // const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); @@ -146,7 +146,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({data1, identity}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == data1); @@ -159,7 +159,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({data1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == data1); @@ -171,7 +171,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({gen1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == gen1); @@ -183,7 +183,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({dead1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(nodes.size() == 1); REQUIRE(sch.size() == 0); -- GitLab From 316bf5ca687d580af6b50fa05db2c990785d7ddf Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 21 Feb 2025 16:31:30 +0100 Subject: [PATCH 11/13] Add working tests from !281 on GraphView::replace() --- unit_tests/graph/Test_GraphView.cpp | 803 ++++++++++++++++++++++++++++ 1 file changed, 803 insertions(+) diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index d2f41269e..8462d50ee 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -1087,3 +1087,806 @@ TEST_CASE("[core/graph] GraphView(insertParent)") { } } + +////////////////////////////////////////// +// test case 1: one input, one output +// 1/ graphview= other_input --> other1 --> old1 --> old2 --> other2 +// 2/ replace oldGraph= -->old1 --> old2 --> by newGraph= --> new1 --> new2 --> +// 3/ verify : +// - output value of replace == TRUE +// - graphview = other_input --> other1 --> new1 --> new2 --> other2 +// - old1 -and old2 are remove from graphview +////////////////////////////////////////// +// test case 2 multiple input, multiple output +// 1/ graphview== other_input1 --> other11 --> old11 --> old12 --> other21 +// other_input2 --> other12 --> old12 --> old22 --> other22 +// 2/ replace oldGraph= --> old11 --> old12 by newGraph= --> new11 --> new12 --> +// --> old12 --> old22 --> new21 --> new22 --> +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1 --> other11 --> new11 --> new12 --> other21 +// other_input2 --> other12 --> new21 --> new22 --> other22 +// - old11, old12, old21 and old22 are remove from graphview +////////////////////////////////////////// +// test case 3 none input, multiple output +// 1/ graphview= old_input --> old1 --> other +// 2/ replace oldGraph= old_input --> old --> by newGraph = new_input --> new --> +// 3/ verify : +// - output value of replace == TRUE +// - graphview = new_input --> new --> other +// - old_input -and old are remove from graphview +////////////////////////////////////////// +TEST_CASE("[core/graph] Graph: replacing a set of nodes, same old/new inputs and same old/new outputs", "[GraphView][replace]") +{ + SECTION("test case 1 one input, one output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + auto other1 = GenericOperator("Other", 1, 0, 1, "other1"); + auto myOld1 = GenericOperator("Old", 1, 0, 1, "old1"); + auto myOld2 = GenericOperator("Old", 1, 0, 1, "old2"); + auto other2 = GenericOperator("Other", 1, 0, 1, "other2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld2->addChild(other2); + graphTest->add({other1, myOld1, myOld2, other2}); + + // Create and link new graph + auto myNew1 = GenericOperator("New", 1, 0, 1, "new1"); + auto myNew2 = GenericOperator("New", 1, 0, 1, "new2"); + myNew1->addChild(myNew2); + + // Replace + bool retValue = GraphView::replace({myOld1, myOld2}, {myNew1, myNew2}); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, other2})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + } + SECTION("test case 2 multiple input, multiple output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + // Create old graph + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto otherInput2 = GenericOperator("Producer", 0, 0, 1, "other_input2"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto myOld11 = GenericOperator("Old", 1, 0, 1, "old11"); + auto myOld12 = GenericOperator("Old", 1, 0, 1, "old12"); + auto myOld21 = GenericOperator("Old", 1, 0, 1, "old21"); + auto myOld22 = GenericOperator("Old", 1, 0, 1, "old22"); + auto other21 = GenericOperator("Other", 1, 0, 1, "other21"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + // Link old graph + otherInput1->addChild(other11); + other11->addChild(myOld11); + myOld11->addChild(myOld12); + myOld12->addChild(other12); + otherInput2->addChild(other21); + other21->addChild(myOld21); + myOld21->addChild(myOld22); + myOld22->addChild(other22); + graphTest->add({other11, myOld11, myOld12, other12, other21, myOld21, myOld22, other22}); + + //std::vector<std::pair<NodePtr, IOIndex_t>> orderInput; + //orderInput.push_back(std::pair<NodePtr, IOIndex_t>(other11,0)); + //orderInput.push_back(std::pair<NodePtr, IOIndex_t>(other21,0)); + //graphTest->setOrderedInputs(orderInput); + graphTest->save("myGraph",true,true); + // Create and link new graph + auto myNew11 = GenericOperator("New", 1, 0, 1, "new11"); + auto myNew12 = GenericOperator("New", 1, 0, 1, "new12"); + auto myNew21 = GenericOperator("New", 1, 0, 1, "new21"); + auto myNew22 = GenericOperator("New", 1, 0, 1, "new22"); + myNew11->addChild(myNew12); + myNew21->addChild(myNew22); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld11, myOld12, myOld21, myOld22}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld11,0), std::pair<NodePtr, IOIndex_t>(myOld21,0)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld12,0), std::pair<NodePtr, IOIndex_t>(myOld22,0)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + graphNew->add({myNew11, myNew12, myNew21, myNew22}); + graphNew->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myNew11,0), std::pair<NodePtr, IOIndex_t>(myNew21,0)}); + graphNew->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myNew12,0), std::pair<NodePtr, IOIndex_t>(myNew22,0)}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld11, myOld12, myOld21, myOld22}, {myNew11, myNew12, myNew21, myNew22}); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, myNew11, myNew12, other12, other21, myNew21, myNew22, other22})); + graphTest->save("myGraph2",true,true); + CHECK(retValue); + // Check links + CHECK(myNew11->input(0).first == other11); + CHECK(myNew12->output(0).at(0).first == other12); + CHECK(myNew21->input(0).first == other21); + CHECK(myNew22->output(0).at(0).first == other22); + // Check graph Nodes + CHECK(graphTest->getNode("old11") == nullptr); + CHECK(graphTest->getNode("old21") == nullptr); + CHECK(graphTest->getNode("old12") == nullptr); + CHECK(graphTest->getNode("old22") == nullptr); + CHECK(graphTest->getNode("new11") == myNew11); + CHECK(graphTest->getNode("new12") == myNew12); + CHECK(graphTest->getNode("new21") == myNew21); + CHECK(graphTest->getNode("new22") == myNew22); + } + SECTION("test case 3 none input, multiple output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto oldInput = GenericOperator("Producer", 0, 0, 1, "old_input"); + auto myOld = GenericOperator("Old", 1, 0, 1, "old"); + auto other = GenericOperator("Other", 1, 0, 1, "other"); + oldInput->addChild(myOld); + myOld->addChild(other); + graphTest->add({oldInput, myOld, other}); + + auto newInput = GenericOperator("Producer", 0, 0, 1, "newInput"); + auto myNew = GenericOperator("New", 1, 0, 1, "new"); + newInput->addChild(myNew); + + bool retValue = GraphView::replace({oldInput, myOld}, {newInput, myNew}); + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({newInput, myNew, other})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + CHECK(myNew->output(0).at(0).first == other); + CHECK(graphTest->getNode("old_input") == nullptr); + CHECK(graphTest->getNode("old") == nullptr); + CHECK(graphTest->getNode("newInput") == newInput); + CHECK(graphTest->getNode("new") == myNew); + } +} + +TEST_CASE("[core/graph] Replacing a set of nodes, one old input, same number of outputs", "[GraphView][replace]") +{ +////////////////////////////////////////// +// test case 1 one old input several (2) new inputs, same number of outputs (1) +// 1/ graphview== other_input1 --> other1 --> old1 --> old2 --> other2 +// 2/ replace oldGraph= -->old1 --> old2--> by newGraph= -->new1 --> new3 --> +// -->new2 ---/ +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1 -->new1 --> new3 --> other2 +// \-->new2 ---/ +// - old1, old2 are removed from graphview +// - new1, new2,and new3 are added to graphview +////////////////////////////////////////// +// test case 2 one old input, several (2) new inputs, same number of outputs (0) +// 1/ graphview== other_input1 --> other1 --> old1 --> old2 +// 2/ replace oldGraph= -->old1 --> old2 by newGraph= -->new1 --> new3 +// -->new2 ---/ +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1 -->new1 --> new3 +// \-->new2 ---/ +// - old1, old2 are removed from graphview +// - new1, new2,and new3 are added to graphview +////////////////////////////////////////// +// test case 3 one old input, several (2) new input, same number of outputs (4) +// 1/ graphview== other_input1 --> other1 --> old1 --> old2 --> other4 +// \ \--> other2 +// \---------/ +// \-------> other3 +// 2/ replace oldGraph= -->old1 --> old2 by newGraph= -->new1 --> new3 --> +// -->new2 ---/ \-> +// 3/ verify : +// - output value of replace == TRUE +// /----------\ +// / /-----> other2 +// - graphview== other_input1 --> other1 -->new1 --> new3 --> other4 +// \-->new2 ---/------> other 3 +// - old1, old2 are removed from graphview +// - new1, new2,and new3 are added to graphview +////////////////////////////////////////// +// test case 4 one old input no (0) new inputs, same number of outputs (1) +// 1/ graphview== other_input1 --> other1 --> old1 --> old2 --> other2 +// 2/ replace oldGraph= -->old1 --> old2 by newGraph= new1 --> new3 --> +// new2 ---/ +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1 -->new1 --> new3 +// \-->new2 ---/ +// - old1, old2 are removed from graphview +// - new1, new2,and new3 are added to graphview +////////////////////////////////////////// + SECTION("test case 1 one old input several (2) new inputs, same number of outputs (1)") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld2->addChild(other2); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add(std::set<Aidge::NodePtr>{myNew1, myNew2, myNew3}); + graphTest->save("myGraphBefore",true,true); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + graphTest->save("myGraphAfter",true,true); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + CHECK(myNew3->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 2 one old input, several (2) new inputs, same number of outputs (0)") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + graphTest->add({other1, myOld1, myOld2}); + graphOld->add({myOld1, myOld2}); + + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add(std::set<Aidge::NodePtr>{myNew1, myNew2, myNew3}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3})); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 3 one old input, several (2) new input, same number of outputs (4)") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 2, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data, InputCategory::Data}, 1, "other2"); + auto other3 = GenericOperator("Other", {InputCategory::Data}, 1, "other3"); + auto other4 = GenericOperator("Other", {InputCategory::Data}, 1, "other4"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2, 0, 0); + myOld1->addChild(other2, 0, 1); + myOld1->addChild(other3, 0, 0); + myOld2->addChild(other2, 1, 0); + myOld2->addChild(other4, 0, 0); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + std::vector<std::pair<NodePtr, IOIndex_t>> oldOutputs; + graphOld->setOrderedOutputs(std::vector<std::pair<NodePtr, IOIndex_t>> {{myOld1, 0}, {myOld1, 0}, {myOld2, 0}, {myOld2, 1}}); + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 2, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add(std::set<Aidge::NodePtr>{myNew1, myNew2, myNew3}); + graphNew->setOrderedOutputs(std::vector<std::pair<NodePtr, IOIndex_t>> {{myNew1, 0}, {myNew2, 0}, {myNew3, 0}, {myNew3, 1}}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + + CHECK(myNew3->output(0).at(0).first == other4); + CHECK(myNew3->output(1).at(0).first == other2); + + // TODO: check if the following conditions should be true (there aren't right now) + //CHECK(myNew2->output(0).at(0).first == other3); + //CHECK(myNew1->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 4 one old input no (0) new inputs, same number of outputs (1)") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld2->addChild(other2); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + + // Create and link new graph + auto myNew1 = GenericOperator("Producer", {}, 1, "new1"); + auto myNew2 = GenericOperator("Producer", {}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add(std::set<Aidge::NodePtr>{myNew1, myNew2, myNew3}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + CHECK(retValue); + // Check links + CHECK(myNew3->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 5 one old input several (3) new inputs, same number of outputs (3) from same node output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + auto other3 = GenericOperator("Other", {InputCategory::Data}, 1, "other3"); + // Link old graph + otherInput->addChild(myOld1); + myOld1->addChild(other1); + myOld1->addChild(other2); + myOld1->addChild(other3); + graphTest->add({myOld1, other1, other2, other3}); + graphOld->add({myOld1}); + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + graphNew->add(std::set<Aidge::NodePtr>{myNew1}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({ myNew1, other1, other2, other3})); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == otherInput); + + CHECK(myNew1->output(0).at(0).first == other1); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + } +} + +////////////////////////////////////////// +// test case 1 same input, one new output +// 1/ graphview== other_input1 --> other11 --> old11 --> old12 --> other21 +// other_input2 --> other12 --> old12 --> old22 --> other22 +// 2/ replace oldGraph= -->old11 --> old12--> by newGraph= -->new --> +// -->old12 --> old22--> _/ +// 3/ verify : +// - output value of replace == FALSE +// - graphview== other_input1 --> other11 --> old11 --> old12 --> other21 +// other_input2 --> other12 --> old121 --> old22 --> other22 +// - old11, old12, old21 and old22 are not remove from graphview +// - new is not added to graphview +////////////////////////////////////////// +TEST_CASE("[core/graph] Graph: replacing a set of nodes, same old/new inputs and multiple old/one new output", "[GraphView][replace]") { + +SECTION("test case 1 same input, one new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto otherInput2 = GenericOperator("Producer", 0, 0, 1, "other_input2"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto myOld11 = GenericOperator("Old", 1, 0, 1, "old11"); + auto myOld12 = GenericOperator("Old", 1, 0, 1, "old12"); + auto myOld21 = GenericOperator("Old", 1, 0, 1, "old21"); + auto myOld22 = GenericOperator("Old", 1, 0, 1, "old22"); + auto other21 = GenericOperator("Other", 1, 0, 1, "other21"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + otherInput1->addChild(other11); + other11->addChild(myOld11); + myOld11->addChild(myOld12); + myOld12->addChild(other12); + otherInput2->addChild(other21); + other21->addChild(myOld21); + myOld21->addChild(myOld22); + myOld22->addChild(other22); + graphTest->add({other11, myOld11, myOld12, other12, other21, myOld21, myOld22, other22}); + + graphTest->save("myGraph",true,true); + auto myNew = GenericOperator("New", 2, 0, 1, "new"); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld11, myOld12, myOld21, myOld22}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld11,0), std::pair<NodePtr, IOIndex_t>(myOld21,0)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld12,0), std::pair<NodePtr, IOIndex_t>(myOld22,0)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + graphNew->add({myNew}); + graphNew->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myNew,0), std::pair<NodePtr, IOIndex_t>(myNew,1)}); + graphNew->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myNew,0)}); + + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld11, myOld12, myOld21, myOld22}, {myNew11, myNew12, myNew21, myNew22}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, myOld11, myOld12, other12, other21, myOld21, myOld22, other22})); + graphTest->save("myGraph2",true,true); + CHECK(retValue == false); + CHECK(graphTest->getNode("old11") == myOld11); + CHECK(graphTest->getNode("old12") == myOld12); + CHECK(graphTest->getNode("old21") == myOld21); + CHECK(graphTest->getNode("old22") == myOld22); + CHECK(graphTest->getNode("new") == nullptr); + } +} + +////////////////////////////////////////// +// test case 1 multiple new input, multiple old/one new output +// 1/ graphview== other_input1 --> other11 --> old --> other21 +// \--> other22 +// 2/ replace oldGraph= --> old --> by newGraph= -->new --> +// \--> --/ +// 3/ verify : +// - output value of replace == FALSE +// - graphview== other_input1 --> other11 --> old --> other21 +// \--> other22 +// - old is not remove from graphview +// - new is not added to graphview +////////////////////////////////////////// +TEST_CASE("[core/graph] Graph: replacing a set of nodes, one old/ multiple new inputs and multiple old/ one output output", "[GraphView][replace]") { + +SECTION("test case 1 multiple new input, multiple old/one new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto myOld = GenericOperator("Old", 1, 0, 2, "old"); + auto other21 = GenericOperator("Other", 1, 0, 1, "other21"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + otherInput1->addChild(other11); + other11->addChild(myOld); + myOld->addChild(other21); + myOld->addChild(other22); + graphTest->add({other11, myOld, other21, other22}); + + graphTest->save("myGraph",true,true); + auto myNew = GenericOperator("New", 2, 0, 1, "new"); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld,0)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld,0), std::pair<NodePtr, IOIndex_t>(myOld,1)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + graphNew->add({myNew}); + graphNew->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myNew,0), std::pair<NodePtr, IOIndex_t>(myNew,1)}); + graphNew->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myNew,0)}); + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld}, {myNew}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, myOld, other21, other22})); + graphTest->save("myGraph2",true,true); + CHECK(retValue == false); + CHECK(graphTest->getNode("old") == myOld); + CHECK(graphTest->getNode("new") == nullptr); + } + +} + +////////////////////////////////////////// +// test case 1 multiple old input/output, not same new input/ same new output +// 1/ graphview== other_input1 --> other11 --> old1 --> other21 +// other_input2 --> other21 --> old2 --> other22 +// other_input3 --> other31 --/ +// 2/ replace oldGraph= --> old1 --> by newGraph= --> new1 --> +// --> old2 --> --> new2 --> +// --/ +// 3/ verify : +// - output value of replace == FALSE +// - graphview== other_input1 --> other11 --> old1 --> other12 +// other_input2 --> other21 --> old2 --> other22 +// other_input3 --> other31 --/ +// - old1, old2 are not remove from graphview +// - new1, new2 is not added to graphview +////////////////////////////////////////// +// test case 2 multiple old input/output, same new input/ not same new output +// 1/ graphview== other_input1 --> other11 --> old1 --> other12 +// other_input2 --> other21 --> old2 --> other22 +// \--> other32 +// 2/ replace oldGraph= --> old1 --> by newGraph= --> new1 --> +// --> old2 --> --> new2 --> +// \--> +// 3/ verify : +// - output value of replace == FALSE +// - graphview== other_input1 --> other11 --> old1 --> other12 +// other_input2 --> other21 --> old2 --> other22 +// \-> other32 +// - old1, old2 are not remove from graphview +// - new1, new2 is not added to graphview +TEST_CASE("[core/graph] Graph: replacing a set of nodes, not same number old/new inputs/outputs, failed", "[GraphView][replace]") { + +SECTION("test case 1 not same old/new input, same old/new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto otherInput2 = GenericOperator("Producer", 0, 0, 1, "other_input2"); + auto otherInput3 = GenericOperator("Producer", 0, 0, 1, "other_input3"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto other21= GenericOperator("Other", 1, 0, 1, "other21"); + auto other31 = GenericOperator("Other", 1, 0, 1, "other31"); + auto myOld1 = GenericOperator("Old", 1, 0, 1, "old1"); + auto myOld2 = GenericOperator("Old", 2, 0, 1, "old2"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + otherInput1->addChild(other11); + otherInput2->addChild(other21); + otherInput3->addChild(other31); + other11->addChild(myOld1); + other21->addChild(myOld2); + other31->addChild(myOld2); + myOld1->addChild(other12); + myOld2->addChild(other22); + graphTest->add({other11, other21, other31, myOld1, myOld2, other12, other22}); + + graphTest->save("myGraph",true,true); + auto myNew1 = GenericOperator("New", 1, 0, 1, "new1"); + auto myNew2 = GenericOperator("New", 1, 0, 1, "new2"); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld1, myOld2}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld1,0),std::pair<NodePtr, IOIndex_t>(myOld2,0), std::pair<NodePtr, IOIndex_t>(myOld2,1)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + graphNew->add({myNew1}); + graphNew->add({myNew2}); + graphNew->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myNew1,0), std::pair<NodePtr, IOIndex_t>(myNew2,0)}); + graphNew->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myNew1,0), std::pair<NodePtr, IOIndex_t>(myNew2,0)}); + + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld}, {myNew}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, other21, other31, myOld1, myOld2, other12, other22})); + graphTest->save("myGraph2",true,true); + CHECK(retValue == false); + CHECK(graphTest->getNode("old1") == myOld1); + CHECK(graphTest->getNode("old2") == myOld2); + CHECK(graphTest->getNode("new1") == nullptr); + CHECK(graphTest->getNode("new2") == nullptr); + } + +SECTION("test case 2 same old/new input, not same old/new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto otherInput2 = GenericOperator("Producer", 0, 0, 1, "other_input2"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto other21 = GenericOperator("Other", 1, 0, 1, "other21"); + auto myOld1 = GenericOperator("Old", 1, 0, 1, "old1"); + auto myOld2 = GenericOperator("Old", 1, 0, 2, "old2"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + auto other32 = GenericOperator("Other", 1, 0, 1, "other32"); + otherInput1->addChild(other11); + otherInput2->addChild(other21); + other11->addChild(myOld1); + other21->addChild(myOld2); + myOld1->addChild(other12); + myOld2->addChild(other22); + myOld2->addChild(other32); + graphTest->add({other11, other21, myOld1, myOld2, other12, other22, other32}); + + graphTest->save("myGraph",true,true); + auto myNew1 = GenericOperator("New", 1, 0, 1, "new1"); + auto myNew2 = GenericOperator("New", 1, 0, 1, "new2"); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld1, myOld2}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0), std::pair<NodePtr, IOIndex_t>(myOld2,1)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + graphNew->add({myNew1}); + graphNew->add({myNew2}); + graphNew->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myNew1,0), std::pair<NodePtr, IOIndex_t>(myNew2,0)}); + graphNew->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myNew1,0), std::pair<NodePtr, IOIndex_t>(myNew2,0)}); + + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld}, {myNew}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, other21, myOld1, myOld2, other12, other22, other32})); + graphTest->save("myGraph2",true,true); + CHECK(retValue==false); + CHECK(graphTest->getNode("old1") == myOld1); + CHECK(graphTest->getNode("old2") == myOld2); + CHECK(graphTest->getNode("new1") == nullptr); + CHECK(graphTest->getNode("new2") == nullptr); + } + +} + +////////////////////////////////////////// +// test case 1 replacing a set of nodes by empty set, same number of inputs and outputs +// 1/ graphview== other_input1 --> other11 --> old1 --> other12 +// other_input2 --> other21 --> old2 --> other22 +// other_input3 --> other31 --/ \--> other32 +// 2/ replace oldGraph= --> old1 --> by newGraph= "empty" +// --> old2 --> +// --/ \--> +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1 --> other11 --> other12 +// other_input2 --> other21 --> other22 +// other_input3 --> other31 --> other32 +// - old1, old2 are remove from graphview +TEST_CASE("[core/graph] Graph: replacing a set of nodes, new set of node is empty", "[GraphView][replace]") { + +SECTION("test case 1 multiple old input/output, not same new input/ same new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto otherInput2 = GenericOperator("Producer", 0, 0, 1, "other_input2"); + auto otherInput3 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto other21 = GenericOperator("Other", 1, 0, 1, "other21"); + auto other31 = GenericOperator("Other", 1, 0, 1, "other31"); + auto myOld1 = GenericOperator("Old", 1, 0, 1, "old1"); + auto myOld2 = GenericOperator("Old", 2, 0, 2, "old2"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + auto other32 = GenericOperator("Other", 1, 0, 1, "other32"); + otherInput1->addChild(other11); + otherInput2->addChild(other21); + otherInput3->addChild(other31); + other11->addChild(myOld1); + other21->addChild(myOld2); + other31->addChild(myOld2); + myOld1->addChild(other12); + myOld2->addChild(other22); + myOld2->addChild(other32); + graphTest->add({other11, other21, other31, myOld1, myOld2, other12, other22, other32}); + + graphTest->save("myGraph",true,true); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld1, myOld2}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0), std::pair<NodePtr, IOIndex_t>(myOld2,1)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld}, {}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, other21, other31, other12, other22, other32})); + graphTest->save("myGraph2",true,true); + CHECK(retValue ); + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + } +} + +////////////////////////////////////////// +// test case 1 replacing a set of nodes by empty set, there is a single input +// 1/ graphview== other_input1* -->*other11*-->*old1**--> *other12 +// | +// *old2**--> *other22 +// \--> *other32 +// 2/ replace oldGraph= --> *old1** --> by newGraph= "empty" +// | +// *old2** --> +// \--> +// 3/ verify : +// - output value of replace == TRUE +// - graphview== other_input1*--> *other11*-->*other12 +// \-->*other22 +// \-->*other32 +// - old1, old2 are remove from graphview +TEST_CASE("[core/graph] Graph: replacing a set of nodes, new set of node is empty and single output", "[GraphView][replace]") { + +SECTION("test case 1 multiple old input/output, not same new input / same new output") { + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + auto otherInput1 = GenericOperator("Producer", 0, 0, 1, "other_input1"); + auto other11 = GenericOperator("Other", 1, 0, 1, "other11"); + auto myOld1 = GenericOperator("Old", 1, 0, 2, "old1"); + auto myOld2 = GenericOperator("Old", 1, 0, 2, "old2"); + auto other12 = GenericOperator("Other", 1, 0, 1, "other12"); + auto other22 = GenericOperator("Other", 1, 0, 1, "other22"); + auto other32 = GenericOperator("Other", 1, 0, 1, "other32"); + otherInput1->addChild(other11); + other11->addChild(myOld1); + myOld1->addChild(myOld2, 1, 0); + myOld1->addChild(other12, 0, 0); + myOld2->addChild(other22, 0, 0); + myOld2->addChild(other32,1, 0); + graphTest->add({other11, myOld1, myOld2, other12, other22, other32}); + + graphTest->save("myGraph",true,true); + // graphOld + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("oldGraph"); + graphOld->add({myOld1, myOld2}); + graphOld->setOrderedInputs({std::pair<NodePtr, IOIndex_t>(myOld1,0)}); + graphOld->setOrderedOutputs({std::pair<NodePtr, IOIndex_t>(myOld1,0), std::pair<NodePtr, IOIndex_t>(myOld2,0), std::pair<NodePtr, IOIndex_t>(myOld2,1)}); + // graphNew + std::shared_ptr<GraphView> graphNew= std::make_shared<GraphView>("newGraph"); + + + bool retValue = GraphView::replace(graphOld, graphNew); + //bool retValue = GraphView::replace({myOld}, {}); + + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other11, other12, other22, other32})); + auto childs = other11->getChildren(0); + CHECK(childs.at(0)== other12); + CHECK(childs.at(1)== other22); + CHECK(childs.at(2)== other32); + graphTest->save("myGraph2",true,true); + CHECK(retValue ); + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + } +} -- GitLab From d1ac46b27711f17fa6995f91c91010e9c2528468 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 2 Mar 2025 15:31:16 +0000 Subject: [PATCH 12/13] Removed unused files --- include/aidge/utilsParsing/AstNode.hpp | 69 ------------------ include/aidge/utilsParsing/ParsingToken.hpp | 79 --------------------- 2 files changed, 148 deletions(-) delete mode 100644 include/aidge/utilsParsing/AstNode.hpp delete mode 100644 include/aidge/utilsParsing/ParsingToken.hpp diff --git a/include/aidge/utilsParsing/AstNode.hpp b/include/aidge/utilsParsing/AstNode.hpp deleted file mode 100644 index bf4f73236..000000000 --- a/include/aidge/utilsParsing/AstNode.hpp +++ /dev/null @@ -1,69 +0,0 @@ - - -#ifndef AIDGE_CORE_AST_NODE_H_ -#define AIDGE_CORE_AST_NODE_H_ - -#include <string> -#include <type_traits> -#include <vector> -#include <memory> -#include "aidge/utilsParsing/ParsingToken.hpp" - -namespace Aidge{ - - template <typename EnumType> - class AstNode: public std::enable_shared_from_this<AstNode<EnumType>> - { - static_assert(std::is_enum<EnumType>::value, "AstNode EnumType must be an enum type"); - public: - AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode<EnumType>>> child ={}):mToken(token),mChild(child){} - /** - * @brief get the type of the token - * @return the type - */ - EnumType getType() const{ - return mToken->getType(); - } - - /** - * @brief get the lexeme of the token - * @return the lexeme - */ - std::string getValue() const{ - return mToken->getLexeme(); - } - /** - * @brief get the child of the node - * @return child - */ - const std::vector<std::shared_ptr<AstNode>>& getChilds() const { - return mChild; - } - /** - * @brief test if the node is a leaf in the tree - * @return true if a leaf - */ - bool isLeaf() const { - return mChild.size() == 0; - } - - /** - * @brief get the number of child - * @return the number of child - */ - std::size_t nbChild() const{ - return mChild.size(); - } - private: - /** - * @brief the token of the node - */ - const std::shared_ptr<ParsingToken<EnumType>> mToken; - /** - * @brief list of child - */ - const std::vector<std::shared_ptr<AstNode>> mChild; - }; -} - -#endif //AIDGE_CORE_AST_NODE_H_ diff --git a/include/aidge/utilsParsing/ParsingToken.hpp b/include/aidge/utilsParsing/ParsingToken.hpp deleted file mode 100644 index 3781fcbf1..000000000 --- a/include/aidge/utilsParsing/ParsingToken.hpp +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_CORE_PARSING_TOKEN_H_ -#define AIDGE_CORE_PARSING_TOKEN_H_ - -#include <string> -#include <type_traits> - -#include <fmt/format.h> - -namespace Aidge { - -template <typename EnumType> -class ParsingToken: public std::enable_shared_from_this<ParsingToken<EnumType>> -{ - static_assert(std::is_enum<EnumType>::value, "ParsingToken EnumType must be an enum type"); -public: - /** - * @brief Token container - * @param type one of the token type - * @param lexeme String representing additional information of the token - */ - ParsingToken(const EnumType type , const std::string& lexeme ) - : mLexeme(lexeme), mType(type){} - - /** - * @brief get the lexeme - * @return std::string - */ - const std::string getLexeme(void) const noexcept { - return mLexeme; - } - - /** - * @brief get the token type - * - * @return ParsingToken - */ - const EnumType getType(void) const noexcept { - return mType; - } - - /** - * @brief copy the token - * @return deep copy of the token - */ - std::shared_ptr<ParsingToken> copy() const noexcept { - return std::make_shared<ParsingToken<EnumType>>(mType, mLexeme); - } - - //TODO - std::string rep(void) const { return fmt::format(" Token ({})\n", mLexeme); } - -private: - /** - * @brief additional information of the token - */ - const std::string mLexeme; - - /** - * @brief type of the token - * @see ConditionalTokenTypes - */ - const EnumType mType; - -}; - -} // namespace Aidge - -#endif //AIDGE_CORE_PARSING_TOKEN_H_ -- GitLab From 51a9ebe6577714815b341471fdd228cd18aa5b9e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 2 Mar 2025 15:39:39 +0000 Subject: [PATCH 13/13] Merged with dev (missing change) --- src/scheduler/Scheduler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index fdda95727..177975545 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -1075,8 +1075,8 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName, bool i fmt::print(fp.get(), "\n"); } -void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); +void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName, bool ignoreProducers) const { + auto fp = createFile(fileName + ".mmd", "w"); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, -- GitLab