#include <set>
#include <memory>
#include <numeric> // std::iota
#include <vector>
#include <utility>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Slice.hpp"
// TODO: assert Operator uses Tensors when implemented
std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std::shared_ptr<Aidge::Node>& node,
const Aidge::DimIdx_t axis,
const std::size_t nbSlices)
// for now, Tiling works only with Conv Operators
if (node->getOperator()->type() != "Conv") {
AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
// TODO: back when tiling works with other Operators
// AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
// TODO: back when tiling works with other Operators
// if (op->nbOutputs() != 1 || op->nbData() > 1) {
// AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
// }
if (!op->dimsForwarded()) {
AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
const std::shared_ptr<Tensor>& outTensor = op->getOutput(0);
std::vector<DimSize_t> outputDims = outTensor->dims();
// start by doing a tiling with strict dimensions division
if (outputDims[axis] % nbSlices != 0) {
AIDGE_INTERNAL_ASSERT("axis should be a multiple of nbSlices");
// dimensions of a Slice
outputDims[axis] /= nbSlices;
auto concat = Concat(nbSlices, axis);
std::set<std::shared_ptr<Aidge::Node>> tiledOperator{concat};
// check slice sizes
// const auto inputDims = op->computeReceptiveField(currentFirstDims[axis], outputDims, 0);
// std::vector<bool> shareTensor(node->nbInputs(), false);
// for (DimSize_t inputID = 0; inputID < node->nbInputs(); ++inputID) {
// const auto inTensor = std::dynamic_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputID));
// if (inTensor->dims() == inputDims[inputID].second)
// shareTensor[inputID] = true;
// }
std::vector<std::shared_ptr<Node>> clonedInputs = std::vector<std::shared_ptr<Node>>(node->nbInputs(), nullptr);
for (std::size_t i = node->nbData(); i < node ->nbInputs(); ++i) {
clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators();
clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0");
const std::vector<std::string> sliceInputsNames = Slice_Op::getInputsName();
// coordinates of the first value of the current output slice
std::vector<DimSize_t> currentFirstDims = std::vector<DimSize_t>(outTensor->nbDims(), 0);
for (IOIndex_t i = 0; currentFirstDims[axis] < outTensor->dims()[axis]; currentFirstDims[axis] += outputDims[axis], ++i) {
const auto inputDims = op->computeReceptiveField(currentFirstDims, outputDims, 0);
auto newNode = node -> clone(); // no input associated to clones
newNode -> setName(node->name() + "_" + std::to_string(currentFirstDims[axis]));
clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2);
auto backend = outTensor->getImpl()->backend();
auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis]));
// Create Slice's Starts producer node
std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>();
starts -> setDataType(DataType::Int64);
starts -> setBackend(backend);
starts -> resize(std::vector<std::size_t>({inputDimsStart.size()}));
starts -> getImpl() -> copyFromHost(, inputDimsStart.size());
auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]);
startsNode -> addChild(slice, 0, 1);
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]);
const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>();
ends -> setDataType(DataType::Int64);
ends -> setBackend(backend);
ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()}));
ends -> getImpl() -> copyFromHost(, inputDimsEnd.size());
auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]);
endsNode -> addChild(slice, 0, 2);
// Create Slice's Axes producer node
std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>();
axes -> setDataType(DataType::Int64);
axes -> setBackend(backend);
axes -> resize(std::vector<std::size_t>({usedDims.size()}));
axes -> getImpl() -> copyFromHost(, usedDims.size());
auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]);
slice -> addChild(newNode, 0, 0);
tiledOperator.insert({slice, newNode, startsNode, endsNode, axesNode});