From a3f3e5109367a9a6dfa6b8e6d7722050bc1e306c Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 31 Jul 2024 10:35:53 +0200 Subject: [PATCH] Added hook mechanism and Abs operator --- include/aidge/data/Tensor.hpp | 12 ++++++ include/aidge/graph/Node.hpp | 20 ++++++++++ include/aidge/operator/Abs.hpp | 71 ++++++++++++++++++++++++++++++++++ src/data/Tensor.cpp | 28 ++++++++++++++ src/graph/Node.cpp | 29 +++++++++++--- src/operator/Abs.cpp | 25 ++++++++++++ 6 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 include/aidge/operator/Abs.hpp create mode 100644 src/operator/Abs.cpp diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 89d7a3a7b..c7b712be4 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -312,6 +312,18 @@ class Tensor : public Data, */ Tensor sqrt() const; + /** + * @brief Element-wise abs operation for Tensor. + * @return Tensor + */ + Tensor abs() const; + + /** + * @brief Mean operation for Tensor. + * @return Tensor + */ + Tensor mean() const; + ~Tensor() noexcept; public: diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f694a1234..3be17d6d2 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -17,6 +17,7 @@ #include <set> #include <string> #include <vector> +#include <deque> #include <utility> #ifdef PYBIND @@ -63,6 +64,9 @@ private: std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ + std::deque<std::function<bool()>> mForward; + std::deque<std::function<bool()>> mBackward; + public: Node() = delete; @@ -79,6 +83,22 @@ public: return lhs.shared_from_this() == rhs.shared_from_this(); } + void addBeforeForward(std::function<bool()> func) { + mForward.push_front(func); + } + + void addAfterForward(std::function<bool()> func) { + mForward.push_back(func); + } + + void addBeforeBackward(std::function<bool()> func) { + mBackward.push_front(func); + } + + void addAfterBackward(std::function<bool()> func) { + mBackward.push_back(func); + } + public: /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION diff --git a/include/aidge/operator/Abs.hpp b/include/aidge/operator/Abs.hpp new file mode 100644 index 000000000..3c2f1bb38 --- /dev/null +++ b/include/aidge/operator/Abs.hpp @@ -0,0 +1,71 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_ABS_H_ +#define AIDGE_CORE_OPERATOR_ABS_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +class Abs_Op : public OperatorTensor, + public Registrable<Abs_Op, std::string, std::shared_ptr<OperatorImpl>(const Abs_Op&)> { +public: + static const std::string Type; + + Abs_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Abs_Op(const Abs_Op& op) + : OperatorTensor(op) + { + if (op.mImpl) { + SET_IMPL_MACRO(Abs_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Abs_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Abs_Op>(*this); + } + + void setBackend(const std::string& name, DeviceIdx_t device = 0) override; + + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Abs(const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Abs_Op>(), name); +} +} + +#endif /* AIDGE_CORE_OPERATOR_ABS_H_ */ diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 20bf3fb78..e382fe2ac 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -16,9 +16,11 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Abs.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/Div.hpp" #include "aidge/operator/Mul.hpp" +#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/Sub.hpp" #include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Transpose.hpp" @@ -106,6 +108,32 @@ Aidge::Tensor Aidge::Tensor::sqrt() const { return sqrt_.getOutput(0)->clone(); } +Aidge::Tensor Aidge::Tensor::abs() const { + AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); + auto abs_ = Abs_Op(); + abs_.associateInput(0, std::make_shared<Tensor>(*this)); + abs_.setDataType(dataType()); + abs_.setDataFormat(dataFormat()); + abs_.setBackend(mImpl->backend()); + abs_.forward(); + return abs_.getOutput(0)->clone(); +} + +Aidge::Tensor Aidge::Tensor::mean() const { + AIDGE_ASSERT(hasImpl(), "Tensor has no implementation."); + // TODO: should be the default behavior of ReduceMean_Op + // No need to specify the list of all axes! + std::vector<std::int32_t> axes(nbDims()); + std::iota(std::begin(axes), std::end(axes), 0); + auto mean_ = ReduceMean_Op(axes, 0); + mean_.associateInput(0, std::make_shared<Tensor>(*this)); + mean_.setDataType(dataType()); + mean_.setDataFormat(dataFormat()); + mean_.setBackend(mImpl->backend()); + mean_.forward(); + return mean_.getOutput(0)->clone(); +} + Aidge::Tensor& Aidge::Tensor::operator=(const Aidge::Tensor& other) { if (this == &other) { return *this; diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 7fe155b5a..6f24e0170 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -29,8 +29,13 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) mIdInChildren(std::vector<std::vector<IOIndex_t>>(static_cast<std::size_t>(op->nbOutputs()), std::vector<IOIndex_t>())), mIdOutParents( - std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) { + std::vector<IOIndex_t>(static_cast<std::size_t>(op->nbInputs()), gk_IODefaultIndex)) +{ // ctor + if (op) { + mForward.push_back([this](){ this->mOperator->forward(); return true; }); + mBackward.push_back([this](){ this->mOperator->backward(); return true; }); + } } /////////////////////////////////////////////////////// @@ -82,13 +87,27 @@ std::string Aidge::Node::createUniqueName(std::string name){ /////////////////////////////////////////////////////// void Aidge::Node::forward() { - assert((mOperator != nullptr) && "No Operator interface provided, can't run forward().\n"); - mOperator->forward(); + for (auto it = mForward.begin(); it != mForward.end(); ) { + const auto keep = (*it)(); + if (!keep) { + it = mForward.erase(it); + } + else { + ++it; + } + } } void Aidge::Node::backward() { - assert((mOperator != nullptr) && "No Operator interface provided, can't run backward().\n"); - mOperator->backward(); + for (auto it = mBackward.begin(); it != mBackward.end(); ) { + const auto keep = (*it)(); + if (!keep) { + it = mBackward.erase(it); + } + else { + ++it; + } + } } /////////////////////////////////////////////////////// diff --git a/src/operator/Abs.cpp b/src/operator/Abs.cpp new file mode 100644 index 000000000..a8ee706f6 --- /dev/null +++ b/src/operator/Abs.cpp @@ -0,0 +1,25 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/Abs.hpp" + +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::Abs_Op::Type = "Abs"; + +void Aidge::Abs_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(Abs_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} -- GitLab