Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
aidge_core
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Cyril Moineau
aidge_core
Commits
4b3540a3
Commit
4b3540a3
authored
1 year ago
by
Maxence Naud
Browse files
Options
Downloads
Patches
Plain Diff
[Upd] replace() member function for tiling handling
parent
f646e90d
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
include/aidge/graph/GraphView.hpp
+5
-5
5 additions, 5 deletions
include/aidge/graph/GraphView.hpp
src/graph/GraphView.cpp
+47
-24
47 additions, 24 deletions
src/graph/GraphView.cpp
unit_tests/graph/Test_GraphView.cpp
+33
-10
33 additions, 10 deletions
unit_tests/graph/Test_GraphView.cpp
with
85 additions
and
39 deletions
include/aidge/graph/GraphView.hpp
+
5
−
5
View file @
4b3540a3
...
@@ -344,17 +344,17 @@ public:
...
@@ -344,17 +344,17 @@ public:
bool
replaceWith
(
std
::
set
<
NodePtr
>
newNodes
);
bool
replaceWith
(
std
::
set
<
NodePtr
>
newNodes
);
/**
/**
* @brief Replace a set of Nodes in
the current
GraphView with a new set of Nodes if possible.
* @brief Replace a set of Nodes in
every available
GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers.
* Both sets should include all the necessary Producers.
* @details Replaced Nodes are
only
removed from
the current GraphView. Other GraphView containing
* @details Replaced Nodes are removed from
any GraphView pointing at them all.
*
them will not be affected by the replacement.
The oldNodes set should have only one input/output
* The oldNodes set should have only one input/output
*
Node
for automatic connections of newNodes set.
*
Tensor
for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>.
* @param newNodes new set of shared_ptr<Node>.
* @return true
* @return true
* @return false
* @return false
*/
*/
bool
replace
(
std
::
set
<
NodePtr
>&
oldNodes
,
std
::
set
<
NodePtr
>&
newNodes
);
static
bool
replace
(
const
std
::
set
<
NodePtr
>&
oldNodes
,
const
std
::
set
<
NodePtr
>&
newNodes
);
void
updateInputNodes
();
void
updateInputNodes
();
/**
/**
...
...
This diff is collapsed.
Click to expand it.
src/graph/GraphView.cpp
+
47
−
24
View file @
4b3540a3
...
@@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
...
@@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
return
replacable
;
return
replacable
;
}
}
bool
Aidge
::
GraphView
::
replace
(
std
::
set
<
Aidge
::
NodePtr
>&
oldNodes
,
std
::
set
<
Aidge
::
NodePtr
>&
newNodes
)
{
bool
Aidge
::
GraphView
::
replace
(
const
std
::
set
<
Aidge
::
NodePtr
>&
oldNodes
,
const
std
::
set
<
Aidge
::
NodePtr
>&
newNodes
)
{
for
(
const
auto
&
node
:
oldNodes
)
{
if
(
mNodes
.
find
(
node
)
==
mNodes
.
end
())
{
AIDGE_INTERNAL_ASSERT
(
"GraphView asked to replace a Node it does not contain."
);
}
}
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
// How to distinguish it from data input?
// How to distinguish it from data input?
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Parameter Tensors could be identified with their dimensions
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
// It also avoids specifying each producer since they are automatically included
// It also avoids specifying each producer since they are automatically included
auto
oldG
=
std
::
make_shared
<
GraphView
>
();
auto
oldG
=
std
::
make_shared
<
GraphView
>
(
"oldG"
);
oldG
->
add
(
oldNodes
,
false
);
oldG
->
add
(
oldNodes
,
false
);
auto
newG
=
std
::
make_shared
<
GraphView
>
();
auto
newG
=
std
::
make_shared
<
GraphView
>
(
"newG"
);
newG
->
add
(
newNodes
,
false
);
newG
->
add
(
newNodes
,
false
);
if
((
oldG
->
inputNodes
().
size
()
!
=
1
)
||
(
oldG
->
outputNodes
().
size
()
!=
1
))
{
if
((
oldG
->
inputNodes
().
size
()
=
=
0
)
||
(
oldG
->
outputNodes
().
size
()
!=
1
))
{
return
false
;
return
false
;
}
}
if
(
!
(
newNodes
.
empty
())
&&
((
newG
->
inputNodes
().
size
()
!
=
1
)
||
if
(
!
(
newNodes
.
empty
())
&&
((
newG
->
inputNodes
().
size
()
=
=
0
)
||
(
newG
->
outputNodes
().
size
()
!=
1
)))
{
(
newG
->
outputNodes
().
size
()
!=
1
)))
{
return
false
;
return
false
;
}
}
std
::
shared_ptr
<
Node
>
previousInputNode
=
(
*
(
oldG
->
inputNodes
()).
begin
());
// there is at least one inputNode in the old/new GraphView
std
::
shared_ptr
<
Node
>
previousOutputNode
=
(
*
(
oldG
->
outputNodes
()).
begin
());
std
::
shared_ptr
<
Node
>
firstPreviousInputNode
=
(
*
(
oldG
->
inputNodes
()).
begin
());
std
::
shared_ptr
<
Node
>
firstPreviousOutputNode
=
(
*
(
oldG
->
outputNodes
()).
begin
());
// find Node to link to new input Node
// find Node to link to new input Node
//compute number of input for
p
reviousInputNode not in oldNodes set
//compute number of input for
firstP
reviousInputNode not in oldNodes set
std
::
size_t
nbExternalInputs
=
0
;
std
::
size_t
nbExternalInputs
=
0
;
std
::
shared_ptr
<
Node
>
externalInput
=
nullptr
;
std
::
shared_ptr
<
Node
>
externalInput
=
nullptr
;
IOIndex_t
externalInputId
=
gk_IODefaultIndex
;
IOIndex_t
externalInputId
=
gk_IODefaultIndex
;
for
(
const
auto
&
input
:
p
reviousInputNode
->
inputs
())
{
for
(
const
auto
&
input
:
firstP
reviousInputNode
->
inputs
())
{
if
(
oldNodes
.
find
(
input
.
first
)
==
oldNodes
.
end
())
{
if
(
oldNodes
.
find
(
input
.
first
)
==
oldNodes
.
end
())
{
// Node connected to another Node outside of oldG
nbExternalInputs
++
;
nbExternalInputs
++
;
externalInput
=
input
.
first
;
externalInput
=
input
.
first
;
externalInputId
=
input
.
second
;
externalInputId
=
input
.
second
;
...
@@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
...
@@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
if
(
nbExternalInputs
>
1
)
{
if
(
nbExternalInputs
>
1
)
{
AIDGE_INTERNAL_ASSERT
(
"To many input to link for oldNodes set"
);
AIDGE_INTERNAL_ASSERT
(
"To many input to link for oldNodes set"
);
}
}
if
(
previousOutputNode
->
nbOutputs
()
!=
1
)
{
if
(
oldG
->
inputNodes
().
size
()
>
1
){
// one or no input has been identified. Checking every input points to the same source
for
(
const
auto
&
previousInputNode
:
oldG
->
inputNodes
())
{
for
(
const
auto
&
input
:
previousInputNode
->
inputs
())
{
if
(
oldNodes
.
find
(
input
.
first
)
==
oldNodes
.
end
())
{
if
(
(
externalInput
!=
input
.
first
)
||
(
externalInputId
!=
input
.
second
)
)
{
return
false
;
// an inputNode points to an external Node different from the registered one
}
}
}
}
}
if
(
firstPreviousOutputNode
->
nbOutputs
()
!=
1
)
{
return
false
;
return
false
;
}
}
// find Node to replicate output connections
// find Node to replicate output connections
std
::
shared_ptr
<
Node
>
newOutputNode
=
newNodes
.
empty
()
?
externalInput
:
*
(
newG
->
outputNodes
().
begin
());
std
::
shared_ptr
<
Node
>
newOutputNode
=
newNodes
.
empty
()
?
externalInput
:
*
(
newG
->
outputNodes
().
begin
());
auto
copyOutputs
=
p
reviousOutputNode
->
outputs
();
auto
copyOutputs
=
firstP
reviousOutputNode
->
outputs
();
// manage Views for newNodes
// manage Views for newNodes
// only keep common views to each node for the new set
// only keep common views to each node for the new set
std
::
set
<
std
::
shared_ptr
<
GraphView
>>
commonGraphViews
=
(
*
oldNodes
.
begin
())
->
views
();
std
::
set
<
std
::
shared_ptr
<
GraphView
>>
commonGraphViews
=
(
*
oldNodes
.
begin
())
->
views
();
...
@@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
...
@@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
std
::
inserter
(
intersection
,
intersection
.
begin
()));
std
::
inserter
(
intersection
,
intersection
.
begin
()));
commonGraphViews
=
intersection
;
commonGraphViews
=
intersection
;
}
}
commonGraphViews
.
erase
(
oldG
);
commonGraphViews
.
erase
(
newG
);
// clean Nodes to replace
// clean Nodes to replace
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
// Do not include common Nodes to avoid cleaning Producers linked to newNodes
...
@@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
...
@@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
for
(
auto
&
nodePtr
:
nodesToClean
)
{
nodePtr
->
resetConnections
(
true
);
}
for
(
auto
&
nodePtr
:
nodesToClean
)
{
nodePtr
->
resetConnections
(
true
);
}
// copy output connections
// copy output connections
for
(
IOIndex_t
o
=
0
;
o
<
p
reviousOutputNode
->
nbOutputs
();
++
o
)
{
for
(
IOIndex_t
o
=
0
;
o
<
firstP
reviousOutputNode
->
nbOutputs
();
++
o
)
{
auto
outputPairs
=
copyOutputs
[
o
];
auto
outputPairs
=
copyOutputs
[
o
];
for
(
const
auto
&
onePair
:
outputPairs
)
{
for
(
const
auto
&
onePair
:
outputPairs
)
{
newOutputNode
->
addChild
(
onePair
.
first
,
o
,
onePair
.
second
);
newOutputNode
->
addChild
(
onePair
.
first
,
o
,
onePair
.
second
);
...
@@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
...
@@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg
}
}
// copy input connections
// copy input connections
if
(
!
newNodes
.
empty
())
{
if
(
!
newNodes
.
empty
())
{
std
::
shared_ptr
<
Node
>
newInputNode
=
(
*
(
newG
->
inputNodes
()).
begin
());
for
(
const
auto
&
newInputNode
:
newG
->
inputNodes
())
{
if
(
newInputNode
->
getNbFreeDataInputs
()
>
1
)
{
IOIndex_t
inputId
=
0
;
return
false
;
for
(
const
auto
&
input
:
newInputNode
->
inputs
())
{
if
(
newNodes
.
find
(
input
.
first
)
==
newNodes
.
end
())
{
externalInput
->
addChild
(
newInputNode
,
externalInputId
,
inputId
);
}
inputId
++
;
}
}
}
// one non-connected input in newNodes set
externalInput
->
addChild
(
newInputNode
,
externalInputId
,
newInputNode
->
getFirstFreeDataInput
());
}
}
// insert new Nodes in the right GraphViews
// insert new Nodes in the right GraphViews
for
(
auto
&
graphPtr
:
commonGraphViews
)
{
for
(
const
auto
&
graphPtr
:
commonGraphViews
)
{
graphPtr
->
add
(
newNodes
,
false
);
graphPtr
->
add
(
newNodes
,
false
);
if
(
newNodes
.
empty
())
{
if
(
newNodes
.
empty
())
{
graphPtr
->
updateInputNodes
();
graphPtr
->
updateInputNodes
();
graphPtr
->
updateOutputNodes
();
graphPtr
->
updateOutputNodes
();
}
}
}
}
for
(
const
auto
&
node
:
oldNodes
)
{
node
->
removeView
(
oldG
);
}
for
(
const
auto
&
node
:
newNodes
)
{
node
->
removeView
(
newG
);
}
return
true
;
return
true
;
}
}
...
...
This diff is collapsed.
Click to expand it.
unit_tests/graph/Test_GraphView.cpp
+
33
−
10
View file @
4b3540a3
...
@@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
...
@@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
std
::
set
<
std
::
shared_ptr
<
Node
>>
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
myFC
,
newMatmulWeight
,
newAddBias
});
std
::
set
<
std
::
shared_ptr
<
Node
>>
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
myFC
,
newMatmulWeight
,
newAddBias
});
// replace
// replace
g
->
replace
(
nodeToReplace
,
newNodes
);
GraphView
::
replace
(
nodeToReplace
,
newNodes
);
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
newMatmulWeight
,
newAddBias
,
other1
,
other2
,
myFC
}));
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
newMatmulWeight
,
newAddBias
,
other1
,
other2
,
myFC
}));
REQUIRE
(((
myFC
->
getParent
(
0
)
==
other1
)
&&
(
myFC
->
getParent
(
1
)
==
newMatmulWeight
)
&&
(
myFC
->
getParent
(
2
)
==
newAddBias
)));
REQUIRE
(((
myFC
->
getParent
(
0
)
==
other1
)
&&
(
myFC
->
getParent
(
1
)
==
newMatmulWeight
)
&&
(
myFC
->
getParent
(
2
)
==
newAddBias
)));
...
@@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
...
@@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
g
->
add
({
r1
,
r2
,
r3
,
r4
});
g
->
add
({
r1
,
r2
,
r3
,
r4
});
auto
nodesToReplace
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r2
,
r3
});
auto
nodesToReplace
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r2
,
r3
});
auto
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({});
auto
newNodes
=
std
::
set
<
std
::
shared_ptr
<
Node
>>
({});
g
->
replace
(
nodesToReplace
,
newNodes
);
GraphView
::
replace
(
nodesToReplace
,
newNodes
);
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r1
,
r4
}));
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
r1
,
r4
}));
REQUIRE
((
r1
->
output
(
0
))[
0
].
first
==
r4
);
REQUIRE
((
r1
->
output
(
0
))[
0
].
first
==
r4
);
}
}
// SECTION("replace for tiling") {
// std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
SECTION
(
"replace for tiling"
)
{
// auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
std
::
shared_ptr
<
GraphView
>
g
=
std
::
make_shared
<
GraphView
>
(
"test_graph"
);
// auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
auto
otherInput
=
GenericOperator
(
"Producer"
,
0
,
0
,
1
,
"other_input"
);
// auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
auto
other1
=
GenericOperator
(
"Other"
,
1
,
1
,
1
,
"other1"
);
// auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
auto
myConv
=
GenericOperator
(
"Conv"
,
1
,
1
,
1
,
"myConv"
);
// otherInput->addChild()
auto
other2
=
GenericOperator
(
"Other"
,
1
,
1
,
1
,
"other2"
);
// }
otherInput
->
addChild
(
other1
);
other1
->
addChild
(
myConv
);
myConv
->
addChild
(
other2
);
g
->
add
({
other1
,
myConv
,
other2
});
// create tiled Conv
auto
conv1
=
GenericOperator
(
"Conv"
,
1
,
1
,
1
,
"myConv1"
);
auto
conv2
=
GenericOperator
(
"Conv"
,
1
,
1
,
1
,
"myConv2"
);
auto
conv3
=
GenericOperator
(
"Conv"
,
1
,
1
,
1
,
"myConv3"
);
auto
conv4
=
GenericOperator
(
"Conv"
,
1
,
1
,
1
,
"myConv4"
);
auto
concat
=
GenericOperator
(
"Concat"
,
4
,
4
,
1
,
"myConcat"
);
conv1
->
addChild
(
concat
);
conv2
->
addChild
(
concat
);
conv3
->
addChild
(
concat
);
conv4
->
addChild
(
concat
);
GraphView
::
replace
({
myConv
},
{
conv1
,
conv2
,
conv3
,
conv4
,
concat
});
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
other1
,
conv1
,
conv2
,
conv3
,
conv4
,
concat
,
other2
}));
GraphView
::
replace
({
conv1
,
conv2
,
conv3
,
conv4
,
concat
},
{
myConv
});
REQUIRE
(
g
->
getNodes
()
==
std
::
set
<
std
::
shared_ptr
<
Node
>>
({
other1
,
myConv
,
other2
}));
}
}
}
TEST_CASE
(
"[GraphView] clone"
)
{
TEST_CASE
(
"[GraphView] clone"
)
{
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment