Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2303 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
ArithmeticOperator.cpp 2.29 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cassert>
#include <memory>

#include "aidge/operator/ArithmeticOperator.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"


Aidge::ArithmeticOperator::~ArithmeticOperator() = default;

void Aidge::ArithmeticOperator::computeOutputDims() {
    // check inputs have been associated
    if (!getInput(0) || !getInput(1)) {
        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
    }

    // if (getInput(0)->empty() || getInput(1)->empty()) {
    //     AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input is empty");
    // }

    std::vector<std::vector<std::size_t>> inputsDims;
    for (std::size_t i = 0; i < nbInputs(); i++)
    {
        inputsDims.push_back(getInput(i)->dims());
    }

    std::size_t outNbDims = 1;

    for(size_t i=0; i<inputsDims.size() ; ++i)
        outNbDims = inputsDims[i].size()>outNbDims?inputsDims[i].size():outNbDims;

    std::vector<std::size_t> outDims(outNbDims, 1);

    std::vector<std::size_t>::iterator it = outDims.end();
    while (it != outDims.begin())
    {
        --it;
        for (size_t i = 0; i < inputsDims.size(); i++)
        {
            if(!inputsDims[i].empty())
            {
                std::size_t dim = inputsDims[i].back();
                inputsDims[i].pop_back();
                if (*it != dim)
                {
                    if(dim != 1)
                    {
                        if (*it != 1)
                        {
                            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsopported Tensor shape for Arithmetic Operation");
                        }
                        else
                        {
                            *it = dim;
                        }
                    }
                }
            }
        }
    }
    mOutputs[0]->resize(outDims);
}