Skip to content
Snippets Groups Projects

feat_operator_convtranspose

Merged Grégoire Kubler requested to merge feat_operator_convtranspose into dev
5 files
+ 884
0
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 208
0
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_
#define AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_
#include <array>
#include <cmath> // std::floor
#include <string>
#include <utility> // std::pair
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp" // SET_IMPL_MACRO
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConvTransposeAttr { StrideDims, DilationDims, KernelDims };
template <DimIdx_t DIM>
class ConvTranspose_Op
: public OperatorTensor,
public Registrable<ConvTranspose_Op<DIM>,
std::string,
std::function<std::shared_ptr<OperatorImpl>(
const ConvTranspose_Op<DIM> &)>> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<ConvTransposeAttr,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>>;
template <ConvTransposeAttr e>
using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
ConvTranspose_Op() = delete;
constexpr explicit ConvTranspose_Op(
const std::array<DimSize_t, DIM> &kernelDims,
const std::array<DimSize_t, DIM> &strideDims =
create_array<DimSize_t, DIM>(1),
const std::array<DimSize_t, DIM> &dilationDims =
create_array<DimSize_t, DIM>(1))
: OperatorTensor(Type,
{InputCategory::Data,
InputCategory::Param,
InputCategory::OptionalParam},
1),
mAttributes(std::make_shared<Attributes_>(
attr<ConvTransposeAttr::StrideDims>(strideDims),
attr<ConvTransposeAttr::DilationDims>(dilationDims),
attr<ConvTransposeAttr::KernelDims>(kernelDims))) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
ConvTranspose_Op(const ConvTranspose_Op<DIM> &op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Conv_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<ConvTranspose_Op<DIM>>(*this);
}
bool forwardDims(bool /*allowDataDependency*/ = false) override final;
std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>>
computeReceptiveField(const std::vector<DimSize_t> &firstEltDims,
const std::vector<DimSize_t> &outputDims,
const IOIndex_t outputIdx = 0) const override;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
std::set<std::string> getAvailableBackends() const override;
DimSize_t inChannels() const {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"{}: operator has no weight Tensor associated so no "
"specific number of input channel imposed.",
Type);
}
return getInput(1)->template dims<DIM + 2>()[0];
}
DimSize_t outChannels() const {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(
std::runtime_error,
"{}: operator has no weight Tensor associated so no "
"specific number of output channel imposed.",
Type);
}
return getInput(1)->template dims<DIM + 2>()[1];
}
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
inline std::array<DimSize_t, DIM> &strideDims() const {
return mAttributes->template getAttr<ConvTransposeAttr::StrideDims>();
}
inline std::array<DimSize_t, DIM> &dilationDims() const {
return mAttributes
->template getAttr<ConvTransposeAttr::DilationDims>();
}
inline std::array<DimSize_t, DIM> &kernelDims() const {
return mAttributes->template getAttr<ConvTransposeAttr::KernelDims>();
}
static const std::vector<std::string> getInputsName() {
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
/**
* @brief Perform a convTranspose(/deconvolution) on the input Tensor.
*
* @tparam DIM Number of dimensions for the feature map.
* @param inChannels Number of input channels.
* @param outChannels Number of output channels.
* @param kernelDims Dimensions of the kernel. Must be the same number of
* dimensions as the feature map.
* @param name Name of the operator.
* @param strideDims Dimensions of the stride attribute. Must be the same
* number of dimensions as the feature map.
* @param dilationDims Dimensions of the dilation attribute. Must be the same
* number of dimensions as the feature map.
* @return std::shared_ptr<Node> A Node containing the operator.
*/
template <std::array<DimIdx_t, 1>::size_type DIM>
std::shared_ptr<Node>
ConvTranspose(const DimSize_t &inChannels,
const DimSize_t &outChannels,
const std::array<DimSize_t, DIM> &kernelDims,
const std::array<DimSize_t, DIM> &strideDims =
create_array<DimSize_t, DIM>(1),
const std::array<DimSize_t, DIM> &dilationDims =
create_array<DimSize_t, DIM>(1),
const bool noBias = false,
const std::string &name = "");
// helper with C-style array instead of std::array for kernel_dims to allow
// automatic template DIM deduction
/**
* @brief Conv Transpose node constructor
* @param[in] inChannels number of input channels of the conv transpose
* operator
* @param[in] outChannels number of ouptut channels of the convTranspose
* operator
* @param[in] kernelDims array of size DIM describing the dimensions of the
* kernel
* @param[in] name name of the node
* @param[in] strideDims stride along each dimension of the operator
* @param[in] dilationDims dilation along each dimension of the operator
* @param[in] noBias describes if the operator has biases or just weights
*/
template <DimIdx_t DIM>
inline std::shared_ptr<Node>
ConvTranspose(const DimSize_t &inChannels,
const DimSize_t &outChannels,
DimSize_t const (&kernelDims)[DIM],
const std::array<DimSize_t, DIM> &strideDims =
create_array<DimSize_t, DIM>(1),
const std::array<DimSize_t, DIM> &dilationDims =
create_array<DimSize_t, DIM>(1),
const bool noBias = false,
const std::string &name = "");
} // namespace Aidge
extern template class Aidge::ConvTranspose_Op<1>;
extern template class Aidge::ConvTranspose_Op<2>;
namespace {
template <>
const char *const EnumStrings<Aidge::ConvTransposeAttr>::data[] = {
"stride_dims",
"dilation_dims",
"kernel_dims"};
}
#endif /* AIDGE_CORE_OPERATOR_CONVTRANSPOSE_H_ */
Loading