Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Squeeze.cpp 5.94 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/operator/Squeeze.hpp"

#include <algorithm>
#include <bitset>
#include <cstdint>
#include <fmt/core.h>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>

#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {
const std::string Squeeze_Op::Type = "Squeeze";

bool Squeeze_Op::dimsForwarded() const {
  if ((getInput(1) && !getInput(1)->undefined())) {
    // output dims are data dependent
    return false;
  }

  return OperatorTensor::dimsForwarded();
}

bool Squeeze_Op::forwardDims(bool allowDataDependency) {
  // error checking
  if (!inputsAssociated(false) || getInput(0)->undefined()) {
    return false;
  }

  std::shared_ptr<Tensor> fallback;
  // Input 1 is axes to squeeze (can also be given via attribute)
  if (getInput(1)) {
    if (!this->axes().empty()) {
      Log::notice("{} : ignoring non-empty axes attribute because input#1 "
                  "takes precedence",
                  type());
    }

    if (!allowDataDependency) {
      Log::warn("{} : unable to forwardDims() because output dims are data "
                "dependent on input#1",
                type());
      return false;
    }

    this->axes().clear(); // If both are provided input would override attrs
    this->axes().reserve(getInput(1)->size());
    const auto &axes =
        getInput(1)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
    if (axes.nbDims() == 0) {
      this->axes().clear();
    } else {
      AIDGE_ASSERT(
          axes.nbDims() == 1,
          "Axes input tensor should be of size 1. Received {} dimensions : {}",
          axes.nbDims(), axes.dims());
      std::copy_n(static_cast<int8_t *>(axes.getImpl()->hostPtr()), axes.size(),
                  std::back_inserter(this->axes()));
    }
  }

  std::vector<DimSize_t> input_dims = getInput(0)->dims();
  std::vector<DimSize_t> output_dims;
  output_dims.reserve(input_dims.size());
  std::vector<DimIdx_t> axes_rectified_idx;
  axes_rectified_idx.reserve(input_dims.size());

  if (this->axes().size() == 0) { // squeeze() => squeeze all 1 sized dimensions
    Log::debug("this->axes() is empty, all 1 sized dim will be squeezed. If "
               "this is an error ensure that the values are properly set via "
               "attribute or data input#1.");
    std::copy_if(input_dims.begin(), input_dims.end(),
                 std::back_inserter(output_dims),
                 [](DimSize_t dim) { return dim != 1; });
  } else { // squeeze({N,.....}) => squeeze all specified dimensions that are of
           // size 1.
    /////// ensure indexes validity and set pythonic negative indexes to their
    // positive value
    for (const int8_t &axis : this->axes()) {
      AIDGE_ASSERT(axis >= static_cast<int8_t>(-input_dims.size()) &&
                       axis < static_cast<int8_t>(input_dims.size()),
                   "{} : Axis index OutOfBounds error, expected value "
                   "within size limits of input tensor : "
                   "[-{},{}], got {}.",
                   type(), input_dims.size(), input_dims.size() - 1, axis);
      auto temp =
          static_cast<DimIdx_t>(axis >= 0 ? axis : axis + input_dims.size());
      if (axes_rectified_idx.end() == std::find(axes_rectified_idx.begin(),
                                                axes_rectified_idx.end(),
                                                temp)) {
        axes_rectified_idx.push_back(temp);
      }
    }

    // Create output_dims
    // speeds up binary search
    std::sort(axes_rectified_idx.begin(), axes_rectified_idx.end());
    DimSize_t i = 0;
    std::copy_if(
        input_dims.begin(), input_dims.end(), std::back_inserter(output_dims),
        [&axes_rectified_idx, &i, &input_dims](DimSize_t dim) {
          // if current dim index is found in axes to squeeze
          // we ensure that this axis is 1 sized, otherwise an error is thrown
          bool ok = true;
          if (std::binary_search(axes_rectified_idx.begin(),
                                 axes_rectified_idx.end(), i)) {
            AIDGE_ASSERT(dim == 1,
                         "{} : Tried to squeeze axis nb {} of a tensor of dim "
                         "{}. Dim to squeeze has to be 1-sized, got size {}."
                         "Axes to squeeze : {}",
                         __func__, i, input_dims, input_dims[i],
                         axes_rectified_idx);
            ok = false;
          }
          i++; // Incrementing counter since there is no enumerate
               // fctn (until C++23)
          return ok;
        });
  }
  mOutputs[0]->resize(output_dims);
  return true;
}

void Squeeze_Op::setBackend(const std::string &name,
                            Aidge::DeviceIdx_t device) {
  if (Registrar<Squeeze_Op>::exists({name})) {
    SET_IMPL_MACRO(Squeeze_Op, *this, name);
  } else {
    mImpl = std::make_shared<Squeeze_OpImpl>(*this);
  }
  mOutputs[0]->setBackend(name, device);
}

std::set<std::string> Aidge::Squeeze_Op::getAvailableBackends() const {
  return Registrar<Squeeze_Op>::getKeys();
}

void Aidge::Squeeze_OpImpl::forward() {
  const Squeeze_Op &op_ = static_cast<const Squeeze_Op &>(mOp);
  // Check if input is provided
  AIDGE_ASSERT(op_.getInput(0), "Squeeze : missing input 0");

  op_.getOutput(0)->getImpl()->copy(op_.getInput(0)->getImpl()->rawPtr(),
                                    op_.getInput(0)->size());
}

} // namespace Aidge