graph->compile() is messing up with inputs
Required prerequisites
-
Make sure you've read the documentation. Your issue may be addressed there. -
Search the issue tracker and discussions to verify that this hasn't already been reported. +1 or comment there if it has.
What commit version of aidge do you use
-
aidge_core
: dev -
aidge_...
: dev
Problem description
TEST_CASE("Graph Compile",
"[GraphCompile]") {
// Network parameters
constexpr auto inChannels = 4;
constexpr auto outChannels = 4;
constexpr auto beta = 0.8;
constexpr auto threshold = 1.0;
// Temporal Dimensions
constexpr auto nbTimeSteps = 3;
constexpr auto batchSize = 2;
// Initialize tensors
auto fc1Weights =
std::make_shared<Tensor>(Array2D<float, inChannels, outChannels>{{
{0.1, 0.1, 0.1, 0.1},
{0.1, 0.1, 0.1, 0.1},
{0.1, 0.1, 0.1, 0.1},
{0.1, 0.1, 0.1, 0.1},
}});
auto input = std::make_shared<Tensor>(
Array3D<float, nbTimeSteps, batchSize, inChannels>{{
{{1, 1, 1, 1}, {1, 1, 1, 1}},
{{1, 1, 1, 1}, {1, 1, 1, 1}},
{{1, 1, 1, 1}, {1, 1, 1, 1}},
}});
// Create network operators
auto pop = Pop("pop");
auto stack = Stack(nbTimeSteps, "stack");
auto fc1 = FC(inChannels, outChannels, /*noBias=*/true, /*name=*/"fc1");
// Get operator references
auto popOp = std::static_pointer_cast<OperatorTensor>(pop->getOperator());
auto stackOp =
std::static_pointer_cast<OperatorTensor>(stack->getOperator());
auto fc1Op = std::static_pointer_cast<OperatorTensor>(fc1->getOperator());
// Connect operators
fc1->input(1).first->getOperator()->setOutput(0, fc1Weights);
pop->getOperator()->associateInput(0, input);
pop->addChild(fc1, 0, 0);
fc1->addChild(stack, 0, 0);
// Create and compile graph
auto graph = std::make_shared<GraphView>();
graph->add({fc1, stack, pop});
REQUIRE(not fc1Op->getInput(2)); // pass
graph->compile("cpu", DataType::Float32);
REQUIRE(not fc1Op->getInput(2)); // fail
}
The following test fails (on the last line), which indicates that fc1Op->getInput(2) is not nullptr. Later on, this will cause the FC CPU kernel to fail at
const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0))) : Tensor();
Since refCastFrom is called on a tensor wit no implentation.
Edited by Jerome Hue