Skip to content
Snippets Groups Projects
Commit 0fbf84e7 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[RemoveFlatten] Add python unittest.

parent f4eb185f
No related branches found
No related tags found
1 merge request!9Fuse bn
"""
Copyright (c) 2023 CEA-List
This program and the accompanying materials are made available under the
terms of the Eclipse Public License 2.0 which is available at
http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
import unittest
import aidge_core
class test_parameters(unittest.TestCase):
"""Very basic test to make sure the python APi is not broken.
Can be remove in later stage of the developpement.
"""
def setUp(self):
pass
def tearDown(self):
pass
def test_conv(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
aidge_core.FC(50, name='0')
])
old_nodes = graph_view.get_nodes()
aidge_core.remove_flatten(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
self.assertTrue("Flatten0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
if __name__ == '__main__':
unittest.main()
...@@ -23,16 +23,17 @@ ...@@ -23,16 +23,17 @@
namespace Aidge { namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
std::shared_ptr<Node> fc; std::shared_ptr<Node> flatten;
for (const auto& element : nodes) { for (const auto& element : nodes) {
assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
if (element->type() == "FC"){ if (element->type() == "Flatten"){
fc = element; flatten = element;
} }
} }
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({nodes})); // TODO : avoid using replace_with and use a remove method instead
g->replaceWith({fc}); g->add(std::set<std::shared_ptr<Node>>({flatten}));
g->replaceWith({});
} }
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment