Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2428 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Conv.hpp 9.83 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_OPERATOR_CONV_H_
#define AIDGE_CORE_OPERATOR_CONV_H_

#include <array>
#include <cmath>
#include <numeric>
#include <vector>

#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {
enum class ConvParam { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims };

template <DimIdx_t DIM>
class Conv_Op : public Operator,
                public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
                public Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
                                       DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> {
public:
    // FIXME: change accessibility
    std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
                                                      std::make_shared<Tensor>()};
    const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();

   public:
    static constexpr const char *Type = "Conv";

    Conv_Op() = delete;

    using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
                                             DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>;
    template <ConvParam e>
    using param = typename Parameterizable_::template param<e>;

    constexpr Conv_Op(DimSize_t in_channels,
                      DimSize_t out_channels,
                      const std::array<DimSize_t, DIM> &kernel_dims,
                      const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
                      const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
                      const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
        : Operator(Type),
          Parameterizable_(param<ConvParam::StrideDims>(stride_dims),
                           param<ConvParam::DilationDims>(dilation_dims),
                           param<ConvParam::InChannels>(in_channels),
                           param<ConvParam::OutChannels>(out_channels),
                           param<ConvParam::KernelDims>(kernel_dims),
                           param<ConvParam::PaddingDims>(padding_dims)) {
        setDatatype(DataType::Float32);
    }

    /**
     * @brief Copy-constructor. Copy the operator parameters and its output tensor(s), but not its input tensors (the new operator has no input associated).
     * @param op Operator to copy.