Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
DivImpl.cpp 7.39 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <memory>
#include <vector>

#include "aidge/backend/cpu/data/Broadcasting.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/DivImpl.hpp"
#include "aidge/backend/cpu/operator/DivImpl_forward_kernels.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"

Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
    // this implementation can be in-place
    return 0;
}

void Aidge::DivImpl_cpu::forward() {
    // Find the correct kernel type
    // auto kernelFunc = Registrar<DivImplForward_cpu>::create({
    //     std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
    //     std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
    //     std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});

    // const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
    //                                                                std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
    // const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
    //                                                                std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());


    // auto a = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
    // auto b = std::static_pointer_cast<Tensor>(mOp.getRawInput(1));

    // // Call kernel
    // kernelFunc(inputDims0,
    //     inputDims1,
    //     std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
    //     getCPUPtr(mOp.getRawInput(0)),
    //     getCPUPtr(mOp.getRawInput(1)),
    //     getCPUPtr(mOp.getRawOutput(0)));

/////////////////////////////////////////////////////////////////

    // [5,2,1,7] & [2,6,7]
    // 1. Same number of dimensions -> [5,2,1,7] & [1,2,6,7]
    // 2. Find the highest equal dimension -> 3
    //    Exception: if the first diverging dimension is the last one, then -> 4 (dims.size())
    // 3. Compute the highest number of contiguous data -> 7
    // 4. Compute stride and offset step for the broadcast mechnism
    // 5. Call a simple kernel

    // Find the correct kernel type
    auto kernelFunc = Registrar<DivImplForward_cpu>::create({
        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});

    // Compute compatible input dimensions
    std::vector<std::size_t>        dims0   = static_cast<const Div_Op&>(mOp).getInput(0)->dims();
    std::vector<std::size_t>        dims1   = static_cast<const Div_Op&>(mOp).getInput(1)->dims();
    const std::vector<std::size_t>& outDims = static_cast<const Div_Op&>(mOp).getOutput(0)->dims();

    // if (dims0 == dims1) {
    //     const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
    //     kernelFunc(input0_contiguous_size, input0_contiguous_size, input0_contiguous_size,
    //                 getCPUPtr(mOp.getRawInput(0)),
    //                 getCPUPtr(mOp.getRawInput(1)),
    //                 getCPUPtr(mOp.getRawOutput(0)));
    //     return;
    // }

    if (dims0.size() > dims1.size()) {
        dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1));
    }
    else if (dims1.size() > dims0.size()) {
        dims0.insert(dims0.cbegin(), dims1.size() - dims0.size(), std::size_t(1));
    }

    const std::size_t nbDims = dims0.size();

    // Find the highest equal dimension
    std::size_t contiguousIdx = nbDims - 1;
    for (; contiguousIdx+1 > 0; --contiguousIdx) {
        if (dims0[contiguousIdx] != dims1[contiguousIdx]) {
            if (contiguousIdx == (nbDims -1)) { // last dimensions of one of the input Tensor are of size 1
                const std::vector<std::size_t>& dims = (dims0[contiguousIdx] == 1) ? dims0 : dims1;
                while ((contiguousIdx+1 > 0) && (dims[contiguousIdx] == 1)) {
                    --contiguousIdx;
                }
            }
            break;
        }
    }
    ++contiguousIdx;

    // Compute the highest number of contiguous data for each Tensor
    const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin()+contiguousIdx, dims0.cend(), std::size_t(1), std::multiplies<std::size_t>());
    const std::size_t input1_contiguous_size = std::accumulate(dims1.cbegin()+contiguousIdx, dims1.cend(), std::size_t(1), std::multiplies<std::size_t>());
    const std::size_t output_contiguous_size = std::accumulate(outDims.cbegin()+contiguousIdx, outDims.cend(), std::size_t(1), std::multiplies<std::size_t>());

    // initialize strides to iterate through data because of broadcasting
    std::size_t *stride_post0;
    std::size_t *stride_post1;
    std::int32_t *stride_step0;
    std::int32_t *stride_step1;
    if (contiguousIdx > 0) {
        stride_post0 = new std::size_t[contiguousIdx];
        stride_post0[contiguousIdx - 1] = 1;
        stride_post1 = new std::size_t[contiguousIdx];
        stride_post1[contiguousIdx - 1] = 1;
        for (std::size_t i = contiguousIdx - 2; i != static_cast<std::size_t>(-1); --i) {
            stride_post0[i] = stride_post0[i+1]*dims0[i+1];
            stride_post1[i] = stride_post1[i+1]*dims1[i+1];
        }
        stride_step0 = new std::int32_t[contiguousIdx];
        stride_step1 = new std::int32_t[contiguousIdx];
        for (std::size_t i = 0; i != contiguousIdx; ++i) {
            stride_step0[i] = (dims0[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post0[i]) : 1;
            stride_step1[i] = (dims1[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post1[i]) : 1;
        }
    }

    // variables for arrays offsets
    std::size_t offsetIn0 = 0;
    std::size_t offsetIn1 = 0;
    std::size_t offsetOut = 0;


    std::size_t dim = contiguousIdx - 1;
    const std::size_t nbStacks = std::accumulate(outDims.cbegin(), outDims.cbegin() + contiguousIdx, std::size_t(1), std::multiplies<std::size_t>());
    for (std::size_t stack = 0; stack < nbStacks;) {
        kernelFunc(input0_contiguous_size, input1_contiguous_size, output_contiguous_size,
                    getCPUPtr(mOp.getRawInput(0), offsetIn0*input0_contiguous_size),
                    getCPUPtr(mOp.getRawInput(1), offsetIn1*input1_contiguous_size),
                    getCPUPtr(mOp.getRawOutput(0), offsetOut*output_contiguous_size));
        if (++stack < nbStacks) {
            std::size_t tmp_stack = stack;
            while(tmp_stack % outDims[dim] == 0) {
                tmp_stack /= outDims[dim];
                dim--;
            }
            offsetIn0 += stride_step0[dim];
            offsetIn1 += stride_step1[dim];
            ++offsetOut;
            dim = contiguousIdx - 1;
        }
    }
    if (contiguousIdx > 0) {
        delete[] stride_post0;
        delete[] stride_post1;
        delete[] stride_step0;
        delete[] stride_step1;
    }
}