Node setDatatype is too strict
Context
Current setDatatype
is too strict and override output datatype where it should not.
An example is the Shape Operator that should always return a INT64 Tensor by design.
I propose to introduce a new method forwardDataType
that for a given node set the datatype of the output using the datatype of its input.
A current use case is the shuffle layer: https://gitlab.eclipse.org/-/project/5139/uploads/6ad9fc3fc64740b31b696a3f20308553/shuffle_layer.onnx
Ideally all the dim computation should be done using INT64
Implementation idea
This would be done by adding forwardDataType
in OperatorTensor. By default, this function check that every input has the same datatype and set this datatype to the outputs.
This would be override by specific operator such as:
- Reshape: Input 1 need to be INT64, output is the same as input 0
- Shape: Whatever the input, the output is always INT64.
In order to maintain current API, GraphView::setDatatype
would call GraphView::forwardDataType
with a deprecation warning.
GraphView::forwardDataType
would have two signatures:
-
bool GraphView::forwardDataType(dtype)
: Return true if successfull. the same dtype is forwarded to all inputs. -
bool GraphView::forwardDataType(std::vector<dtype>)
: Return true if successfull. Specify different dtype for each graph inputs.