-
Cyril Moineau authoredCyril Moineau authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Squeeze.cpp 5.94 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Squeeze.hpp"
#include <algorithm>
#include <bitset>
#include <cstdint>
#include <fmt/core.h>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
const std::string Squeeze_Op::Type = "Squeeze";
bool Squeeze_Op::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined())) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Squeeze_Op::forwardDims(bool allowDataDependency) {
// error checking
if (!inputsAssociated(false) || getInput(0)->undefined()) {
return false;
}
std::shared_ptr<Tensor> fallback;
// Input 1 is axes to squeeze (can also be given via attribute)
if (getInput(1)) {
if (!this->axes().empty()) {
Log::notice("{} : ignoring non-empty axes attribute because input#1 "
"takes precedence",
type());
}
if (!allowDataDependency) {
Log::warn("{} : unable to forwardDims() because output dims are data "
"dependent on input#1",
type());
return false;
}
this->axes().clear(); // If both are provided input would override attrs
this->axes().reserve(getInput(1)->size());
const auto &axes =
getInput(1)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
if (axes.nbDims() == 0) {
this->axes().clear();
} else {
AIDGE_ASSERT(
axes.nbDims() == 1,
"Axes input tensor should be of size 1. Received {} dimensions : {}",
axes.nbDims(), axes.dims());
std::copy_n(static_cast<int8_t *>(axes.getImpl()->hostPtr()), axes.size(),
std::back_inserter(this->axes()));
}
}
std::vector<DimSize_t> input_dims = getInput(0)->dims();
std::vector<DimSize_t> output_dims;
output_dims.reserve(input_dims.size());
std::vector<DimIdx_t> axes_rectified_idx;
axes_rectified_idx.reserve(input_dims.size());
if (this->axes().size() == 0) { // squeeze() => squeeze all 1 sized dimensions
Log::debug("this->axes() is empty, all 1 sized dim will be squeezed. If "
"this is an error ensure that the values are properly set via "
"attribute or data input#1.");
std::copy_if(input_dims.begin(), input_dims.end(),
std::back_inserter(output_dims),
[](DimSize_t dim) { return dim != 1; });
} else { // squeeze({N,.....}) => squeeze all specified dimensions that are of
// size 1.
/////// ensure indexes validity and set pythonic negative indexes to their
// positive value
for (const int8_t &axis : this->axes()) {
AIDGE_ASSERT(axis >= static_cast<int8_t>(-input_dims.size()) &&
axis < static_cast<int8_t>(input_dims.size()),
"{} : Axis index OutOfBounds error, expected value "
"within size limits of input tensor : "
"[-{},{}], got {}.",
type(), input_dims.size(), input_dims.size() - 1, axis);
auto temp =
static_cast<DimIdx_t>(axis >= 0 ? axis : axis + input_dims.size());
if (axes_rectified_idx.end() == std::find(axes_rectified_idx.begin(),
axes_rectified_idx.end(),
temp)) {
axes_rectified_idx.push_back(temp);
}
}
// Create output_dims
// speeds up binary search
std::sort(axes_rectified_idx.begin(), axes_rectified_idx.end());
DimSize_t i = 0;
std::copy_if(
input_dims.begin(), input_dims.end(), std::back_inserter(output_dims),
[&axes_rectified_idx, &i, &input_dims](DimSize_t dim) {
// if current dim index is found in axes to squeeze
// we ensure that this axis is 1 sized, otherwise an error is thrown
bool ok = true;
if (std::binary_search(axes_rectified_idx.begin(),
axes_rectified_idx.end(), i)) {
AIDGE_ASSERT(dim == 1,
"{} : Tried to squeeze axis nb {} of a tensor of dim "
"{}. Dim to squeeze has to be 1-sized, got size {}."
"Axes to squeeze : {}",
__func__, i, input_dims, input_dims[i],
axes_rectified_idx);
ok = false;
}
i++; // Incrementing counter since there is no enumerate
// fctn (until C++23)
return ok;
});
}
mOutputs[0]->resize(output_dims);
return true;
}
void Squeeze_Op::setBackend(const std::string &name,
Aidge::DeviceIdx_t device) {
if (Registrar<Squeeze_Op>::exists({name})) {
SET_IMPL_MACRO(Squeeze_Op, *this, name);
} else {
mImpl = std::make_shared<Squeeze_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Squeeze_Op::getAvailableBackends() const {
return Registrar<Squeeze_Op>::getKeys();
}
void Aidge::Squeeze_OpImpl::forward() {
const Squeeze_Op &op_ = static_cast<const Squeeze_Op &>(mOp);
// Check if input is provided
AIDGE_ASSERT(op_.getInput(0), "Squeeze : missing input 0");
op_.getOutput(0)->getImpl()->copy(op_.getInput(0)->getImpl()->rawPtr(),
op_.getInput(0)->size());
}
} // namespace Aidge