Skip to content
Snippets Groups Projects
Commit 89cadd14 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] 'removeFlatten' recipie test and fix bug in 'GraphView::replace'

parent 92738e81
No related branches found
No related tags found
No related merge requests found
...@@ -74,7 +74,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const { ...@@ -74,7 +74,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
std::string givenName = std::string givenName =
(node_ptr->name().empty()) (node_ptr->name().empty())
? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>" ? "<em>" + currentType + "#" + std::to_string(typeCounter[currentType]) + "</em>"
: "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + " )</em></sub>\""; : "\"" + node_ptr->name() + "\\n<sub><em>( " + currentType + "#" + std::to_string(typeCounter[currentType]) + " )</em></sub>\"";
namePtrTable[node_ptr] = namePtrTable[node_ptr] =
(currentType + "_" + std::to_string(typeCounter[currentType])); (currentType + "_" + std::to_string(typeCounter[currentType]));
...@@ -789,7 +789,8 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s ...@@ -789,7 +789,8 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
// Case 3 // Case 3
if (oldOI.size() == oldOO.size()) { if (oldOI.size() == oldOO.size()) {
for (std::size_t i = 0; i < oldOI.size(); ++i) { for (std::size_t i = 0; i < oldOI.size(); ++i) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); if (inputParents[i].first)
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
} }
} }
else if (oldOI.size() == 1) { else if (oldOI.size() == 1) {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/recipies/Recipies.hpp"
namespace Aidge {
TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") {
// generate the original GraphView
auto flatten = GenericOperator("Flatten", 1, 0, 1, "myFlatten");
auto fc = FC(10, 50, "myFC");
flatten -> addChild(fc);
auto g = std::make_shared<GraphView>();
g->add({fc, flatten});
// Check original graph
// g -> save("before_remove_flatten");
// use recipie
removeFlatten(g);
// Check transformed graph
// g -> save("after_remove_flatten");
REQUIRE(g->getOrderedInputs().size() == 1);
REQUIRE(g->getOrderedOutputs().size() == 1);
REQUIRE(g->getOrderedInputs()[0].first == fc);
REQUIRE(g->getOrderedOutputs()[0].first == fc);
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment