Skip to content
Snippets Groups Projects

Feat: Support empty permutation vector for Transpose

Merged Houssem ROUIS requested to merge feat/enhance_tranpose into dev
1 unresolved thread
Files
4
@@ -25,6 +25,11 @@
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief implementation of the operator Transpose.
* @note Since this operator implementation is agnostic to the backend it is
* located here instead of in aidge_backend.
*/
class TransposeImpl : public OperatorImpl {
public:
TransposeImpl(const Operator& op, const std::string& backend = "")
@@ -33,8 +38,22 @@ public:
void forward() override;
};
enum class TransposeAttr { OutputDimsOrder };
enum class TransposeAttr {
/**
* @brief order of the ouput dims from the input dims. If left empty,
* the dimensions of input will be reversed.
*/
OutputDimsOrder
};
/**
* @brief This operator has as purpose to transpose the axes of a given tensor.
* input#0 : Tensor to transpose
* @example Calling transpose() on a tensor of dimensions [1, 2, 3] with OutputDimsOrder=(1,0,2) result
* in a tensor of dim [2, 1, 3].
* @example Calling transpose() on a tensor of dimensions [1,2,3,4] with an empty OutputDimsOrder vector
* will result in a tensor of dim [4,3,2,1].
*/
class Transpose_Op : public OperatorTensor,
public Registrable<Transpose_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Transpose_Op&)>> {
@@ -50,6 +69,10 @@ private:
public:
Transpose_Op() = delete;
/**
* @brief constructor for Transpose op
* @param[in] outputDimsOrder axes permutation order. By default axes are reversed.
*/
Transpose_Op(const std::vector<DimSize_t> &outputDimsOrder);
/**
@@ -70,6 +93,9 @@ public:
std::set<std::string> getAvailableBackends() const override;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief axes new order, if left empty, axes will be reversed.
*/
inline std::vector<DimSize_t>& outputDimsOrder() const noexcept { return mAttributes -> getAttr<TransposeAttr::OutputDimsOrder>(); }
static const std::vector<std::string> getInputsName(){
@@ -80,8 +106,8 @@ public:
}
};
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder,
const std::string& name = "");
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder = {},
const std::string& name = "");
} // namespace Aidge
namespace {
Loading