diff --git a/src/operator/DivImpl.cpp b/src/operator/DivImpl.cpp index 098b20776888c6d72110e4bc4c0c3e191febd41c..cfd74be45b29852c89e4a27035ce2d38fc7266cc 100644 --- a/src/operator/DivImpl.cpp +++ b/src/operator/DivImpl.cpp @@ -55,7 +55,7 @@ void Aidge::DivImpl_cpu::forward() { // 2. Find the highest equal dimension -> 3 // Exception: if the first diverging dimension is the last one, then -> 4 (dims.size()) // 3. Compute the highest number of contiguous data -> 7 - // 4. Compute stride and offset step for the broadcast mechnism + // 4. Compute stride and offset step for the broadcast mechanism // 5. Call a simple kernel const auto& opTensor = static_cast<const Div_Op&>(mOp); @@ -70,15 +70,17 @@ void Aidge::DivImpl_cpu::forward() { std::vector<std::size_t> dims1 = opTensor.getInput(1)->dims(); const std::vector<std::size_t>& outDims = opTensor.getOutput(0)->dims(); - // if (dims0 == dims1) { - // const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>()); - // kernelFunc(input0_contiguous_size, input0_contiguous_size, input0_contiguous_size, - // getCPUPtr(mOp.getRawInput(0)), - // getCPUPtr(mOp.getRawInput(1)), - // getCPUPtr(mOp.getRawOutput(0))); - // return; - // } + // special case for equal dimensions, the kernel is called with the entire arrays at once + if (dims0 == dims1) { + const std::size_t input0_contiguous_size = std::accumulate(dims0.cbegin(), dims0.cend(), std::size_t(1), std::multiplies<std::size_t>()); + kernelFunc(input0_contiguous_size, input0_contiguous_size, input0_contiguous_size, + getCPUPtr(mOp.getRawInput(0)), + getCPUPtr(mOp.getRawInput(1)), + getCPUPtr(mOp.getRawOutput(0))); + return; + } + // set dimensions to be of equal size by filling the smallest one with ones. if (dims0.size() > dims1.size()) { dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1)); } @@ -89,8 +91,10 @@ void Aidge::DivImpl_cpu::forward() { const std::size_t nbDims = dims0.size(); // Find the highest equal dimension - std::size_t contiguousIdx = nbDims - 1; - for (; contiguousIdx+1 > 0; --contiguousIdx) { + // std::size_t contiguousIdx = nbDims - 1; + std::size_t contiguousIdx = nbDims; + while (contiguousIdx-- > 0) { + // for (; contiguousIdx+1 > 0; --contiguousIdx) { if (dims0[contiguousIdx] != dims1[contiguousIdx]) { if (contiguousIdx == (nbDims -1)) { // last dimensions of one of the input Tensor are of size 1 const std::vector<std::size_t>& dims = (dims0[contiguousIdx] == 1) ? dims0 : dims1; @@ -109,21 +113,17 @@ void Aidge::DivImpl_cpu::forward() { const std::size_t output_contiguous_size = std::accumulate(outDims.cbegin()+contiguousIdx, outDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); // initialize strides to iterate through data because of broadcasting - std::int32_t *stride_post0; - std::int32_t *stride_post1; - std::int32_t *stride_step0; - std::int32_t *stride_step1; + std::unique_ptr<std::int32_t[]> stride_post0 = std::make_unique<std::int32_t[]>(contiguousIdx); + std::unique_ptr<std::int32_t[]> stride_post1 = std::make_unique<std::int32_t[]>(contiguousIdx); + std::unique_ptr<std::int32_t[]> stride_step0 = std::make_unique<std::int32_t[]>(contiguousIdx); + std::unique_ptr<std::int32_t[]> stride_step1 = std::make_unique<std::int32_t[]>(contiguousIdx); if (contiguousIdx > 0) { - stride_post0 = new std::int32_t[contiguousIdx]; stride_post0[contiguousIdx - 1] = 1; - stride_post1 = new std::int32_t[contiguousIdx]; stride_post1[contiguousIdx - 1] = 1; for (std::size_t i = contiguousIdx - 2; i != static_cast<std::size_t>(-1); --i) { stride_post0[i] = stride_post0[i+1]*static_cast<std::int32_t>(dims0[i+1]); stride_post1[i] = stride_post1[i+1]*static_cast<std::int32_t>(dims1[i+1]); } - stride_step0 = new std::int32_t[contiguousIdx]; - stride_step1 = new std::int32_t[contiguousIdx]; for (std::size_t i = 0; i != contiguousIdx; ++i) { stride_step0[i] = (dims0[i] == 1) ? 1 - stride_post0[i] : 1; stride_step1[i] = (dims1[i] == 1) ? 1 - stride_post1[i] : 1; @@ -155,10 +155,4 @@ void Aidge::DivImpl_cpu::forward() { dim = contiguousIdx - 1; } } - if (contiguousIdx > 0) { - delete[] stride_post0; - delete[] stride_post1; - delete[] stride_step0; - delete[] stride_step1; - } }