Skip to content
Snippets Groups Projects
Slice.cpp 9.51 KiB
Newer Older
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/operator/Slice.hpp"
Houssem ROUIS's avatar
Houssem ROUIS committed
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <utility>
#include <vector>
#include <fmt/format.h>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Slice_Op::Type = "Slice";

bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
Houssem ROUIS's avatar
Houssem ROUIS committed
    // check inputs have been associated
    if (!getInput(0)) {
        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
    if(!getInput(0)->empty())
        if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty())
        {
            if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) {
                AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type());
            }

            AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");

Houssem ROUIS's avatar
Houssem ROUIS committed
            this->template getAttr<SliceAttr::Starts>().clear();
            this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
            this->template getAttr<SliceAttr::Ends>().clear();
            this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
            this->template getAttr<SliceAttr::Axes>().clear();
            this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size());
            switch (mInputs[1]->dataType()) {
                case DataType::Float64:
                    std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
                                getInput(1)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
                    std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()),
                                getInput(2)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
                    std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()),
                                getInput(3)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
                    break;
                case DataType::Float32:
                    std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
                                getInput(1)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
                    std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()),
                                getInput(2)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
                    std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()),
                                getInput(3)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
                    break;
                case DataType::Int64:
                    std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
                                getInput(1)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
                    std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()),
                                getInput(2)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
                    std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()),
                                getInput(3)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
                    break;
                case DataType::Int32:
                    std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
                                getInput(1)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
                    std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()),
                                getInput(2)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
                    std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()),
                                getInput(3)->size(),
                                std::back_inserter(this->template getAttr<SliceAttr::Axes>()));                                
                    break;
                default:
Houssem ROUIS's avatar
Houssem ROUIS committed
                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type());
Houssem ROUIS's avatar
Houssem ROUIS committed
        // Fill Steps attr if empty
        if(this->template getAttr<SliceAttr::Steps>().empty()) {
            // In case the input Steps is not provided, default value is 1
            this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1);

            if (getInput(4) && !getInput(4)->empty()) {
                this->template getAttr<SliceAttr::Steps>().clear();
                this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size());
                switch (mInputs[1]->dataType()) {
                    case DataType::Float64:
                        std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()),
                                    getInput(4)->size(),
                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
                        break;
                    case DataType::Float32:
                        std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()),
                                    getInput(4)->size(),
                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
                        break;
                    case DataType::Int64:
                        std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()),
                                    getInput(4)->size(),
                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
                        break;
                    case DataType::Int32:
                        std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()),
                                    getInput(4)->size(),
                                    std::back_inserter(this->template getAttr<SliceAttr::Steps>()));                              
                        break;
                    default:
                        AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
                        break;
                }
            }
        }
        DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
        std::vector<DimSize_t> outDims = getInput(0)->dims();
        for (std::size_t i = 0; i < nbAxes; ++i) {
            DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
                            static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
                            static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
            DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
                              static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
                              static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
            DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
                            static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
                            static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
Houssem ROUIS's avatar
Houssem ROUIS committed
            if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
                AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
            }
            const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
            // Check if slice length is valid
            if (sliceLength > getInput(0)->dims()[axis])
            {
                AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
            }
            outDims[axis] = sliceLength;
        mOutputs[0]->resize(outDims);
Olivier BICHLER's avatar
Olivier BICHLER committed

void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
    SET_IMPL_MACRO(Slice_Op, *this, name);
Olivier BICHLER's avatar
Olivier BICHLER committed
    mOutputs[0]->setBackend(name, device);
}