Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
aidge_core
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Cyril Moineau
aidge_core
Commits
080743a9
Commit
080743a9
authored
1 year ago
by
Maxence Naud
Browse files
Options
Downloads
Patches
Plain Diff
[Upd] replace() instead of replaceWith() in GraphView
parent
4b783082
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
include/aidge/graph/GraphView.hpp
+25
-11
25 additions, 11 deletions
include/aidge/graph/GraphView.hpp
src/graph/GraphView.cpp
+101
-0
101 additions, 0 deletions
src/graph/GraphView.cpp
unit_tests/graph/Test_GraphView.cpp
+64
-0
64 additions, 0 deletions
unit_tests/graph/Test_GraphView.cpp
with
190 additions
and
11 deletions
include/aidge/graph/GraphView.hpp
+
25
−
11
View file @
080743a9
...
...
@@ -322,17 +322,17 @@ public:
/**
* @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
*
*
* @param childNode Node that gets a new parent.
* @param newParentNode Inserted Node.
* @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
* @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
* @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
*/
void
insertParent
(
NodePtr
childNode
,
NodePtr
newParentNode
,
IOIndex_t
childInputTensorIdx
,
IOIndex_t
newParentInputTensorIdx
,
void
insertParent
(
NodePtr
childNode
,
NodePtr
newParentNode
,
IOIndex_t
childInputTensorIdx
,
IOIndex_t
newParentInputTensorIdx
,
IOIndex_t
newParentOutputTensorIdx
);
/**
...
...
@@ -342,6 +342,20 @@ public:
* @return false
*/
bool
replaceWith
(
std
::
set
<
NodePtr
>
newNodes
);
/**
* @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers.
* @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing
* them will not be affected by the replacement. The oldNodes set should have only one input/output
* Node for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>.
* @return true
* @return false
*/
bool
replace
(
std
::
set
<
NodePtr
>&
oldNodes
,
std
::
set
<
NodePtr
>&
newNodes
);
void
updateInputNodes
();
/**
* @brief Process from zero the set of output Nodes.
...
...
@@ -379,6 +393,12 @@ public:
*/
std
::
shared_ptr
<
GraphView
>
cloneCallback
(
NodePtr
(
*
cloneNode
)(
NodePtr
))
const
;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t
getNbFreeDataInputs
()
const
;
private
:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
...
...
@@ -390,12 +410,6 @@ private:
*/
IOIndex_t
getNbDataInputs
()
const
;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t
getNbFreeDataInputs
()
const
;
/**
* @brief Update the set of inputNodes with a new Node, checking if it can be
* added and removing any Node not part of mInputNode anymore.
...
...
This diff is collapsed.
Click to expand it.
src/graph/GraphView.cpp
+
101
−
0
View file @
080743a9
...
...
@@ -17,6 +17,7 @@
#include
"aidge/utils/Types.h"
#include
"aidge/graph/GraphView.hpp"
#include
"aidge/data/Tensor.hpp"
#include
"aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
...
...
@@ -594,6 +595,106 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return
replacable
;
}
bool
Aidge
::
GraphView
::
replace
(
std
::
set
<
Aidge
::
NodePtr
>&
oldNodes
,
std
::
set
<
Aidge
::
NodePtr
>&
newNodes
)
{
for
(
const
auto
&
node
:
oldNodes
)
{
if
(
mNodes
.
find
(
node
)
==
mNodes
.
end
())
{
AIDGE_INTERNAL_ASSERT
(
"GraphView asked to replace a Node it does not contain."
);
}
}
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
auto
oldG
=
std
::
make_shared
<
GraphView
>
();
oldG
->
add
(
oldNodes
,
false
);
auto
newG
=
std
::
make_shared
<
GraphView
>
();
newG
->
add
(
newNodes
,
false
);
if
((
oldG
->
inputNodes
().
size
()
!=
1
)
||
(
oldG
->
outputNodes
().
size
()
!=
1
))
{
return
false
;
}
if
(
!
(
newNodes
.
empty
())
&&
((
newG
->
inputNodes
().
size
()
!=
1
)
||
(
newG
->
outputNodes
().
size
()
!=
1
)))
{
return
false
;
}
std
::
shared_ptr
<
Node
>
previousInputNode
=
(
*
(
oldG
->
inputNodes
()).
begin
());
std
::
shared_ptr
<
Node
>
previousOutputNode
=
(
*
(
oldG
->
outputNodes
()).
begin
());
// find Node to link to new input Node
//compute number of input for previousInputNode not in oldNodes set
std
::
size_t
nbExternalInputs
=
0
;
std
::
shared_ptr
<
Node
>
externalInput
=
nullptr
;
IOIndex_t
externalInputId
=
gk_IODefaultIndex
;
for
(
const
auto
&
input
:
previousInputNode
->
inputs
())
{
if
(
oldNodes
.
find
(
input
.
first
)
==
oldNodes
.
end
())
{
nbExternalInputs
++
;
externalInput
=
input
.
first
;
externalInputId
=
input
.
second
;
}
}
if
(
nbExternalInputs
>
1
)
{
AIDGE_INTERNAL_ASSERT
(
"To many input to link for oldNodes set"
);
}
if
(
previousOutputNode
->
nbOutputs
()
!=
1
)
{
return
false
;
}
// find Node to replicate output connections
std
::
shared_ptr
<
Node
>
newOutputNode
=
newNodes
.
empty
()
?
externalInput
:
*
(
newG
->
outputNodes
().
begin
());
auto
copyOutputs
=
previousOutputNode
->
outputs
();
// manage Views for newNodes
// only keep common views to each node for the new set
std
::
set
<
std
::
shared_ptr
<
GraphView
>>
commonGraphViews
=
(
*
oldNodes
.
begin
())
->
views
();
for
(
const
auto
&
nodePtr
:
oldNodes
)
{
const
auto
nodeView
=
nodePtr
->
views
();
std
::
set
<
std
::
shared_ptr
<
GraphView
>>
intersection
;
std
::
set_intersection
(
commonGraphViews
.
begin
(),
commonGraphViews
.
end
(),
nodeView
.
begin
(),
nodeView
.
end
(),
std
::
inserter
(
intersection
,
intersection
.
begin
()));
commonGraphViews
=
intersection
;
}
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
std
::
set
<
std
::
shared_ptr
<
Node
>>
nodesToClean
;
std
::
set_difference
(
oldNodes
.
begin
(),
oldNodes
.
end
(),
newNodes
.
begin
(),
newNodes
.
end
(),
std
::
inserter
(
nodesToClean
,
nodesToClean
.
begin
()));
for
(
auto
&
nodePtr
:
nodesToClean
)
{
nodePtr
->
resetConnections
(
true
);
}
// copy output connections
for
(
IOIndex_t
o
=
0
;
o
<
previousOutputNode
->
nbOutputs
();
++
o
)
{
auto
outputPairs
=
copyOutputs
[
o
];
for
(
const
auto
&
onePair
:
outputPairs
)
{
newOutputNode
->
addChild
(
onePair
.
first
,
o
,
onePair
.
second
);
}
}
// copy input connections
if
(
!
newNodes
.
empty
())
{
std
::
shared_ptr
<
Node
>
newInputNode
=
(
*
(
newG
->
inputNodes
()).
begin
());
if
(
newInputNode
->
getNbFreeDataInputs
()
>
1
)
{
return
false
;
}
// one non-connected input in newNodes set
externalInput
->
addChild
(
newInputNode
,
externalInputId
,
newInputNode
->
getFirstFreeDataInput
());
}
// insert new Nodes in the right GraphViews
for
(
auto
&
graphPtr
:
commonGraphViews
)
{
graphPtr
->
add
(
newNodes
,
false
);
if
(
newNodes
.
empty
())
{
graphPtr
->
updateInputNodes
();
graphPtr
->
updateOutputNodes
();
}
}
return
true
;
}
void
Aidge
::
GraphView
::
updateInputNodes
()
{
mInputNodes
.
clear
();
for
(
const
std
::
shared_ptr
<
Node
>&
go_ptr
:
mNodes
)
{
...
...
This diff is collapsed.
Click to expand it.
unit_tests/graph/Test_GraphView.cpp
+
64
−
0
View file @
080743a9
...
...
@@ -332,6 +332,70 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
}
}
TEST_CASE
(
"[core/graph] GraphView(replace)"
,
"[GraphView][replace]"
)
{
SECTION
(
"replace small pattern"
)
{
// create original graph
std
::
shared_ptr
<
GraphView
>
g
=
std
::
make_shared
<
GraphView
>
(
"TestGraph"
);
auto
otherInput
=
GenericOperator
(
"Producer"
,
0
,
0
,
1
,
"other_input"
);
auto
matmulWeight
=
GenericOperator
(
"Producer"
,
0
,
0
,
1
,
"matmul_w"
);
auto
addBias
=
GenericOperator
(
"Producer"
,
0
,
0
,
1
,
"add_b"
);
auto
other1
=
GenericOperator
(
"Other"
,
1
,
1
,
1
,
"other1"
);
auto
other2
=
GenericOperator
(
"Other"
,
1
,
1
,
1
,
"other2"
);
auto
matmul
=
GenericOperator
(
"MatMul"
,
1
,
2
,
1
,
"matmul"
);
auto
add
=
GenericOperator
(
"Add"
,
1
,
2
,
1
,
"add"
);
otherInput
->
addChild
(
other1
);
other1
->
addChild
(
matmul
);
matmul
->
addChild
(
add
);
add
->
addChild
(
other2
);
matmulWeight
->
addChild
(
matmul
,
0
,
1
);
addBias
->
addChild
(
add
,
0
,
1
);
g
->
add
({
other1
,
matmul
,
add
,
other2
});
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
matmulWeight
,
addBias
,
other1
,
other2
,
matmul
,
add
}));
// create graph to replace
std
::
set
<
std
::
shared_ptr
<
Node
>>
nodeToReplace
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
matmulWeight
,
addBias
,
matmul
,
add
});
// create replacing graph
std
::
shared_ptr
<
Node
>
myFC
=
GenericOperator
(
"FC"
,
1
,
3
,
1
,
"fc"
);
auto
newMatmulWeight
=
matmulWeight
->
cloneSharedOperators
();
newMatmulWeight
->
addChild
(
myFC
,
0
,
1
);
auto
newAddBias
=
addBias
->
cloneSharedOperators
();
newAddBias
->
addChild
(
myFC
,
0
,
2
);
std
::
set
<
std
::
shared_ptr
<
Node
>>
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
myFC
,
newMatmulWeight
,
newAddBias
});
// replace
g
->
replace
(
nodeToReplace
,
newNodes
);
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
newMatmulWeight
,
newAddBias
,
other1
,
other2
,
myFC
}));
REQUIRE
(((
myFC
->
getParent
(
0
)
==
other1
)
&&
(
myFC
->
getParent
(
1
)
==
newMatmulWeight
)
&&
(
myFC
->
getParent
(
2
)
==
newAddBias
)));
}
SECTION
(
"replace with nothing"
)
{
std
::
shared_ptr
<
GraphView
>
g
=
std
::
make_shared
<
GraphView
>
(
"TestGraph"
);
auto
r1
=
GenericOperator
(
"relu"
,
0
,
0
,
1
);
auto
r2
=
GenericOperator
(
"relu"
,
1
,
1
,
1
);
auto
r3
=
GenericOperator
(
"relu"
,
1
,
1
,
1
);
auto
r4
=
GenericOperator
(
"relu"
,
1
,
1
,
0
);
r1
->
addChild
(
r2
);
r2
->
addChild
(
r3
);
r3
->
addChild
(
r4
);
g
->
add
({
r1
,
r2
,
r3
,
r4
});
auto
nodesToReplace
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r2
,
r3
});
auto
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({});
g
->
replace
(
nodesToReplace
,
newNodes
);
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r1
,
r4
}));
REQUIRE
((
r1
->
output
(
0
))[
0
].
first
==
r4
);
}
// SECTION("replace for tiling") {
// std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
// auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
// auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
// auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
// auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
// otherInput->addChild()
// }
}
TEST_CASE
(
"[GraphView] clone"
)
{
auto
dataProvider
=
Producer
({
16
,
3
,
224
,
224
},
"dataProvider"
);
auto
conv1
=
Conv
(
3
,
32
,
{
3
,
3
},
"conv1"
);
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment