/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/operator/Sub.hpp" #include <cstddef> // std::size_t #include <stdexcept> // std::runtime_error #include <string> #include <vector> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" const std::string Aidge::Sub_Op::Type = "Sub"; bool Aidge::Sub_Op::forwardDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); } if (!getInput(0)->empty() && !getInput(1)->empty()) { const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1; const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1; std::size_t out_id = outDims.size() - 1; std::size_t low_id = lowDims.size() - 1; std::size_t i = 0; while (i++ < lowDims.size()) { if (outDims[out_id] == 1) { outDims[out_id] = lowDims[low_id]; } else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported Tensor shape for Sub Operation: {}", outDims); } --out_id; --low_id; } mOutputs[0]->resize(outDims); return true; } return false; } void Aidge::Sub_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { SET_IMPL_MACRO(Sub_Op, *this, name); mOutputs[0]->setBackend(name, device); }