Skip to content
Snippets Groups Projects
Commit 40b8dddd authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Merge branch 'feat/enhance_tranpose' into 'dev'

Feat: Support empty permutation vector for Transpose

See merge request eclipse/aidge/aidge_core!237
parents 929a5059 f32515a9
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!237Feat: Support empty permutation vector for Transpose
Pipeline #58315 passed
......@@ -25,6 +25,11 @@
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief implementation of the operator Transpose.
* @note Since this operator implementation is agnostic to the backend it is
* located here instead of in aidge_backend.
*/
class TransposeImpl : public OperatorImpl {
public:
TransposeImpl(const Operator& op, const std::string& backend = "")
......@@ -33,8 +38,22 @@ public:
void forward() override;
};
enum class TransposeAttr { OutputDimsOrder };
enum class TransposeAttr {
/**
* @brief order of the ouput dims from the input dims. If left empty,
* the dimensions of input will be reversed.
*/
OutputDimsOrder
};
/**
* @brief This operator has as purpose to transpose the axes of a given tensor.
* input#0 : Tensor to transpose
* @example Calling transpose() on a tensor of dimensions [1, 2, 3] with OutputDimsOrder=(1,0,2) result
* in a tensor of dim [2, 1, 3].
* @example Calling transpose() on a tensor of dimensions [1,2,3,4] with an empty OutputDimsOrder vector
* will result in a tensor of dim [4,3,2,1].
*/
class Transpose_Op : public OperatorTensor,
public Registrable<Transpose_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Transpose_Op&)>> {
......@@ -50,6 +69,10 @@ private:
public:
Transpose_Op() = delete;
/**
* @brief constructor for Transpose op
* @param[in] outputDimsOrder axes permutation order. By default axes are reversed.
*/
Transpose_Op(const std::vector<DimSize_t> &outputDimsOrder);
/**
......@@ -70,6 +93,9 @@ public:
std::set<std::string> getAvailableBackends() const override;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief axes new order, if left empty, axes will be reversed.
*/
inline std::vector<DimSize_t>& outputDimsOrder() const noexcept { return mAttributes -> getAttr<TransposeAttr::OutputDimsOrder>(); }
static const std::vector<std::string> getInputsName(){
......@@ -80,8 +106,8 @@ public:
}
};
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder,
const std::string& name = "");
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder = {},
const std::string& name = "");
} // namespace Aidge
namespace {
......
......@@ -28,13 +28,26 @@ namespace Aidge {
void declare_Transpose(py::module &m) {
const std::string pyClassName("TransposeOp");
py::class_<Transpose_Op, std::shared_ptr<Transpose_Op>, OperatorTensor>(
m, "TransposeOp", py::multiple_inheritance())
.def(py::init<const std::vector<DimSize_t>&>(), py::arg("output_dims_order"))
m, "TransposeOp", py::multiple_inheritance(),
R"mydelimiter(
Initialize transpose operator
:param output_dims_order : axes permutation order, must be of rank = r and values between [0;r-1]
with r = input_tensor.nbDims()
:type output_dims_order : :py:class: List[Int]
)mydelimiter")
.def(py::init<const std::vector<DimSize_t>&>(), py::arg("output_dims_order")=std::vector<std::size_t>())
.def_static("get_inputs_name", &Transpose_Op::getInputsName)
.def_static("get_outputs_name", &Transpose_Op::getOutputsName)
.def_readonly_static("Type", &Transpose_Op::Type);
declare_registrable<Transpose_Op>(m, pyClassName);
m.def("Transpose", &Transpose, py::arg("output_dims_order"), py::arg("name") = "");
m.def("Transpose", &Transpose, py::arg("output_dims_order")=std::vector<std::size_t>(), py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a transpose operator.
:param output_dims_order : axes permutation order, must be of rank = r and values between [0;r-1]
with r = input_tensor.nbDims()
:type output_dims_order : :py:class: List[Int]
:param name : name of the node.
)mydelimiter");
}
void init_Transpose(py::module &m) {
......
......@@ -59,6 +59,15 @@ std::shared_ptr<Aidge::Operator> Aidge::Transpose_Op::clone() const {
bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
AIDGE_ASSERT(!getInput(0)->empty(), "Not applicable on scalars.");
// If permutation vector is not given, reverse the dims of input tensor
if (outputDimsOrder().empty())
{
this->outputDimsOrder().resize(getInput(0)->nbDims());
std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0);
}
AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(),
"Permutation vector must have the same rank as input tensor.");
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) {
outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]);
......@@ -86,6 +95,6 @@ std::set<std::string> Aidge::Transpose_Op::getAvailableBackends() const {
//////////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Transpose(const std::vector<Aidge::DimSize_t> &outputDimsOrder,
const std::string& name) {
const std::string& name) {
return std::make_shared<Node>(std::make_shared<Transpose_Op>(outputDimsOrder), name);
}
\ No newline at end of file
......@@ -128,6 +128,75 @@ TEST_CASE("[cpu/operator] Transpose(forward)") {
op->setBackend("cpu");
myTranspose->forward();
REQUIRE(*(op->getOutput(0)) == *output);
}
SECTION("Default permutation") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,2,3,1,4> {
{
{
{
{1, 2, 3, 4}
},
{
{5, 6, 7, 8}
},
{
{9, 10, 11, 12}
}
},
{
{
{13, 14, 15, 16}
},
{
{17, 18, 19, 20}
},
{
{21, 22, 23, 24}
}
}
}
});
std::shared_ptr<Tensor> output = std::make_shared<Tensor>(Array4D<int,4,1,3,2> {
{
{
{
{ 1, 13},
{ 5, 17},
{ 9, 21}
}
},
{
{
{ 2, 14},
{ 6, 18},
{10, 22}
}
},
{
{
{ 3, 15},
{ 7, 19},
{11, 23}
}
},
{
{
{ 4, 16},
{ 8, 20},
{12, 24}
}
}
}
});
std::shared_ptr<Node> myTranspose = Transpose({});
auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myTranspose->forward();
REQUIRE(*(op->getOutput(0)) == *output);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment