-
Maxence Naud authoredMaxence Naud authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Mul.cpp 2.44 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <memory>
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Mul_Op::Type = "Mul";
bool Aidge::Mul_Op::forwardDims(bool /*allowDataDependency*/) {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
}
if (!getInput(0)->empty() && !getInput(1)->empty()) {
const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims();
const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims();
std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1;
const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1;
std::size_t out_id = outDims.size() - 1;
std::size_t low_id = lowDims.size() - 1;
std::size_t i = 0;
while (i++ < lowDims.size()) {
if (outDims[out_id] == 1) {
outDims[out_id] = lowDims[low_id];
}
else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported Tensor shape for Mul Operation: {}", outDims);
}
--out_id;
--low_id;
}
mOutputs[0]->resize(outDims);
return true;
}
else if (!getInput(0)->empty() && !getInput(1)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible input dimensions for Operator Mul: {} and {}", getInput(0)->dims(), getInput(1)->dims());
}
return false;
}
void Aidge::Mul_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Mul_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}