Skip to content
Snippets Groups Projects
Commit f3364c9f authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'feat/operator_globalAveragePooling' into 'dev'

Feat/operator global average pooling

See merge request eclipse/aidge/aidge_core!91
parents 08a66f38 6ec0de5c
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ option(PYBIND "python binding" ON) ...@@ -19,7 +19,7 @@ option(PYBIND "python binding" ON)
option(WERROR "Warning as error" OFF) option(WERROR "Warning as error" OFF)
option(TEST "Enable tests" ON) option(TEST "Enable tests" ON)
option(COVERAGE "Enable coverage" OFF) option(COVERAGE "Enable coverage" OFF)
option(ENABLE_ASAN "Enable ASan (adress sanitizer) for runtime analysis of memory use (over/underflow, memory leak, ...)" OFF) option(ENABLE_ASAN "Enable ASan (AddressSanitizer) for runtime analysis of memory use (over/underflow, memory leak, ...)" OFF)
############################################## ##############################################
# Import utils CMakeLists # Import utils CMakeLists
......
/******************************************************************************** /********************************************************************************
* Copyright (c) 2023 CEA-List * Copyright (c) 2023 CEA-List
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at * terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0. * http://www.eclipse.org/legal/epl-2.0.
* *
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_IMPORTS_H_ #ifndef AIDGE_IMPORTS_H_
#define AIDGE_IMPORTS_H_ #define AIDGE_IMPORTS_H_
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/backend/TensorImpl.hpp" #include "aidge/backend/TensorImpl.hpp"
#include "aidge/backend/StimulusImpl.hpp" #include "aidge/backend/StimulusImpl.hpp"
#include "aidge/backend/cpu/data/TensorImpl.hpp" #include "aidge/backend/cpu/data/TensorImpl.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/Database.hpp" #include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp" #include "aidge/data/DataProvider.hpp"
#include "aidge/graph/Connector.hpp" #include "aidge/graph/Connector.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/filler/Filler.hpp" #include "aidge/filler/Filler.hpp"
#include "aidge/nodeTester/ConditionalInterpreter.hpp" #include "aidge/nodeTester/ConditionalInterpreter.hpp"
#include "aidge/operator/Add.hpp" #include "aidge/operator/Add.hpp"
#include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Concat.hpp" #include "aidge/operator/Concat.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Div.hpp" #include "aidge/operator/Div.hpp"
#include "aidge/operator/Erf.hpp" #include "aidge/operator/Erf.hpp"
#include "aidge/operator/FC.hpp" #include "aidge/operator/FC.hpp"
#include "aidge/operator/Gather.hpp" #include "aidge/operator/Gather.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MatMul.hpp" #include "aidge/operator/GlobalAveragePooling.hpp"
#include "aidge/operator/MaxPooling.hpp" #include "aidge/operator/MatMul.hpp"
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/Mul.hpp" #include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Mul.hpp"
#include "aidge/operator/Pad.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Pad.hpp"
#include "aidge/operator/Pow.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/Pow.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/Reshape.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Scaling.hpp" #include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Slice.hpp" #include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Softmax.hpp" #include "aidge/operator/Slice.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Sub.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Sub.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/operator/Transpose.hpp"
#include "aidge/stimuli/Stimulus.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/stimuli/Stimulus.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Attributes.hpp"
#include "aidge/utils/DynamicAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Random.hpp" #include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Random.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#endif /* AIDGE_IMPORTS_H_ */
#endif /* AIDGE_IMPORTS_H_ */
...@@ -44,10 +44,10 @@ private: ...@@ -44,10 +44,10 @@ private:
/// @brief Set of nodes included in the graphview with names /// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry; std::map<std::string, NodePtr> mNodeRegistry;
/// @brief GraphView inputs /// @brief GraphView inputs IOIndex_t designates the input number
std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
/// @brief GraphView outputs /// @brief GraphView outputs IOIndex_t designates the input number
std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public: public:
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_GLOBAL_AVERAGE_POOLING_H_
#define AIDGE_CORE_OPERATOR_GLOBAL_AVERAGE_POOLING_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief Description for the tensor data structure.
* @details Sets the properties of the tensor without actually containing any
* data. Contains a pointer to an actual contiguous implementation of data.
*/
class GlobalAveragePooling_Op
: public OperatorTensor,
public Registrable<GlobalAveragePooling_Op, std::string,
std::shared_ptr<OperatorImpl>(
const GlobalAveragePooling_Op &)> {
public:
static const std::string Type;
GlobalAveragePooling_Op() : OperatorTensor(Type, 1, 0, 1) {}
GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op)
: OperatorTensor(op) {
if (op.mImpl) {
SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
}
std::shared_ptr<Operator> clone() const override {
return std::make_shared<GlobalAveragePooling_Op>(*this);
}
void computeOutputDims() override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override final;
static const std::vector<std::string> getInputsName() {
return {"data_input"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
inline std::shared_ptr<Node>
GlobalAveragePooling(const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<GlobalAveragePooling_Op>(),
name);
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_GLOBAL_AVERAGE_POOLING_H_ */
...@@ -57,9 +57,9 @@ void init_Node(py::module& m) { ...@@ -57,9 +57,9 @@ void init_Node(py::module& m) {
:param other_node: Pointer to the other Node. :param other_node: Pointer to the other Node.
:type other_node: :py:class: Node :type other_node: :py:class: Node
:param out_id: ID of the current Node output to connect to the other Node. Default to 0. :param out_id: ID of the output of the current Node to connect to the other Node. (If Node has 1 output max ID is 0). Default to 0.
:type out_id: int :type out_id: int
:param other_in_id: ID of the other Node input to connect to the current Node. Default to the first avaible data input. :param other_in_id: ID of the input of the other Node to connect to the current Node (If the node is a Mul op it has 2 input then Max ID is 1).Default to the first avaible data input.
:type other_in_id: int :type other_in_id: int
)mydelimiter") )mydelimiter")
...@@ -85,7 +85,7 @@ void init_Node(py::module& m) { ...@@ -85,7 +85,7 @@ void init_Node(py::module& m) {
:type other_view: :py:class: GraphView :type other_view: :py:class: GraphView
:param out_id: ID of the current Node output to connect to the other Node. Default to 0. :param out_id: ID of the current Node output to connect to the other Node. Default to 0.
:type out_id: int :type out_id: int
:param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node. :param other_in_id: Pair of Node and input connection ID for specifying the connection. If the GraphView whose content is linked has only one input Node, then it defaults to the first available data input ID of this Node.
:type other_in_id: tuple[:py:class: Node, int] :type other_in_id: tuple[:py:class: Node, int]
)mydelimiter") )mydelimiter")
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/GlobalAveragePooling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Attributes.hpp"
namespace py = pybind11;
namespace Aidge {
const std::string pyClassName("GlobalAveragePoolingOp");
void init_GlobalAveragePooling(py::module &m) {
py::class_<GlobalAveragePooling_Op, std::shared_ptr<GlobalAveragePooling_Op>,
OperatorTensor>(m, pyClassName.c_str(),
py::multiple_inheritance())
.def("get_inputs_name", &GlobalAveragePooling_Op::getInputsName)
.def("get_outputs_name", &GlobalAveragePooling_Op::getOutputsName);
declare_registrable<GlobalAveragePooling_Op>(m, pyClassName);
m.def("globalaveragepooling", &GlobalAveragePooling, py::arg("name") = "");
}
} // namespace Aidge
...@@ -38,6 +38,7 @@ void init_Erf(py::module&); ...@@ -38,6 +38,7 @@ void init_Erf(py::module&);
void init_FC(py::module&); void init_FC(py::module&);
void init_Gather(py::module&); void init_Gather(py::module&);
void init_GenericOperator(py::module&); void init_GenericOperator(py::module&);
void init_GlobalAveragePooling(py::module&);
void init_LeakyReLU(py::module&); void init_LeakyReLU(py::module&);
void init_MatMul(py::module&); void init_MatMul(py::module&);
void init_MaxPooling(py::module&); void init_MaxPooling(py::module&);
...@@ -103,6 +104,7 @@ void init_Aidge(py::module& m) { ...@@ -103,6 +104,7 @@ void init_Aidge(py::module& m) {
init_FC(m); init_FC(m);
init_Gather(m); init_Gather(m);
init_GenericOperator(m); init_GenericOperator(m);
init_GlobalAveragePooling(m);
init_LeakyReLU(m); init_LeakyReLU(m);
init_MatMul(m); init_MatMul(m);
init_MaxPooling(m); init_MaxPooling(m);
......
...@@ -889,36 +889,45 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ...@@ -889,36 +889,45 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
return GraphView::replace(oldG, newG); return GraphView::replace(oldG, newG);
} }
bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std::shared_ptr<GraphView>& newG) { bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const std::shared_ptr<GraphView>& newGraph) {
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input? // How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions // TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included // It also avoids specifying each producer since they are automatically included
const auto& oldNodes = oldG->getNodes(); const std::set<NodePtr>& oldNodes = oldGraph->getNodes();
const auto& newNodes = newG->getNodes(); const std::set<NodePtr>& newNodes = newGraph->getNodes();
const auto oldOI = oldG->getOrderedInputs(); const std::vector<std::pair<NodePtr, IOIndex_t>> oldOIn =
const auto oldOO = oldG->getOrderedOutputs(); oldGraph->getOrderedInputs();
const auto newOI = newG->getOrderedInputs(); const std::vector<std::pair<NodePtr, IOIndex_t>> oldOOut =
const auto newOO = newG->getOrderedOutputs(); oldGraph->getOrderedOutputs();
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOI.size()); const std::vector<std::pair<NodePtr, IOIndex_t>> newOIn =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOO.size()); newGraph->getOrderedInputs();
const std::vector<std::pair<NodePtr, IOIndex_t>> newOOut =
// keep in memory every parent newGraph->getOrderedOutputs();
for (std::size_t i = 0; i < oldOI.size(); ++i) {
auto inputParent = oldOI[i].first -> input(oldOI[i].second); auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size());
auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size());
// keep in memory every node related to the node to replace :
// Parent
for (std::size_t i = 0; i < oldOIn.size(); ++i) {
std::pair<NodePtr, IOIndex_t> inputParent =
oldOIn[i].first -> input(oldOIn[i].second);
inputParents[i]= inputParent; inputParents[i]= inputParent;
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
} }
for (std::size_t i = 0; i < oldOO.size();) { // Children
auto outputChildList = oldOO[i].first -> output(oldOO[i].second); for (std::size_t i = 0; i < oldOOut.size();) {
if (outputChildList.empty()) { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
oldOOut[i].first -> output(oldOOut[i].second);
if (outputChild.empty()) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
++i; ++i;
} }
else { else {
for (const auto& child : outputChildList) { for (const auto& child : outputChild) {
if (oldNodes.find(child.first) == oldNodes.cend()) { if (oldNodes.find(child.first) == oldNodes.cend()) {
outputChildren[i] = child; outputChildren[i] = child;
++i; ++i;
...@@ -931,37 +940,37 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std ...@@ -931,37 +940,37 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std
// set of common GraphView for oldNodes' Nodes // set of common GraphView for oldNodes' Nodes
std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views();
for (const auto& nodePtr : oldNodes) { for (const auto& nodePtr : oldNodes) {
const auto nodeView = nodePtr->views(); const std::set<std::shared_ptr<GraphView>> nodeView = nodePtr->views();
std::set<std::shared_ptr<GraphView>> intersection; std::set<std::shared_ptr<GraphView>> intersection;
std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
nodeView.begin(), nodeView.end(), nodeView.begin(), nodeView.end(),
std::inserter(intersection, intersection.begin())); std::inserter(intersection, intersection.begin()));
commonGraphViews = intersection; commonGraphViews = intersection;
} }
commonGraphViews.erase(oldG); commonGraphViews.erase(oldGraph);
commonGraphViews.erase(newG); commonGraphViews.erase(newGraph);
if ((newNodes.size() > 0) && (oldOI.size() != newOI.size()) && (oldOO.size() != newOO.size())) { if ((newNodes.size() > 0) && (oldOIn.size() != newOIn.size()) && (oldOOut.size() != newOOut.size())) {
for (const auto& nodePtr : oldNodes) { for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG); nodePtr->removeView(oldGraph);
} }
for (const auto& nodePtr : newNodes) { for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG); nodePtr->removeView(newGraph);
} }
return false; return false;
} }
if ((oldOI.size() == newOI.size()) && if ((oldOIn.size() == newOIn.size()) &&
(oldOO.size() == newOO.size())) { (oldOOut.size() == newOOut.size())) {
// Case 1 // Case 1
for (std::size_t i = 0; i < oldOI.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) { if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[i].first, inputParents[i].second, newOI[i].second); inputParents[i].first -> addChild(newOIn[i].first, inputParents[i].second, newOIn[i].second);
} }
} }
for (std::size_t o = 0; o < oldOO.size(); ++o) { for (std::size_t o = 0; o < oldOOut.size(); ++o) {
if (outputChildren[o].first) { if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second);
} }
} }
} }
...@@ -970,52 +979,53 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std ...@@ -970,52 +979,53 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std
// get the number of Children for oldg->outputNodes() // get the number of Children for oldg->outputNodes()
if (newNodes.size() == 0) { if (newNodes.size() == 0) {
// Case 3 // Case 3
if (oldOI.size() == oldOO.size()) { if (oldOIn.size() == oldOOut.size()) {
for (std::size_t i = 0; i < oldOI.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) if (inputParents[i].first) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
}
} }
} }
else if ((oldOI.size() == 1) && (inputParents[0].first)) { else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < oldOI.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second); inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
} }
} }
} }
else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes else if ( // for tiling-like cases. The number of inputNodes changes but not outputNodes
((oldOI.size() == 1) || (newOI.size() == 1)) && // (oldOI.size() == newOI.size()) already handled in Case 1 ((oldOIn.size() == 1) || (newOIn.size() == 1)) && // (oldOIn.size() == newOI.size()) already handled in Case 1
((oldOO.size() == newOO.size())) ((oldOOut.size() == newOOut.size()))
) { ) {
// Case 2 // Case 2
if ((oldOI.size() == 1) && (inputParents[0].first)) { if ((oldOIn.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < newOI.size(); ++i) { for (std::size_t i = 0; i < newOIn.size(); ++i) {
inputParents[0].first -> addChild(newOI[i].first, inputParents[0].second, newOI[i].second); inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second);
} }
} else { } else {
for (std::size_t i = 0; i < oldOI.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) { if (inputParents[i].first) {
inputParents[i].first -> addChild(newOI[0].first, inputParents[i].second, newOI[0].second); inputParents[i].first -> addChild(newOIn[0].first, inputParents[i].second, newOIn[0].second);
} }
} }
} }
for (std::size_t o = 0; o < oldOO.size(); ++o) { for (std::size_t o = 0; o < oldOOut.size(); ++o) {
if (outputChildren[o].first) { if (outputChildren[o].first) {
newOO[o].first -> addChild(outputChildren[o].first, newOO[o].second, outputChildren[o].second); newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second);
} }
} }
} }
else { else {
for (const auto& nodePtr : oldNodes) { for (const auto& nodePtr : oldNodes) {
nodePtr->removeView(oldG); nodePtr->removeView(oldGraph);
} }
for (const auto& nodePtr : newNodes) { for (const auto& nodePtr : newNodes) {
nodePtr->removeView(newG); nodePtr->removeView(newGraph);
} }
return false; return false;
} }
} }
auto oldGOutputs = oldG->outputNodes(); auto oldGOutputs = oldGraph->outputNodes();
for (const auto& nodePtr : oldNodes) { for (const auto& nodePtr : oldNodes) {
bool removeFromGraphs = true; bool removeFromGraphs = true;
if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) { if (std::find(oldGOutputs.cbegin(), oldGOutputs.cend(), nodePtr) == oldGOutputs.cend()) {
...@@ -1041,10 +1051,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std ...@@ -1041,10 +1051,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std
} }
} }
for (const auto& nodePtr : oldNodes) { for (const auto& nodePtr : oldNodes) {
nodePtr -> removeView(oldG); nodePtr -> removeView(oldGraph);
} }
for (const auto& nodePtr : newNodes) { for (const auto& nodePtr : newNodes) {
nodePtr -> removeView(newG); nodePtr -> removeView(newGraph);
} }
return true; return true;
} }
...@@ -1247,7 +1257,14 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo ...@@ -1247,7 +1257,14 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
} }
if (deletedNode == mRootNode) { if (deletedNode == mRootNode) {
mRootNode = nullptr; const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes();
if(ranked_nodes.second== 0 )
{
mRootNode = nullptr;
} else {
// The new root node will be the second node in the order of ranked nodes
setRootNode(*std::next(ranked_nodes.first.cbegin(),1));
}
} }
} }
......
...@@ -283,7 +283,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const { ...@@ -283,7 +283,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::Node::getChildren() const {
} }
std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const { std::vector<std::vector<std::shared_ptr<Aidge::Node>>> Aidge::Node::getOrderedChildren() const {
std::vector<std::vector<std::shared_ptr<Node>>> children = auto children =
std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size()); std::vector<std::vector<std::shared_ptr<Node>>>(mChildren.size());
for (std::size_t outId = 0; outId < mChildren.size(); ++outId) { for (std::size_t outId = 0; outId < mChildren.size(); ++outId) {
children[outId] = getChildren(outId); children[outId] = getChildren(outId);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/GlobalAveragePooling.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::GlobalAveragePooling_Op::Type = "GlobalAveragePooling";
void Aidge::GlobalAveragePooling_Op::computeOutputDims() {
// error checking
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"GlobalAveragePooling : The input was not connected");
}
// necessary bc forward dims sometimes passes with an empty vector before
// doing another pass
else if (getInput(0)->empty()) {
return;
// computation
} else {
AIDGE_ASSERT(getInput(0)->dims().size() >= 3,
"GlobalAveragePooling : needs at least a 3 dimensions input, "
"number of input dim : {}",
getInput(0)->dims().size());
// Global average pooling takes each filter, averages its values and uses
// it as an output(Much like a fancier flatten). 1st dim is batch 2nd is
// number of filter
const std::vector<DimSize_t> out_dims{getInput(0)->dims().at(0),
getInput(0)->dims().at(1)};
mOutputs[0]->resize(out_dims);
}
}
void Aidge::GlobalAveragePooling_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
\ No newline at end of file
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
namespace Aidge { namespace Aidge {
void removeFlatten(std::shared_ptr<Node> flatten) { void removeFlatten(std::shared_ptr<Node> flatten) {
GraphView::replace({flatten}, {}); GraphView::replace({flatten}, {});
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cstddef> // std::size_t
#include <memory>
#include <random> // std::random_device, std::mt19937, std::uniform_int_distribution
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/GlobalAveragePooling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
TEST_CASE("[core/operator] GlobalAveragePooling_Op(computeOutputDims)",
"[GlobalAveragePooling][computeOutputDims]") {
constexpr std::uint16_t NB_TRIALS = 10;
// Create a random number generator
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::size_t> dimsDist(1, 10);
std::uniform_int_distribution<std::size_t> inf3DimsDistribution(1, 2);
std::uniform_int_distribution<std::size_t> sup3DimsDistribution(3, 10);
// Create the GlobalAveragePooling Operator
std::shared_ptr<Node> myGlobAvgPool = GlobalAveragePooling();
auto op =
std::static_pointer_cast<OperatorTensor>(myGlobAvgPool->getOperator());
// input_0
std::shared_ptr<Tensor> input_T = std::make_shared<Tensor>();
SECTION("Un-connected input leads to failure.") {
REQUIRE_THROWS(op->computeOutputDims());
}
op->associateInput(0, input_T);
SECTION("Connected Inputs") {
SECTION("empty tensor") {
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
const std::size_t nb_dims = 0;
std::vector<std::size_t> dims(nb_dims);
input_T->resize(dims);
REQUIRE_NOTHROW(op->computeOutputDims());
}
}
SECTION("Full tensor") {
SECTION("nbDim < 3") {
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
const std::size_t nb_dims = inf3DimsDistribution(gen);
std::vector<std::size_t> dims(nb_dims);
for (uint16_t i = 0; i < nb_dims; ++i) {
dims[i] = dimsDist(gen);
}
input_T->resize(dims);
REQUIRE_THROWS(op->computeOutputDims());
}
}
SECTION("nbDim > 3") {
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
const std::size_t nb_dims = sup3DimsDistribution(gen);
std::vector<std::size_t> dims(nb_dims);
for (uint16_t i = 0; i < nb_dims; ++i) {
dims[i] = dimsDist(gen) + 1;
}
std::vector<DimSize_t> dims_out{dims[0], dims[1]};
input_T->resize(dims);
op->setInput(0, input_T);
REQUIRE_NOTHROW(op->computeOutputDims());
REQUIRE(op->getOutput(0)->dims() == dims_out);
REQUIRE((op->getOutput(0)->dims().size()) == static_cast<size_t>(2));
}
}
}
}
}
} // namespace Aidge
...@@ -10,40 +10,92 @@ ...@@ -10,40 +10,92 @@
********************************************************************************/ ********************************************************************************/
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <memory>
#include <set> #include <set>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/FC.hpp" #include "aidge/operator/FC.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") {
std::shared_ptr<Node> flatten =
GenericOperator("Flatten", 1, 0, 1, "myFlatten");
std::shared_ptr<Node> fc0 = FC(10, 10, "FC_1");
std::shared_ptr<Node> fc1 = FC(10, 10, "FC_2");
std::shared_ptr<Node> prod = Producer(std::array<DimSize_t, 10>(), "myProd");
SECTION("flatten last layer : nothing removed because pattern searched is "
"Flatten=>FC") {
std::shared_ptr<Aidge::GraphView> g = Sequential({fc0, flatten});
removeFlatten(g);
CHECK(g->getOrderedOutputs().size() == 1);
CHECK(g->getOrderedOutputs()[0].first == flatten);
TEST_CASE("[cpu/recipes] RemoveFlatten", "[RemoveFlatten][recipes]") { CHECK(g->getOrderedInputs().size() == 1);
// generate the original GraphView CHECK(g->getOrderedInputs()[0].first == fc0);
auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten");
auto fc = FC(10, 50, "myFC"); CHECK(fc0->getParent(0) == nullptr);
CHECK(fc0->getChildren(0).size() == 1);
CHECK(g->rootNode() == fc0);
}
SECTION("flatten first layer : flatten removed") {
auto g = Sequential({flatten, fc0});
flatten -> addChild(fc); removeFlatten(g);
CHECK(g->getOrderedInputs().size() == 1);
CHECK(g->getOrderedInputs()[0].first == fc0);
CHECK(g->getOrderedOutputs().size() == 1);
CHECK(g->getOrderedOutputs()[0].first == fc0);
CHECK(fc0->getParent(0) == nullptr);
CHECK(fc0->getChildren(0).size() == 0);
CHECK(g->rootNode() == fc0);
}
SECTION("flatten middle layer") {
auto g = Sequential({fc0, flatten, fc1});
removeFlatten(g);
auto g = std::make_shared<GraphView>(); CHECK(g->getOrderedInputs().size() == 1);
g->add({fc, flatten}); CHECK(g->getOrderedInputs()[0].first == fc0);
// Check original graph CHECK(g->getOrderedOutputs().size() == 1);
// g -> save("before_remove_flatten"); CHECK(g->getOrderedOutputs()[0].first == fc1);
CHECK(fc1->getParent(0) == fc0);
CHECK(fc0->getChildren(0)[0] == fc1);
CHECK(g->rootNode() == fc0);
}
SECTION("flatten right after a producer") {
auto g = Sequential({prod, flatten, fc0});
// prod->addChild(flatten, 0);
// flatten->addChild(fc0, 0);
// auto g = std::make_shared<GraphView>({prod, flatten, fc0});
// use recipie
removeFlatten(g); removeFlatten(g);
// Check transformed graph CHECK(g->getOrderedInputs().size() == 0);
// g -> save("after_remove_flatten");
CHECK(g->getOrderedOutputs().size() == 1);
CHECK(g->getOrderedOutputs()[0].first == fc0);
CHECK(fc0->getParent(0) == prod);
CHECK(fc0->getChildren(0).size() == 0);
REQUIRE(g->getOrderedInputs().size() == 1); CHECK(g->rootNode() == prod);
REQUIRE(g->getOrderedOutputs().size() == 1); }
REQUIRE(g->getOrderedInputs()[0].first == fc);
REQUIRE(g->getOrderedOutputs()[0].first == fc);
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment