Simplify forward_dims(), highlight compile()
Required prerequisites
-
Make sure you've read the documentation. Your issue may be addressed there. -
Search the issue tracker and discussions to verify that this hasn't already been reported. +1 or comment there if it has.
What commit version of aidge do you use
-
aidge_core
: 2.1.0
Problem description
Here is the GraphView::forward_dims()
code: https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/blob/main/src/graph/GraphView.cpp?ref_type=heads#L401
And here is the beginning of the member function:
bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) {
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
if (!dims.empty()){
AIDGE_ASSERT(dims.size() == mInputNodes.size(), "GraphView forwardDims error - Inconsistent number of given dimensions ({}) and graph inputs ({})", dims.size(), mInputNodes.size());
for (std::size_t i = 0; i < dims.size(); ++i){
auto tensor = std::make_shared<Tensor>(dims[i]);
mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor);
}
}
It starts by setting the input Tensor of the input Node if dims
parameter has been provided. Issue: dataType, backend and anything that was set in the previous Tensor is overwritten.
It is thus impossible to simply compute the dimensions of a sub-graph or it could interfere with input Nodes entries.
This behaviour sheds the light on a bigger issue with the forward_dims()
function: it does too many things out of its scope without telling the user:
- Set inputs
- Check for validity of Nodes connections
- propagate dimensions
Suggested solution
Reduce the function to simply forwarding dimensions as its name suggests. Clarify the behaviour of forward_dims()
according to its inputs (Tensors already associated? dims parameter provided?) and move GraphView validation to the compile()
function.
- Set inputs
Inputs should not be changed. This function should only propagate dimensions as its name implies.
- Check validity of Nodes connections
checking validity should be done by the already existing GraphView::compile()
function that was actually created to perform such task.
- propagate dimensions
Keep this
Reproducible example code
In the Test_HorizontalTiling.cpp
file, the input Tensor myInput
is set to be of int type but is suprisingly converted to float type during the call to forward_dims, causing the failure of the test as there is no kernel on CPU for float32->int32 for ReLU Operator.