Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • eclipse/aidge/aidge_core
  • hrouis/aidge_core
  • mszczep/aidge_core
  • oantoni/aidge_core
  • cguillon/aidge_core
  • jeromeh/aidge_core
  • axelfarr/aidge_core
  • cmoineau/aidge_core
  • noamzerah/aidge_core
  • lrakotoarivony/aidge_core
  • silvanosky/aidge_core
  • maab05/aidge_core
  • mick94/aidge_core
  • lucaslopez/aidge_core_ll
  • wboussella/aidge_core
  • farnez/aidge_core
  • mnewson/aidge_core
17 results
Show changes
......@@ -42,6 +42,27 @@ public:
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Softmax_Op(const Softmax_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Softmax_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Softmax_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Softmax_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
(void) inputIdx; // avoid unused warning
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_RECIPIES_LABELGRAPH_H_
#define AIDGE_RECIPIES_LABELGRAPH_H_
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
namespace Aidge {
NodePtr nodeLabel(NodePtr node);
/**
* @brief Generate the graph for the pixel-wise labels corresponding to a data graph, taking into account the scaling changes (padding, stride, pooling...).
* @details Right now, the behavior is to replace the following operators:
* - Conv: MaxPooling
* - ConvDepthWie: MaxPooling
* - AvgPooling: MaxPooling
* - MaxPooling: MaxPooling
* - all others: identity (removed)
* @param graph Data graph
* @param return Computing graph for the labels derived from the data graph
*/
std::shared_ptr<GraphView> labelGraph(std::shared_ptr<GraphView> graph);
} // namespace Aidge
#endif /* AIDGE_RECIPIES_LABELGRAPH_H_ */
......@@ -14,6 +14,7 @@
#include <map>
#include <vector>
#include <string>
#include <type_traits>
#include <typeinfo>
#include <assert.h>
......@@ -41,11 +42,6 @@ private:
throw std::bad_cast();
}
public:
// not copyable, not movable
CParameter(CParameter const &) = delete;
CParameter(CParameter &&) = delete;
CParameter &operator=(CParameter const &) = delete;
CParameter &operator=(CParameter &&) = delete;
CParameter() : m_Params({}){};
~CParameter() = default;
......
......@@ -94,6 +94,12 @@ public:
(void)p; // avoid unused warning
}
Parameterizable(const Parameterizable& params):
mParams(params.mParams)
{
// cpy-ctor (required for Operator cpy-ctor)
}
// Compile-time access with enum
template <PARAM_ENUM paramEnum>
constexpr typename std::tuple_element<static_cast<std::size_t>(paramEnum),std::tuple<T...>>::type& get() {
......
......@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include <stdio.h>
#include "aidge/backend/OperatorImpl.hpp"
......@@ -59,7 +60,11 @@ void init_GenericOperator(py::module& m) {
throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key.");
}
return res;
});
})
.def_readonly_static("identity", &GenericOperator_Op::Identity)
.def("compute_output_dims", &GenericOperator_Op::computeOutputDims)
.def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function"))
;
m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"),
py::arg("name") = "");
......
......@@ -62,11 +62,11 @@ class CMakeBuild(build_ext):
os.chdir(str(build_temp))
# Impose to use the executable of the python
# Impose to use the executable of the python
# used to launch setup.py to setup PythonInterp
param_py = "-DPYTHON_EXECUTABLE=" + sys.executable
install_path = f"{build_temp}/install" if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"]
install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"]
self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}'])
if not self.dry_run:
......@@ -83,11 +83,11 @@ class CMakeBuild(build_ext):
for file in files:
if file.endswith('.so') and (root != str(aidge_package.absolute())):
currentFile=os.path.join(root, file)
shutil.copy(currentFile, str(aidge_package.absolute()))
shutil.copy(currentFile, str(aidge_package.absolute()))
# Copy version.txt in aidge_package
os.chdir(os.path.dirname(__file__))
shutil.copy("version.txt", str(aidge_package.absolute()))
shutil.copy("version.txt", str(aidge_package.absolute()))
if __name__ == '__main__':
......
......@@ -682,4 +682,55 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
mOutputNodes.erase(val);
}
}
}
\ No newline at end of file
}
std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*cloneNode)(NodePtr)) const {
std::shared_ptr<GraphView> newGraph = std::make_shared<GraphView>(mName);
// Map for old node -> new node correspondance
std::map<NodePtr, NodePtr> oldToNewNodes;
for (const std::shared_ptr<Node> &node_ptr : mNodes) {
oldToNewNodes[node_ptr] = cloneNode(node_ptr);
}
// For each node, convert old node -> new node connections
for (auto &oldToNewNode : oldToNewNodes) {
if (oldToNewNode.second == nullptr)
continue; // deleted node
// Add new node to new GraphView
newGraph->add(oldToNewNode.second, false);
// Connect parent nodes. Nodes that were removed with cloneNode() are set to nullptr
size_t parentId = 0;
for (auto parent : oldToNewNode.first->inputs()) {
while (oldToNewNodes[parent.first] == nullptr) {
// Find next valid parent in line, going backward in the graph
assert(parent.first->nbDataInputs() <= 1 && "deleted nodes in GraphView::clone() cannot have multiple data inputs");
const auto& parents = parent.first->inputs();
if (!parents.empty() && parents[0].first != nullptr // a valid parent exists
&& oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView
{
parent = parents[0];
}
else {
break;
}
}
if (oldToNewNodes[parent.first]) {
oldToNewNodes[parent.first]->addChild(oldToNewNode.second, parent.second, parentId);
}
++parentId;
}
}
// Update OutputNodes/inputNodes
newGraph->updateInputNodes();
newGraph->updateOutputNodes();
return newGraph;
}
......@@ -321,6 +321,26 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) {
}
}
///////////////////////////////////////////////////////
// CLONE
///////////////////////////////////////////////////////
Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
return std::make_shared<Node>(mOperator, mName);
}
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type)
? mOperator
: mOperator->clone();
return std::make_shared<Node>(op, mName);
}
Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(mOperator->clone(), mName);
}
/////////////////////////////////////////////////////////////////////////////////////////////
// private
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <vector>
#include "aidge/operator/GenericOperator.hpp"
const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity
= [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; };
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/recipies/LabelGraph.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
Aidge::NodePtr Aidge::nodeLabel(NodePtr node) {
// Conv => MaxPooling
if (node->type() == Conv_Op<2>::Type) {
auto op = std::dynamic_pointer_cast<Conv_Op<2>>(node->getOperator());
auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvParam::KernelDims>(), op->get<ConvParam::StrideDims>());
return std::make_shared<Node>(newOp, node->name());
}
// ConvDepthWise => MaxPooling
if (node->type() == ConvDepthWise_Op<2>::Type) {
auto op = std::dynamic_pointer_cast<ConvDepthWise_Op<2>>(node->getOperator());
auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<ConvDepthWiseParam::KernelDims>(), op->get<ConvDepthWiseParam::StrideDims>());
return std::make_shared<Node>(newOp, node->name());
}
// AvgPooling => MaxPooling
if (node->type() == AvgPooling_Op<2>::Type) {
auto op = std::dynamic_pointer_cast<AvgPooling_Op<2>>(node->getOperator());
auto newOp = std::make_shared<MaxPooling_Op<2>>(op->get<AvgPoolingParam::KernelDims>(), op->get<AvgPoolingParam::StrideDims>());
return std::make_shared<Node>(newOp, node->name());
}
// MaxPooling => MaxPooling
if (node->type() == MaxPooling_Op<2>::Type) {
return node->clone();
}
// By default, remove the node from the graph
return nullptr;
}
std::shared_ptr<Aidge::GraphView> Aidge::labelGraph(std::shared_ptr<GraphView> graph) {
return graph->cloneCallback(&nodeLabel);
}
......@@ -332,6 +332,234 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
}
}
TEST_CASE("[GraphView] clone") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("clone_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->clone();
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("clone_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator());
}
SECTION("Check new connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) != g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) != g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) != g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) != g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) != g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) != g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedProducers") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedProducers_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->cloneSharedProducers();
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("cloneSharedProducers_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check new connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) != g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getOutput(0) != g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getOutput(0) != g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider2->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedOperators") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedOperators_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == conv1->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getInput(1) == g1->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getInput(2) == g1->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(conv2->getOperator()->getInput(1) == g1->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getInput(2) == g1->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(conv2->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv3->getOperator()->getInput(1) == g1->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(conv3->getOperator()->getInput(2) == g1->getNode("conv3_b")->getOperator()->getOutput(0));
}
auto g2 = g1->cloneSharedOperators();
g2->forwardDims();
g2->save("cloneSharedOperators_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getOutput(0) == g2->getNode("conv1")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(1) == g2->getNode("conv1_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getInput(2) == g2->getNode("conv1_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(1) == g2->getNode("conv2_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getInput(2) == g2->getNode("conv2_b")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(1) == g2->getNode("conv3_w")->getOperator()->getOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
}
}
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
......@@ -352,7 +580,7 @@ TEST_CASE("[core/graph] GraphView(insertParent)") {
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
......@@ -374,4 +602,4 @@ TEST_CASE("[core/graph] GraphView(insertParent)") {
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}
}
\ No newline at end of file
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipies/LabelGraph.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include <cstddef>
using namespace Aidge;
TEST_CASE("[LabelGraph] conv") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("LabelGraph_conv_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_conv_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling");
}
}
TEST_CASE("[LabelGraph] deleted node") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
Conv(64, 10, {1, 1}, "conv3", {2, 2})
});
g1->save("LabelGraph_deleted_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleted_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling");
}
SECTION("Check dimensions") {
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222}));
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 220, 220}));
REQUIRE(g2->getNode("conv3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 110, 110}));
}
}
TEST_CASE("[LabelGraph] deleted nodes") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(32, 64, {3, 3}, "conv2"),
GenericOperator("Dummy_to_be_removed", 1, 1, 1),
Conv(64, 10, {1, 1}, "conv3")
});
g1->save("LabelGraph_deleteds_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
g2->save("LabelGraph_deleteds_label");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling");
}
}
TEST_CASE("[LabelGraph] pooling") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
AvgPooling({2, 2}, "pool1"),
MaxPooling({2, 2}, "pool2"),
MaxPooling({2, 2}, "pool3", {2, 2})
});
g1->save("LabelGraph_deleted_graph");
auto g2 = labelGraph(g1);
auto dataProvider2 = Producer({16, 1, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("pool1"), 0);
g2->forwardDims();
g2->save("LabelGraph_pooling");
SECTION("Check resulting nodes") {
REQUIRE(g2->getNodes().size() == 3);
REQUIRE(g2->getNode("pool1")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0) == g2->getNode("pool2")->getOperator()->getInput(0));
REQUIRE(g2->getNode("pool2")->getOperator()->type() == "MaxPooling");
REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0) == g2->getNode("pool3")->getOperator()->getInput(0));
REQUIRE(g2->getNode("pool3")->getOperator()->type() == "MaxPooling");
}
SECTION("Check dimensions") {
REQUIRE(g2->getNode("pool1")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 223, 223}));
REQUIRE(g2->getNode("pool2")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 222, 222}));
REQUIRE(g2->getNode("pool3")->getOperator()->getOutput(0)->dims() == std::vector<DimSize_t>({16, 1, 111, 111}));
}
}