Skip to content
Snippets Groups Projects
Commit 46e15f55 authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove 'Tensor' from 'Operator' class. Only keep 'Data'

- Add a common parent class to each Operator using Tensors: OperatorTensor
- Gather shared operator functions in OperatorTensor
- Add generic mInputs and mOutputs attributes for OperatorTensor
- Add an enum to identify the type of Data used by each Operator
- Change Inputs, DataInputs, Outputs for Inputs, Data, Attr, Outputs for less confusion
parent 57ad2929
No related branches found
No related tags found
2 merge requests!46Remove Operator reference to Tensor,!20Draft: Introduction of Tiling
...@@ -20,12 +20,16 @@ ...@@ -20,12 +20,16 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/hook/Hook.hpp" #include "aidge/hook/Hook.hpp"
namespace Aidge { namespace Aidge {
enum class OperatorType {
Data,
Tensor
};
class Operator : public std::enable_shared_from_this<Operator> { class Operator : public std::enable_shared_from_this<Operator> {
protected: protected:
std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator
...@@ -33,27 +37,28 @@ protected: ...@@ -33,27 +37,28 @@ protected:
private: private:
std::string mType; std::string mType;
const OperatorType mOperatorType;
const IOIndex_t mNbData; const IOIndex_t mNbData;
const IOIndex_t mNbAttr; const IOIndex_t mNbParam;
const IOIndex_t mNbOut; const IOIndex_t mNbOut;
public: public:
Operator() = delete; Operator() = delete;
Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut) Operator(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data)
: mType(type), : mType(type),
mOperatorType(operatorType),
mNbData(nbData), mNbData(nbData),
mNbAttr(nbAttr), mNbParam(nbParam),
mNbOut(nbOut) mNbOut(nbOut)
{ {
// ctor // ctor
} }
virtual std::shared_ptr<Operator> clone() const = 0;
virtual ~Operator();
Operator(const Operator& op): Operator(const Operator& op):
std::enable_shared_from_this<Operator>(), std::enable_shared_from_this<Operator>(),
mOperatorType(op.mOperatorType),
mNbData(op.mNbData), mNbData(op.mNbData),
mNbAttr(op.mNbAttr), mNbParam(op.mNbParam),
mNbOut(op.mNbOut) mNbOut(op.mNbOut)
{ {
mType = op.mType; mType = op.mType;
...@@ -63,9 +68,12 @@ public: ...@@ -63,9 +68,12 @@ public:
// Hooks are not copied. // Hooks are not copied.
} }
virtual ~Operator();
public: public:
virtual std::shared_ptr<Operator> clone() const = 0;
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) = 0; virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
/** /**
* @brief For a given output feature area, compute the associated receptive * @brief For a given output feature area, compute the associated receptive
* field for each data input. * field for each data input.
...@@ -92,7 +100,7 @@ public: ...@@ -92,7 +100,7 @@ public:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
virtual void setBackend(const std::string& name) = 0; virtual void setBackend(const std::string& name) = 0;
virtual void setDatatype(const DataType& datatype) = 0; virtual void setDataType(const DataType& dataType) const = 0;
/** /**
* @brief Set the a new OperatorImpl to the Operator * @brief Set the a new OperatorImpl to the Operator
...@@ -135,13 +143,17 @@ public: ...@@ -135,13 +143,17 @@ public:
// INNER // INNER
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
std::string type() const { inline std::string type() const noexcept {
return mType; return mType;
} }
inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbAttr; }; inline OperatorType operatorType() const noexcept{
return mOperatorType;
}
inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbParam; };
inline IOIndex_t nbData() const noexcept { return mNbData; }; inline IOIndex_t nbData() const noexcept { return mNbData; };
inline IOIndex_t nbAttr() const noexcept { return mNbAttr; }; inline IOIndex_t nbParam() const noexcept { return mNbParam; };
inline IOIndex_t nbOutputs() const noexcept { return mNbOut; }; inline IOIndex_t nbOutputs() const noexcept { return mNbOut; };
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -18,63 +18,84 @@ ...@@ -18,63 +18,84 @@
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class OperatorTensor : public Operator { class OperatorTensor : public Operator {
/* TODO: Add an attribute specifying the type of Data used by the Operator. /* TODO: Add an attribute specifying the type of Data used by the Operator.
* The same way ``Type`` attribute specifies the type of Operator. Hence this * The same way ``Type`` attribute specifies the type of Operator. Hence this
* attribute could be checked in the forwardDims function to assert Operators * attribute could be checked in the forwardDims function to assert Operators
* being used work with Tensors and cast them to OpertorTensor instead of * being used work with Tensors and cast them to OpertorTensor instead of
* Operator. * Operator.
*/ */
/* TODO: Maybe change type attribute of Data object by an enum instead of an /* TODO: Maybe change type attribute of Data object by an enum instead of an
* array of char. Faster comparisons. * array of char. Faster comparisons.
*/ */
protected: protected:
std::vector<std::shared_ptr<Tensor>*> mInputs; std::vector<std::shared_ptr<Tensor>> mInputs;
std::vector<std::shared_ptr<Tensor>> mOutputs; std::vector<std::shared_ptr<Tensor>> mOutputs;
public: public:
OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut) OperatorTensor() = delete;
: Operator(type, nbData, nbAttr, nbOut),
mInputs(std::vector<std::shared_ptr<Tensor>*>(nbData + nbAttr, nullptr)), OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbParam,
mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut)) const IOIndex_t nbOut)
{ : Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor),
mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut)) {
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) {
mOutputs[i] = std::make_shared<Tensor>(); mOutputs[i] = std::make_shared<Tensor>();
mOutputs[i]->setDataType(DataType::Float32);
}
}
OperatorTensor(const OperatorTensor& other)
: Operator(other),
mInputs(std::vector<std::shared_ptr<Tensor>>(other.nbInputs(), nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(other.nbOutputs())) {
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOutputs()); ++i) {
mOutputs[i] = std::make_shared<Tensor>(other.output(i));
// datatype already copied
} }
} }
virtual ~OperatorTensor() = default;
public: public:
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) override; virtual void associateInput(const IOIndex_t inputIdx,
const std::shared_ptr<Data>& data) override;
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// Tensor access // Tensor access
// input management // input management
std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
Tensor& input(const IOIndex_t inputIdx) const; inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final; inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
//output management // output management
std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const; const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
Tensor& output(const IOIndex_t outputIdx) const; inline Tensor& output(const IOIndex_t outputIdx) const {
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final; return *getOutput(outputIdx);
}
inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
return std::static_pointer_cast<Data>(getOutput(outputIdx));
}
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// Tensor dimensions // Tensor dimensions
virtual void computeOutputDims() = 0; virtual void computeOutputDims();
virtual bool outputDimsForwarded() const; virtual bool outputDimsForwarded() const;
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const; virtual void setDataType(const DataType& dataType) const override;
}; };
} // namespace Aidge } // namespace Aidge
#endif // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_ #endif // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
\ No newline at end of file \ No newline at end of file
...@@ -19,50 +19,58 @@ ...@@ -19,50 +19,58 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>* data) { void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
if (inputIdx >= nbInputs()) { if (inputIdx >= nbInputs()) {
AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs()); AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
} }
if (strcmp((*data)->type(), Tensor::Type) != 0) { if (strcmp((data)->type(), Tensor::Type) != 0) {
printf("input data must be of Tensor type"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Input data must be of Tensor type");
exit(-1);
} }
mInputs[inputIdx] = &std::dynamic_pointer_cast<Tensor>(*data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
} }
std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const { const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) { if (inputIdx >= nbInputs()) {
AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs()); AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu inputs", type().c_str(), nbInputs());
} }
return *mInputs[inputIdx]; return mInputs[inputIdx];
} }
Aidge::Tensor& Aidge::OperatorTensor::input(const Aidge::IOIndex_t inputIdx) const {
return *getInput(inputIdx);
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
const std::shared_ptr<Aidge::Tensor>& Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) { if (outputIdx >= nbOutputs()) {
AIDGE_ASSERT("%s Operator has %hu outputs", type().c_str(), nbOutputs()); AIDGE_THROW_OR_ABORT(std::runtime_error, "%s Operator has %hu outputs", type().c_str(), nbOutputs());
} }
return mOutputs[outputIdx]; return mOutputs[outputIdx];
} }
Aidge::Tensor& Aidge::OperatorTensor::output(const Aidge::IOIndex_t outputIdx) const {
return *getOutput(outputIdx);
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const { void Aidge::OperatorTensor::computeOutputDims() {
return std::static_pointer_cast<Data>(getOutput(outputIdx)); // check inputs have been associated
bool associated = (nbInputs() > 0); // do not compute anything if no input
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
associated &= !(getInput(i)->empty());
}
if (associated) {
const auto expectedDims = getInput(0)->dims();
for (std::size_t i = 1; i < nbInputs(); ++i) {
if (expectedDims != getInput(i)->dims()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator's inputs should have the same dimensions");
}
}
mOutputs[0]->resize(expectedDims);
}
} }
bool Aidge::OperatorTensor::outputDimsForwarded() const { bool Aidge::OperatorTensor::outputDimsForwarded() const {
bool forwarded = true; bool forwarded = true;
// check both inputs and outputs have been filled
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
forwarded &= !(getInput(i)->empty());
}
for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty()); forwarded &= !(getOutput(i)->empty());
} }
...@@ -71,9 +79,9 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const { ...@@ -71,9 +79,9 @@ bool Aidge::OperatorTensor::outputDimsForwarded() const {
void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDatatype(dataType); getOutput(i)->setDataType(dataType);
} }
for (IOIndex_t i = 0; i < nbInputs(); ++i) { for (IOIndex_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setDatatype(dataType); getInput(i)->setDataType(dataType);
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment