/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"

Aidge::OperatorImpl::OperatorImpl(const Operator& op):
    mOp(op),
    mNbConsumedData(mOp.nbInputs(), 0),
    mNbProducedData(mOp.nbOutputs(), 0)
{
    //ctor
}

Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
    assert(mOp.getRawInput(inputIdx) && "requires valid input");

    // Requires the whole tensor by default
    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
}

Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
    assert(mOp.getRawInput(inputIdx) && "requires valid input");

    // Protect the whole tensor by default
    return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size();
}

Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
                                                         const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
    assert(mOp.getRawOutput(outputIdx) && "requires valid output");

    // Requires the whole tensor by default, regardless of available data on inputs
    return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
}

Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
    assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
    return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}

Aidge::NbElts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
    assert(static_cast<std::size_t>(outputIdx) < mNbProducedData.size());
    return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}

void Aidge::OperatorImpl::updateConsummerProducer(){
    // Update producer-consumer data
    for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
        // each input is consumed by the minimum amount for a forward pass
        mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
    }

    for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
        mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
    }
}

void Aidge::OperatorImpl::forward() {
    AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented");
}

void Aidge::OperatorImpl::backward() {
    AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented");
}