Skip to content
Snippets Groups Projects
Commit 57c7bc62 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

make ArithmeticOperator inherit from OperatorTensor

parent ab902aea
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!65[Add] broadcasting for Arithmetic Operators
Pipeline #38209 passed
...@@ -18,94 +18,41 @@ ...@@ -18,94 +18,41 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/graph/Node.hpp"
namespace Aidge { namespace Aidge {
class ArithmeticOperator : public Operator { class ArithmeticOperator : public OperatorTensor {
/* TODO: Add an attribute specifying the type of Data used by the Operator.
* The same way ``Type`` attribute specifies the type of Operator. Hence this
* attribute could be checked in the forwardDims function to assert Operators
* being used work with Tensors and cast them to OpertorTensor instead of
* Operator.
*/
/* TODO: Maybe change type attribute of Data object by an enum instead of an
* array of char. Faster comparisons.
*/
protected:
std::vector<std::shared_ptr<Tensor>> mInputs;
std::vector<std::shared_ptr<Tensor>> mOutputs;
public: public:
ArithmeticOperator() = delete; ArithmeticOperator() = delete;
ArithmeticOperator(const std::string& type) ArithmeticOperator(const std::string& type)
: Operator(type, 2, 0, 1, OperatorType::Tensor), : OperatorTensor(type, 2, 0, 1) {
mInputs(std::vector<std::shared_ptr<Tensor>>(2, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(1)) {
mOutputs[0] = std::make_shared<Tensor>();
mOutputs[0]->setDataType(DataType::Float32);
} }
ArithmeticOperator(const ArithmeticOperator& other) ArithmeticOperator(const ArithmeticOperator& other) : OperatorTensor(other){ }
: Operator(other),
mInputs(std::vector<std::shared_ptr<Tensor>>(2, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(1)) {
mOutputs[0] = std::make_shared<Tensor>();
}
~ArithmeticOperator(); ~ArithmeticOperator();
std::shared_ptr<Operator> clone() const override {
return std::make_shared<ArithmeticOperator>(*this);
}
void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { printf("setBackend: not available yet.\n"); }
public: public:
///////////////////////////////////////////////////
virtual void associateInput(const IOIndex_t inputIdx,
const std::shared_ptr<Data>& data) override;
///////////////////////////////////////////////////
/////////////////////////////////////////////////// void computeOutputDims() override final;
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
// output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override;
virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
return std::static_pointer_cast<Data>(getOutput(outputIdx));
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input1", "data_input2"}; return {"data_input_1", "data_input_2"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
///////////////////////////////////////////////////
///////////////////////////////////////////////////
// Tensor dimensions
/**
* @brief For a given output feature area, compute the associated receptive
* field for each data input.
* @param firstIdx First index of the output feature.
* @param outputDims Size of output feature.
* @param outputIdx Index of the output. Default 0.
* @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>
* For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/
virtual std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
virtual void computeOutputDims();
virtual bool outputDimsForwarded() const;
///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const override;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -10,18 +10,19 @@ ...@@ -10,18 +10,19 @@
********************************************************************************/ ********************************************************************************/
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ArithmeticOperator.hpp" #include "aidge/operator/ArithmeticOperator.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include <pybind11/stl.h>
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_ArithmeticOperator(py::module& m){ void init_ArithmeticOperator(py::module& m){
py::class_<ArithmeticOperator, std::shared_ptr<ArithmeticOperator>, Operator>(m, "ArithmeticOperator") py::class_<ArithmeticOperator, std::shared_ptr<ArithmeticOperator>, OperatorTensor>(m, "ArithmeticOperator")
.def("get_output", &ArithmeticOperator::getOutput, py::arg("outputIdx")) .def("get_output", &ArithmeticOperator::getOutput, py::arg("outputIdx"))
.def("get_input", &ArithmeticOperator::getInput, py::arg("inputIdx")) .def("get_input", &ArithmeticOperator::getInput, py::arg("inputIdx"))
.def("set_output", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setOutput, py::arg("outputIdx"), py::arg("data"))
.def("set_input", (void (ArithmeticOperator::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &ArithmeticOperator::setInput, py::arg("outputIdx"), py::arg("data"))
.def("output_dims_forwarded", &ArithmeticOperator::outputDimsForwarded) .def("output_dims_forwarded", &ArithmeticOperator::outputDimsForwarded)
; ;
} }
} } // namespace Aidge
...@@ -17,12 +17,12 @@ namespace Aidge { ...@@ -17,12 +17,12 @@ namespace Aidge {
void init_Data(py::module&); void init_Data(py::module&);
void init_Tensor(py::module&); void init_Tensor(py::module&);
void init_OperatorImpl(py::module&); void init_OperatorImpl(py::module&);
void init_ArithmeticOperator(py::module&);
void init_Attributes(py::module&); void init_Attributes(py::module&);
void init_Operator(py::module&); void init_Operator(py::module&);
void init_OperatorTensor(py::module&); void init_OperatorTensor(py::module&);
void init_Add(py::module&); void init_Add(py::module&);
void init_ArithmeticOperator(py::module&);
void init_AvgPooling(py::module&); void init_AvgPooling(py::module&);
void init_BatchNorm(py::module&); void init_BatchNorm(py::module&);
void init_Concat(py::module&); void init_Concat(py::module&);
...@@ -74,11 +74,11 @@ void init_Aidge(py::module& m){ ...@@ -74,11 +74,11 @@ void init_Aidge(py::module& m){
init_OpArgs(m); init_OpArgs(m);
init_Connector(m); init_Connector(m);
init_ArithmeticOperator(m);
init_OperatorImpl(m); init_OperatorImpl(m);
init_Attributes(m); init_Attributes(m);
init_Operator(m); init_Operator(m);
init_OperatorTensor(m); init_OperatorTensor(m);
init_ArithmeticOperator(m);
init_Add(m); init_Add(m);
init_AvgPooling(m); init_AvgPooling(m);
init_BatchNorm(m); init_BatchNorm(m);
......
...@@ -19,99 +19,8 @@ ...@@ -19,99 +19,8 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::ArithmeticOperator::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (inputIdx >= 2) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 2 inputs", type().c_str());
}
if (strcmp((data)->type(), Tensor::Type) != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input data must be of Tensor type");
}
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::ArithmeticOperator::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(*std::dynamic_pointer_cast<Tensor>(data));
}
}
Aidge::ArithmeticOperator::~ArithmeticOperator() = default; Aidge::ArithmeticOperator::~ArithmeticOperator() = default;
void Aidge::ArithmeticOperator::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (getInput(inputIdx)) {
*mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data)));
}
}
const std::shared_ptr<Aidge::Tensor>& Aidge::ArithmeticOperator::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 2 inputs", type().c_str());
}
return mInputs[inputIdx];
}
void Aidge::ArithmeticOperator::setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 outputs", type().c_str());
}
*mOutputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::ArithmeticOperator::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) {
if (strcmp(data->type(), "Tensor") != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator only accepts Tensors as inputs", type().c_str());
}
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 output", type().c_str());
}
*mOutputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::ArithmeticOperator::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has 1 output", type().c_str());
}
return mOutputs[outputIdx];
}
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>> Aidge::ArithmeticOperator::computeReceptiveField(
const std::vector<DimSize_t>& firstEltDims,
const std::vector<Aidge::DimSize_t>& outputDims,
const Aidge::IOIndex_t outputIdx) const
{
static_cast<void>(outputIdx);
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range.");
}
if (nbInputs() != nbData()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator has attributes. Must be handled in an overrided function.");
}
if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
for (DimIdx_t i = 0; i < outputDims.size(); ++i) {
if (((outputDims[i] + firstEltDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
// return the same Tensor description as given in function parameter for each data input
return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims));
}
void Aidge::ArithmeticOperator::computeOutputDims() { void Aidge::ArithmeticOperator::computeOutputDims() {
// check inputs have been associated // check inputs have been associated
if (!getInput(0) || !getInput(1)) { if (!getInput(0) || !getInput(1)) {
...@@ -163,30 +72,4 @@ void Aidge::ArithmeticOperator::computeOutputDims() { ...@@ -163,30 +72,4 @@ void Aidge::ArithmeticOperator::computeOutputDims() {
} }
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
}
bool Aidge::ArithmeticOperator::outputDimsForwarded() const {
bool forwarded = true;
// check both inputs and outputs have been filled
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false;
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty());
}
return forwarded;
}
void Aidge::ArithmeticOperator::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
}
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
}
else {
getInput(i)->setDataType(dataType);
}
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment