Skip to content

graph->compile() is messing up with inputs

Required prerequisites

  • Make sure you've read the documentation. Your issue may be addressed there.
  • Search the issue tracker and discussions to verify that this hasn't already been reported. +1 or comment there if it has.

What commit version of aidge do you use

  • aidge_core: dev
  • aidge_...: dev

Problem description

TEST_CASE("Graph Compile",
          "[GraphCompile]") {
    // Network parameters
    constexpr auto inChannels = 4;
    constexpr auto outChannels = 4;
    constexpr auto beta = 0.8;
    constexpr auto threshold = 1.0;

    // Temporal Dimensions
    constexpr auto nbTimeSteps = 3;
    constexpr auto batchSize = 2;

    // Initialize tensors
    auto fc1Weights =
        std::make_shared<Tensor>(Array2D<float, inChannels, outChannels>{{
            {0.1, 0.1, 0.1, 0.1},
            {0.1, 0.1, 0.1, 0.1},
            {0.1, 0.1, 0.1, 0.1},
            {0.1, 0.1, 0.1, 0.1},
        }});

    auto input = std::make_shared<Tensor>(
        Array3D<float, nbTimeSteps, batchSize, inChannels>{{
            {{1, 1, 1, 1}, {1, 1, 1, 1}},
            {{1, 1, 1, 1}, {1, 1, 1, 1}},
            {{1, 1, 1, 1}, {1, 1, 1, 1}},
        }});

    // Create network operators
    auto pop = Pop("pop");
    auto stack = Stack(nbTimeSteps, "stack");
    auto fc1 = FC(inChannels, outChannels, /*noBias=*/true, /*name=*/"fc1");

    // Get operator references
    auto popOp = std::static_pointer_cast<OperatorTensor>(pop->getOperator());
    auto stackOp =
        std::static_pointer_cast<OperatorTensor>(stack->getOperator());
    auto fc1Op = std::static_pointer_cast<OperatorTensor>(fc1->getOperator());

    // Connect operators
    fc1->input(1).first->getOperator()->setOutput(0, fc1Weights);
    pop->getOperator()->associateInput(0, input);
    pop->addChild(fc1, 0, 0);
    fc1->addChild(stack, 0, 0);

    // Create and compile graph
    auto graph = std::make_shared<GraphView>();
    graph->add({fc1, stack, pop});
    REQUIRE(not fc1Op->getInput(2)); // pass
    graph->compile("cpu", DataType::Float32);
    REQUIRE(not fc1Op->getInput(2)); // fail
}

The following test fails (on the last line), which indicates that fc1Op->getInput(2) is not nullptr. Later on, this will cause the FC CPU kernel to fail at

const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0))) : Tensor();

Since refCastFrom is called on a tensor wit no implentation.

Edited by Jerome Hue