Skip to content
Snippets Groups Projects
Commit c8d3a3e4 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] working prototype for horizontal tiling

parent 114b4ec1
No related branches found
No related tags found
1 merge request!54horizontal tiling
Pipeline #35116 canceled
......@@ -84,7 +84,7 @@ void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
*/
void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
// std::set<std::shared_ptr<Node>> getHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices);
std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices);
// void horizontalTiling(std::shared_ptr<Node> node, DimIdx_t dim, std::size_t nbSlices);
// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <memory>
#include <vector>
#include <utility>
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Slice.hpp"
// TODO: assert Operator uses Tensors when implemented
std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std::shared_ptr<Aidge::Node>& node,
const Aidge::DimIdx_t axis,
const std::size_t nbSlices)
{
if (node->getOperator()->type() != "Conv") {
AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
}
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
if (op->nbOutputs() != 1 || op->nbData() > 1) {
AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
}
if (!op->outputDimsForwarded()) {
AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
}
// start by doing a tiling with strict dimensions division
const auto& outTensor = op->getOutput(0);
if (op->getOutput(0)->dims()[axis] % nbSlices != 0) {
AIDGE_INTERNAL_ASSERT("axis should be a multiple of nbSlices");
}
// dimensions of a Slice
std::vector<DimSize_t> outputDims = outTensor->dims();
outputDims[axis] /= nbSlices;
std::vector<DimSize_t> currentFirstDims = std::vector<DimSize_t>(outTensor->nbDims(), 0);
std::set<std::shared_ptr<Aidge::Node>> res;
auto concat = Concat(nbSlices, axis);
res.insert(concat);
// check slice sizes
// const auto inputDims = op->computeReceptiveField(currentFirstDims[axis], outputDims, 0);
// std::vector<bool> shareTensor(node->nbInputs(), false);
// for (DimSize_t inputID = 0; inputID < node->nbInputs(); ++inputID) {
// const auto inTensor = std::dynamic_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputID));
// if (inTensor->dims() == inputDims[inputID].second)
// shareTensor[inputID] = true;
// }
std::vector<std::shared_ptr<Node>> clonedInputs = std::vector<std::shared_ptr<Node>>(node->nbInputs(), nullptr);
for (std::size_t i = node->nbData(); i < node ->nbInputs(); ++i) {
clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators();
clonedInputs[i] -> setName(node -> name() + "_0");
res.insert(clonedInputs[i]);
}
for (; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis]) {
const auto inputDims = op->computeReceptiveField(outTensor->getIdx(currentFirstDims), outputDims, 0);
auto newNode = node -> clone(); // no input associated to clones
newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter
auto slice = Slice(inputDims[0].first, inputDims[0].second, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, currentFirstDims[axis]);
res.insert(slice);
res.insert(newNode);
}
return res;
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment