Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ExpandImpl.cpp 1.73 KiB

/********************************************************************************
 * Copyright (c) 2024 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/backend/cpu/operator/ExpandImpl.hpp"

#include <vector>

#include "aidge/backend/cpu/operator/ExpandImpl_kernels.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Expand.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {

template <> void ExpandImpl_cpu::forward() {
    const Expand_Op &op_ = static_cast<const Expand_Op &>(mOp);
    // Check if input are provided
    AIDGE_ASSERT(op_.getInput(0),
                 "{}: missing input 0: {}",
                 Expand_Op::Type,
                 Expand_Op::getInputsName()[0]);
    AIDGE_ASSERT(op_.getInput(1),
                 "{}: missing input 1: {}",
                 Expand_Op::Type,
                 Expand_Op::getInputsName()[1]);

    // Find the correct kernel type
    const auto impl =
        Registrar<ExpandImpl_cpu>::create(getBestMatch(getRequiredSpec()));

    // Call kernel
    impl.forward(op_.getInput(0),
                 op_.getInput(1),
                 op_.getOutput(0)->getImpl()->rawPtr(),
                 op_.getOutput(0)->dims());
}

template <> void ExpandImpl_cpu::backward() {
    AIDGE_THROW_OR_ABORT(
        std::runtime_error,
        "Backward not yet implemented for Expand_Op on backend cpu");
}

} // namespace Aidge