Skip to content
Snippets Groups Projects

[Add] broadcasting for Arithmetic Operators

Merged Houssem ROUIS requested to merge hrouis/aidge_core:broadcasting into dev
1 file
+ 5
5
Compare changes
  • Side-by-side
  • Inline
@@ -21,7 +21,7 @@
void Aidge::ArithmeticOperator::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (inputIdx >= 2) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(),2);
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 2 inputs", type().c_str());
}
if (strcmp((data)->type(), Tensor::Type) != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input data must be of Tensor type");
@@ -55,7 +55,7 @@ void Aidge::ArithmeticOperator::setInput(const Aidge::IOIndex_t inputIdx, std::s
const std::shared_ptr<Aidge::Tensor>& Aidge::ArithmeticOperator::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 2 inputs", type().c_str());
}
return mInputs[inputIdx];
}
@@ -65,7 +65,7 @@ void Aidge::ArithmeticOperator::setOutput(const Aidge::IOIndex_t outputIdx, cons
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 outputs", type().c_str());
}
*mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
@@ -75,14 +75,14 @@ void Aidge::ArithmeticOperator::setOutput(const Aidge::IOIndex_t outputIdx, std:
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 output", type().c_str());
}
*mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::ArithmeticOperator::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 output", type().c_str());
}
return mOutputs[outputIdx];
}
Loading