``Operator`` should operate on ``Data`` not ``Tensor``
Summary
Operator
should be a base class for any operator, whatever the type of data they compute. Hence, functions such as:
virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0;
virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0;
virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0;
virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0;
though they are very practical, should not exist in this class. They have been implemented for convenience because only Operator using Tensors have been implemented yet.
Solutions
Short term
A short term solution is to adapt getInput()
and others into utils functions
shared_ptr<Tensor> myTensor = getInputTensor(myOperator, 0);
with
std::shared_ptr<Tensor> getInputTensor(std::shared_ptr<Operator> operator, IOIndex inputIndex) {
return std::static_pointer_cast<Tensor>(operator->getRawInput(inputIndex));
}
Long term
There could be an intermediate class between Operator
and Tensor-based Operators that would implement every default member function.
classDiagram
class Operator
class OperatorTensor
class Conv
class FC
class ReLU
Operator <|-- OperatorTensor
OperatorTensor <|-- Conv
OperatorTensor <|-- FC
OperatorTensor <|-- ReLU
This system could be then extended to event-based Operators or stream-based Operators. Leaving Operator
with only what is really necessary.
Edited by Maxence Naud