Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ConvTransposeImpl.cpp 3.68 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp"
#include "aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp"

#include "aidge/operator/Conv.hpp"

template <> void Aidge::ConvTransposeImpl1D_cpu::forward() {
    const auto &op = static_cast<const ConvTranspose_Op<1> &>(mOp);

    AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type());
    AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type());
    AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type());

    std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback,
        inputBiasFallback;
    const auto &inputData =
        op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0));
    const auto &inputWeight =
        op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0));
    const auto &inputBias =
        (op.getInput(2))
            ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0))
            : Tensor();

    // Call kernel
    const auto impl = Registrar<ConvTransposeImpl1D_cpu>::create(
        getBestMatch(getRequiredSpec()));
    impl.forward(op.strideDims()[0],
                 op.dilationDims()[0],
                 op.kernelDims()[0],
                 op.getInput(0)->template dims<3>(),
                 op.getOutput(0)->template dims<3>(),
                 inputData.getImpl()->hostPtr(),
                 inputWeight.getImpl()->hostPtr(),
                 op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr,
                 op.getOutput(0)->getImpl()->rawPtr());
}

template <> void Aidge::ConvTransposeImpl1D_cpu::backward() {
    AIDGE_THROW_OR_ABORT(
        std::runtime_error,
        "Backward not yet implemented for Conv_Op<1> on backend cpu");
}

template <> void Aidge::ConvTransposeImpl2D_cpu::forward() {
    const auto &op = static_cast<const ConvTranspose_Op<2> &>(mOp);

    AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type());
    AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type());
    AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type());

    std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback,
        inputBiasFallback;
    const auto &inputData =
        op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0));
    const auto &inputWeight =
        op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0));
    const auto &inputBias =
        (op.getInput(2))
            ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0))
            : Tensor();

    // Call kernel
    const auto impl = Registrar<ConvTransposeImpl2D_cpu>::create(
        getBestMatch(getRequiredSpec()));

    impl.forward(op.strideDims(),
                 op.dilationDims(),
                 op.kernelDims(),
                 op.getInput(0)->template dims<4>(),
                 op.getOutput(0)->template dims<4>(),
                 inputData.getImpl()->hostPtr(),
                 inputWeight.getImpl()->hostPtr(),
                 op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr,
                 op.getOutput(0)->getImpl()->rawPtr());
}

template <> void Aidge::ConvTransposeImpl2D_cpu::backward() {
    AIDGE_THROW_OR_ABORT(
        std::runtime_error,
        "Backward not yet implemented for Conv_Op<2> on backend cpu");
}