Skip to content
Snippets Groups Projects
Commit 09064c35 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] some function names

parent d83ddf49
No related branches found
No related tags found
2 merge requests!408[Add] Dropout Operator,!355[Fix] Tensor::setDataFormat handling of DataFormat::Default
Pipeline #66351 passed
......@@ -66,7 +66,7 @@ void bindEnum(py::module& m, const std::string& name) {
void init_DataFormat(py::module& m) {
bindEnum<DataFormat>(m, "dformat");
m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df"));
m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst"));
m.def("get_permutation_mapping", &getPermutationMapping, py::arg("src"), py::arg("dst"));
}
} // namespace Aidge
......@@ -196,7 +196,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im
&& spec.format != DataFormat::Any
&& required.format != spec.format)
{
const auto transpose = getDataFormatTranspose(required.format, spec.format);
const auto transpose = getPermutationMapping(required.format, spec.format);
std::vector<size_t> identity(transpose.size());
std::iota(std::begin(identity), std::end(identity), 0);
......@@ -261,7 +261,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format)
{
const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
const auto transpose = getPermutationMapping(requiredIOSpec.format, IOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(IOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type);
......@@ -315,7 +315,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format)
{
const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format);
const auto transpose = getPermutationMapping(IOSpec.format, requiredIOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type);
......
......@@ -94,14 +94,14 @@ void Aidge::explicitTranspose(std::shared_ptr<GraphView> graph) {
else {
// Case 2: change of format
// => compute the new permutation array
const auto transpose = getDataFormatTranspose(parentInput->dataFormat(), output->dataFormat());
const auto transpose = getPermutationMapping(parentInput->dataFormat(), output->dataFormat());
auto transposeOp = std::static_pointer_cast<Transpose_Op>(parent.first->getOperator());
transposeOp->setDataFormat(output->dataFormat());
transposeOp->outputDimsOrder() = std::vector<DimSize_t>(transpose.begin(), transpose.end());
}
}
else {
const auto transpose = getDataFormatTranspose(input->dataFormat(), output->dataFormat());
const auto transpose = getPermutationMapping(input->dataFormat(), output->dataFormat());
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(output->dataFormat());
transposeOp->getOperator()->setDataType(output->dataType());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment