Skip to content
Snippets Groups Projects

[Add] broadcasting for Arithmetic Operators

Merged Houssem ROUIS requested to merge hrouis/aidge_core:broadcasting into dev
4 files
+ 21
190
Compare changes
  • Side-by-side
  • Inline
Files
4
@@ -18,94 +18,41 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/graph/Node.hpp"
namespace Aidge {
class ArithmeticOperator : public Operator {
/* TODO: Add an attribute specifying the type of Data used by the Operator.
* The same way ``Type`` attribute specifies the type of Operator. Hence this
* attribute could be checked in the forwardDims function to assert Operators
* being used work with Tensors and cast them to OpertorTensor instead of
* Operator.
*/
/* TODO: Maybe change type attribute of Data object by an enum instead of an
* array of char. Faster comparisons.
*/
protected:
std::vector<std::shared_ptr<Tensor>> mInputs;
std::vector<std::shared_ptr<Tensor>> mOutputs;
class ArithmeticOperator : public OperatorTensor {
public:
ArithmeticOperator() = delete;
ArithmeticOperator(const std::string& type)
: Operator(type, 2, 0, 1, OperatorType::Tensor),
mInputs(std::vector<std::shared_ptr<Tensor>>(2, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(1)) {
mOutputs[0] = std::make_shared<Tensor>();
mOutputs[0]->setDataType(DataType::Float32);
: OperatorTensor(type, 2, 0, 1) {
}
ArithmeticOperator(const ArithmeticOperator& other)
: Operator(other),
mInputs(std::vector<std::shared_ptr<Tensor>>(2, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(1)) {
mOutputs[0] = std::make_shared<Tensor>();
}
ArithmeticOperator(const ArithmeticOperator& other) : OperatorTensor(other){ }
~ArithmeticOperator();
std::shared_ptr<Operator> clone() const override {
return std::make_shared<ArithmeticOperator>(*this);
}
void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); }
public:
///////////////////////////////////////////////////
virtual void associateInput(const IOIndex_t inputIdx,
const std::shared_ptr<Data>& data) override;
///////////////////////////////////////////////////
///////////////////////////////////////////////////
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
void computeOutputDims() override final;
// output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override;
virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
return std::static_pointer_cast<Data>(getOutput(outputIdx));
}
static const std::vector<std::string> getInputsName(){
return {"data_input1", "data_input2"};
return {"data_input_1", "data_input_2"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
///////////////////////////////////////////////////
///////////////////////////////////////////////////
// Tensor dimensions
/**
* @brief For a given output feature area, compute the associated receptive
* field for each data input.
* @param firstIdx First index of the output feature.
* @param outputDims Size of output feature.
* @param outputIdx Index of the output. Default 0.
* @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>
* For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/
virtual std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
virtual void computeOutputDims();
virtual bool outputDimsForwarded() const;
///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const override;
};
} // namespace Aidge
Loading