Skip to content
Snippets Groups Projects
Commit 628b9466 authored by Iryna de Albuquerque Silva's avatar Iryna de Albuquerque Silva Committed by Olivier BICHLER
Browse files

Add explicit names setting to nodes created in the getConvHorizontalTiling() function.

parent 68aaf9b7
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!364Add explicit names setting to nodes created in the HorizontalTiling recipe
...@@ -62,6 +62,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -62,6 +62,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
auto concat = Concat(nbSlices, axis); auto concat = Concat(nbSlices, axis);
concat -> setName(concat->type());
std::set<std::shared_ptr<Aidge::Node>> tiledOperator{concat}; std::set<std::shared_ptr<Aidge::Node>> tiledOperator{concat};
// check slice sizes // check slice sizes
...@@ -77,7 +78,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -77,7 +78,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
for (std::size_t i = 0; i < node ->nbInputs(); ++i) { for (std::size_t i = 0; i < node ->nbInputs(); ++i) {
if (node->inputCategory(i) == InputCategory::Param || node->inputCategory(i) == InputCategory::OptionalParam) { if (node->inputCategory(i) == InputCategory::Param || node->inputCategory(i) == InputCategory::OptionalParam) {
clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators(); clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators();
clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0"); // clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0");
clonedInputs[i] -> setName(node -> getParent(i) -> name());
tiledOperator.insert(clonedInputs[i]); tiledOperator.insert(clonedInputs[i]);
} }
} }
...@@ -93,6 +95,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -93,6 +95,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[2] -> addChild(newNode, 0, 2); clonedInputs[2] -> addChild(newNode, 0, 2);
auto slice = Slice(); auto slice = Slice();
slice -> setName(node->name() + '_' + slice->type() + '_' + std::to_string(i));
auto backend = outTensor->getImpl()->backend(); auto backend = outTensor->getImpl()->backend();
// Create Slice's Starts producer node // Create Slice's Starts producer node
...@@ -105,7 +108,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -105,7 +108,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
starts -> setBackend(backend); starts -> setBackend(backend);
starts -> resize(std::vector<std::size_t>({inputDimsStart.size()})); starts -> resize(std::vector<std::size_t>({inputDimsStart.size()}));
starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size()); starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size());
auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]); // auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]);
auto startsNode = Producer(starts, slice->name() + "_" + sliceInputsNames[1]);
startsNode -> addChild(slice, 0, 1); startsNode -> addChild(slice, 0, 1);
// Create Slice's Ends producer node // Create Slice's Ends producer node
...@@ -118,7 +122,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -118,7 +122,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
ends -> setBackend(backend); ends -> setBackend(backend);
ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()})); ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()}));
ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size()); ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size());
auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]); // auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]);
auto endsNode = Producer(ends, slice->name() + "_" + sliceInputsNames[2]);
endsNode -> addChild(slice, 0, 2); endsNode -> addChild(slice, 0, 2);
// Create Slice's Axes producer node // Create Slice's Axes producer node
...@@ -129,7 +134,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -129,7 +134,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
axes -> setBackend(backend); axes -> setBackend(backend);
axes -> resize(std::vector<std::size_t>({usedDims.size()})); axes -> resize(std::vector<std::size_t>({usedDims.size()}));
axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size()); axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size());
auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]); // auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]);
auto axesNode = Producer(axes, slice->name() + "_" + sliceInputsNames[3]);
axesNode -> addChild(slice, 0, 3); axesNode -> addChild(slice, 0, 3);
// Create Slice's Steps producer node // Create Slice's Steps producer node
...@@ -139,7 +145,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -139,7 +145,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
steps -> setBackend(backend); steps -> setBackend(backend);
steps -> resize(std::vector<std::size_t>({inputDimsSteps.size()})); steps -> resize(std::vector<std::size_t>({inputDimsSteps.size()}));
steps -> getImpl() -> copyFromHost(inputDimsSteps.data(), inputDimsSteps.size()); steps -> getImpl() -> copyFromHost(inputDimsSteps.data(), inputDimsSteps.size());
auto stepsNode = Producer(steps, slice->name() + sliceInputsNames[4]); // auto stepsNode = Producer(steps, slice->name() + sliceInputsNames[4]);
auto stepsNode = Producer(steps, slice->name() + "_" + sliceInputsNames[4]);
stepsNode -> addChild(slice, 0, 4); stepsNode -> addChild(slice, 0, 4);
// auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, inputDimsSteps); // auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, inputDimsSteps);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment