Skip to content

[core] ``Operator`` should operate on ``Data`` not ``Tensor``

Summary

Operator should be a base class for any operator, whatever the type of data they compute. Hence, functions such as:

virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0;
virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0;
virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0;
virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0;

though they are very practical, should not exist in this class. They have been implemented for convenience because only Operator using Tensors have been implemented yet.

Solutions

Short term

A short term solution is to adapt getInput() and others into utils functions

shared_ptr<Tensor> myTensor = getInputTensor(myOperator, 0);

with

std::shared_ptr<Tensor> getInputTensor(std::shared_ptr<Operator> operator, IOIndex inputIndex) {
    return std::static_pointer_cast<Tensor>(operator->getRawInput(inputIndex));
}

Long term

There could be an intermediate class between Operator and Tensor-based Operators that would implement every default member function.

classDiagram

class Operator
class OperatorTensor
class Conv
class FC
class ReLU

Operator <|-- OperatorTensor
OperatorTensor <|-- Conv
OperatorTensor <|-- FC
OperatorTensor <|-- ReLU

This system could be then extended to event-based Operators or stream-based Operators. Leaving Operator with only what is really necessary.

Edited by Maxence Naud