Skip to content
Snippets Groups Projects
HorizontalTiling.cpp 5.29 KiB
Newer Older
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <set>
#include <memory>
#include <numeric>   // std::iota
#include <vector>
#include <utility>

Olivier BICHLER's avatar
Olivier BICHLER committed
#include "aidge/recipes/Recipes.hpp"

#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Producer.hpp"

#include "aidge/operator/Add.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Slice.hpp"

// TODO: assert Operator uses Tensors when implemented
std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std::shared_ptr<Aidge::Node>& node,
                                                            const Aidge::DimIdx_t axis,
                                                            const std::size_t nbSlices)
{
    // for now, Tiling works only with Conv Operators
    if (node->getOperator()->type() != "Conv") {
        AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
    }
    // TODO: back when tiling works with other Operators
    // AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
    const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
    // TODO: back when tiling works with other Operators
    // if (op->nbOutputs() != 1 || op->nbData() > 1) {
    //     AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
    // }
    if (!op->dimsForwarded()) {
        AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
    }

    const std::shared_ptr<Tensor>& outTensor = op->getOutput(0);
    std::vector<DimSize_t> outputDims = outTensor->dims();

    // start by doing a tiling with strict dimensions division
    if (outputDims[axis] % nbSlices != 0) {
        AIDGE_INTERNAL_ASSERT("axis should be a multiple of nbSlices");
    }

    // dimensions of a Slice
    outputDims[axis] /= nbSlices;


    auto concat = Concat(nbSlices, axis);
    std::set<std::shared_ptr<Aidge::Node>> tiledOperator{concat};

    // check slice sizes
    // const auto inputDims = op->computeReceptiveField(currentFirstDims[axis], outputDims, 0);
    // std::vector<bool> shareTensor(node->nbInputs(), false);
    // for (DimSize_t inputID = 0; inputID < node->nbInputs(); ++inputID) {
    //     const auto inTensor = std::dynamic_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputID));
    //     if (inTensor->dims() == inputDims[inputID].second)
    //         shareTensor[inputID] = true;
    // }

    std::vector<std::shared_ptr<Node>> clonedInputs = std::vector<std::shared_ptr<Node>>(node->nbInputs(), nullptr);
    for (std::size_t i = node->nbData(); i < node ->nbInputs(); ++i) {
        clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators();
        clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0");
        tiledOperator.insert(clonedInputs[i]);
    const std::vector<std::string> sliceInputsNames = Slice_Op::getInputsName();
    // coordinates of the first value of the current output slice
    std::vector<DimSize_t> currentFirstDims = std::vector<DimSize_t>(outTensor->nbDims(), 0);
Maxence Naud's avatar
Maxence Naud committed
    for (IOIndex_t i = 0; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis], ++i) {
        const auto inputDims = op->computeReceptiveField(currentFirstDims, outputDims, 0);
        auto newNode = node -> clone(); // no input associated to clones
        newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
        clonedInputs[1] -> addChild(newNode, 0, 1);
        clonedInputs[2] -> addChild(newNode, 0, 2);

        auto backend = outTensor->getImpl()->backend();
        // Create Slice's Starts attribute
        std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
        for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
            inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
        // Create Slice's Ends attribute
        std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
        for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
            inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]);
        }

        // Create Slice's Axes attribute
        std::vector<std::int8_t> usedDims(inputDimsEnd.size());
        std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0));

Houssem ROUIS's avatar
Houssem ROUIS committed
        // Create Slice's Steps attribute
        std::vector<std::int64_t> steps(inputDimsEnd.size(), static_cast<std::int64_t>(1));
Houssem ROUIS's avatar
Houssem ROUIS committed

        auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis]));
        slice -> addChild(newNode, 0, 0);
Maxence Naud's avatar
Maxence Naud committed
        newNode -> addChild(concat, 0, i);
        tiledOperator.insert({slice, newNode});
    return tiledOperator;