Skip to content
Snippets Groups Projects
Commit f32515a9 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

support empty permutation vector for Transpose

parent 929a5059
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!237Feat: Support empty permutation vector for Transpose
Pipeline #58281 passed
......@@ -25,6 +25,11 @@
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief implementation of the operator Transpose.
* @note Since this operator implementation is agnostic to the backend it is
* located here instead of in aidge_backend.
*/
class TransposeImpl : public OperatorImpl {
public:
TransposeImpl(const Operator& op, const std::string& backend = "")
......@@ -33,8 +38,22 @@ public:
void forward() override;
};
enum class TransposeAttr { OutputDimsOrder };
enum class TransposeAttr {
/**
* @brief order of the ouput dims from the input dims. If left empty,
* the dimensions of input will be reversed.
*/
OutputDimsOrder
};
/**
* @brief This operator has as purpose to transpose the axes of a given tensor.
* input#0 : Tensor to transpose
* @example Calling transpose() on a tensor of dimensions [1, 2, 3] with OutputDimsOrder=(1,0,2) result
* in a tensor of dim [2, 1, 3].
* @example Calling transpose() on a tensor of dimensions [1,2,3,4] with an empty OutputDimsOrder vector
* will result in a tensor of dim [4,3,2,1].
*/
class Transpose_Op : public OperatorTensor,
public Registrable<Transpose_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Transpose_Op&)>> {
......@@ -50,6 +69,10 @@ private:
public:
Transpose_Op() = delete;
/**
* @brief constructor for Transpose op
* @param[in] outputDimsOrder axes permutation order. By default axes are reversed.
*/
Transpose_Op(const std::vector<DimSize_t> &outputDimsOrder);
/**
......@@ -70,6 +93,9 @@ public:
std::set<std::string> getAvailableBackends() const override;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief axes new order, if left empty, axes will be reversed.
*/
inline std::vector<DimSize_t>& outputDimsOrder() const noexcept { return mAttributes -> getAttr<TransposeAttr::OutputDimsOrder>(); }
static const std::vector<std::string> getInputsName(){
......@@ -80,8 +106,8 @@ public:
}
};
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder,
const std::string& name = "");
std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder = {},
const std::string& name = "");
} // namespace Aidge
namespace {
......
......@@ -28,13 +28,26 @@ namespace Aidge {
void declare_Transpose(py::module &m) {
const std::string pyClassName("TransposeOp");
py::class_<Transpose_Op, std::shared_ptr<Transpose_Op>, OperatorTensor>(
m, "TransposeOp", py::multiple_inheritance())
.def(py::init<const std::vector<DimSize_t>&>(), py::arg("output_dims_order"))
m, "TransposeOp", py::multiple_inheritance(),
R"mydelimiter(
Initialize transpose operator
:param output_dims_order : axes permutation order, must be of rank = r and values between [0;r-1]
with r = input_tensor.nbDims()
:type output_dims_order : :py:class: List[Int]
)mydelimiter")
.def(py::init<const std::vector<DimSize_t>&>(), py::arg("output_dims_order")=std::vector<std::size_t>())
.def_static("get_inputs_name", &Transpose_Op::getInputsName)
.def_static("get_outputs_name", &Transpose_Op::getOutputsName)
.def_readonly_static("Type", &Transpose_Op::Type);
declare_registrable<Transpose_Op>(m, pyClassName);
m.def("Transpose", &Transpose, py::arg("output_dims_order"), py::arg("name") = "");
m.def("Transpose", &Transpose, py::arg("output_dims_order")=std::vector<std::size_t>(), py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a transpose operator.
:param output_dims_order : axes permutation order, must be of rank = r and values between [0;r-1]
with r = input_tensor.nbDims()
:type output_dims_order : :py:class: List[Int]
:param name : name of the node.
)mydelimiter");
}
void init_Transpose(py::module &m) {
......
......@@ -59,6 +59,15 @@ std::shared_ptr<Aidge::Operator> Aidge::Transpose_Op::clone() const {
bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
AIDGE_ASSERT(!getInput(0)->empty(), "Not applicable on scalars.");
// If permutation vector is not given, reverse the dims of input tensor
if (outputDimsOrder().empty())
{
this->outputDimsOrder().resize(getInput(0)->nbDims());
std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0);
}
AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(),
"Permutation vector must have the same rank as input tensor.");
std::vector<DimSize_t> outputDims;
for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) {
outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]);
......@@ -86,6 +95,6 @@ std::set<std::string> Aidge::Transpose_Op::getAvailableBackends() const {
//////////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Transpose(const std::vector<Aidge::DimSize_t> &outputDimsOrder,
const std::string& name) {
const std::string& name) {
return std::make_shared<Node>(std::make_shared<Transpose_Op>(outputDimsOrder), name);
}
\ No newline at end of file
......@@ -128,6 +128,75 @@ TEST_CASE("[cpu/operator] Transpose(forward)") {
op->setBackend("cpu");
myTranspose->forward();
REQUIRE(*(op->getOutput(0)) == *output);
}
SECTION("Default permutation") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,2,3,1,4> {
{
{
{
{1, 2, 3, 4}
},
{
{5, 6, 7, 8}
},
{
{9, 10, 11, 12}
}
},
{
{
{13, 14, 15, 16}
},
{
{17, 18, 19, 20}
},
{
{21, 22, 23, 24}
}
}
}
});
std::shared_ptr<Tensor> output = std::make_shared<Tensor>(Array4D<int,4,1,3,2> {
{
{
{
{ 1, 13},
{ 5, 17},
{ 9, 21}
}
},
{
{
{ 2, 14},
{ 6, 18},
{10, 22}
}
},
{
{
{ 3, 15},
{ 7, 19},
{11, 23}
}
},
{
{
{ 4, 16},
{ 8, 20},
{12, 24}
}
}
}
});
std::shared_ptr<Node> myTranspose = Transpose({});
auto op = std::static_pointer_cast<OperatorTensor>(myTranspose -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myTranspose->forward();
REQUIRE(*(op->getOutput(0)) == *output);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment