Skip to content
Snippets Groups Projects
Commit 09064c35 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] some function names

parent d83ddf49
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!355[Fix] Tensor::setDataFormat handling of DataFormat::Default
Pipeline #66351 passed
...@@ -66,7 +66,7 @@ void bindEnum(py::module& m, const std::string& name) { ...@@ -66,7 +66,7 @@ void bindEnum(py::module& m, const std::string& name) {
void init_DataFormat(py::module& m) { void init_DataFormat(py::module& m) {
bindEnum<DataFormat>(m, "dformat"); bindEnum<DataFormat>(m, "dformat");
m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df")); m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df"));
m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst")); m.def("get_permutation_mapping", &getPermutationMapping, py::arg("src"), py::arg("dst"));
} }
} // namespace Aidge } // namespace Aidge
...@@ -196,7 +196,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im ...@@ -196,7 +196,7 @@ bool Aidge::OperatorImpl::checkIOSpec(const ImplSpec::IOSpec& required, const Im
&& spec.format != DataFormat::Any && spec.format != DataFormat::Any
&& required.format != spec.format) && required.format != spec.format)
{ {
const auto transpose = getDataFormatTranspose(required.format, spec.format); const auto transpose = getPermutationMapping(required.format, spec.format);
std::vector<size_t> identity(transpose.size()); std::vector<size_t> identity(transpose.size());
std::iota(std::begin(identity), std::end(identity), 0); std::iota(std::begin(identity), std::end(identity), 0);
...@@ -261,7 +261,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -261,7 +261,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& IOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format) && requiredIOSpec.format != IOSpec.format)
{ {
const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format); const auto transpose = getPermutationMapping(requiredIOSpec.format, IOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(IOSpec.format); transposeOp->getOperator()->setDataFormat(IOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type); transposeOp->getOperator()->setDataType(requiredIOSpec.type);
...@@ -315,7 +315,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec& ...@@ -315,7 +315,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
&& IOSpec.format != DataFormat::Any && IOSpec.format != DataFormat::Any
&& requiredIOSpec.format != IOSpec.format) && requiredIOSpec.format != IOSpec.format)
{ {
const auto transpose = getDataFormatTranspose(IOSpec.format, requiredIOSpec.format); const auto transpose = getPermutationMapping(IOSpec.format, requiredIOSpec.format);
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(requiredIOSpec.format); transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
transposeOp->getOperator()->setDataType(requiredIOSpec.type); transposeOp->getOperator()->setDataType(requiredIOSpec.type);
......
...@@ -94,14 +94,14 @@ void Aidge::explicitTranspose(std::shared_ptr<GraphView> graph) { ...@@ -94,14 +94,14 @@ void Aidge::explicitTranspose(std::shared_ptr<GraphView> graph) {
else { else {
// Case 2: change of format // Case 2: change of format
// => compute the new permutation array // => compute the new permutation array
const auto transpose = getDataFormatTranspose(parentInput->dataFormat(), output->dataFormat()); const auto transpose = getPermutationMapping(parentInput->dataFormat(), output->dataFormat());
auto transposeOp = std::static_pointer_cast<Transpose_Op>(parent.first->getOperator()); auto transposeOp = std::static_pointer_cast<Transpose_Op>(parent.first->getOperator());
transposeOp->setDataFormat(output->dataFormat()); transposeOp->setDataFormat(output->dataFormat());
transposeOp->outputDimsOrder() = std::vector<DimSize_t>(transpose.begin(), transpose.end()); transposeOp->outputDimsOrder() = std::vector<DimSize_t>(transpose.begin(), transpose.end());
} }
} }
else { else {
const auto transpose = getDataFormatTranspose(input->dataFormat(), output->dataFormat()); const auto transpose = getPermutationMapping(input->dataFormat(), output->dataFormat());
auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end())); auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
transposeOp->getOperator()->setDataFormat(output->dataFormat()); transposeOp->getOperator()->setDataFormat(output->dataFormat());
transposeOp->getOperator()->setDataType(output->dataType()); transposeOp->getOperator()->setDataType(output->dataType());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment