You need to sign in or sign up before continuing.
Connection mismatch: Input#0 for Gemm MetaOp
This is my code for creating a Gemm MetaOp:
auto graph = std::make_shared<Aidge::GraphView>();
auto fc = FC(in_channels, out_channels, no_bias, name);
graph->add(fc, false);
auto transA = Transpose(std::vector<Aidge::DimSize_t>{}, name + "_transposeA");
auto transB = Transpose(std::vector<Aidge::DimSize_t>{}, name + "_transposeB");
if (transposeA) {
transA->addChild(fc,0,0);
graph->add(transA);
}
if (alpha != 1.0f) {
std::shared_ptr<Aidge::Tensor> alphaTensor = std::make_shared<Aidge::Tensor>(alpha);
auto alphaProducer = Producer(alphaTensor, name + "_alpha");
auto mulA = Mul(name + "_mulA");
alphaProducer->addChild(mulA,0,1);
graph->add(alphaProducer);
auto mulOutput = transposeA ? transA : fc;
mulA->addChild(mulOutput,0,0);
graph->add(mulA);
}
if(transposeB ||beta != 1.0f) {
auto weightsInput = std::make_pair(fc, 1);
graph->replace({weightsInput.first->getParents()[weightsInput.second]}, {});
if (transposeB) {
transB->addChild(fc,0,1);
graph->add(transB);
weightsInput = std::make_pair(transB, 0);
}
if (beta != 1.0f) {
// add the new weight graph
std::shared_ptr<Aidge::Tensor> betaTensor = std::make_shared<Aidge::Tensor>(beta);
auto betaProducer = Producer(betaTensor, name + "_beta");
auto mulB = Mul(name + "_mulB");
betaProducer->addChild(mulB,0,1);
graph->add(mulB->getParent(1));
mulB->addChild(weightsInput.first, 0, weightsInput.second);
graph->add(mulB);
weightsInput = std::make_pair(mulB, 0);
}
addProducer(weightsInput.first, weightsInput.second, {out_channels, in_channels}, "w");
// graph->add(weightsInput.first->getParent(weightsInput.second));
auto metaOpNode = MetaOperator("Gemm", graph, {}, name);
return metaOpNode;
The graph seems good:
but when I try to test the forward:
float alpha = 0.5, beta = 1.5;
auto gemm = Gemm(3, 5, alpha, beta, false, true, true);
auto op =
std::static_pointer_cast<MetaOperator_Op>(gemm->getOperator());
op->associateInput(0, input_t);
op->associateInput(1, weights_t);
op->associateInput(2, bias);
auto g = getConnectedGraphView(gemm);
g->setDataType(DataType::Float32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
REQUIRE_NOTHROW(scheduler.generateScheduling());
REQUIRE_NOTHROW(scheduler.forward(true));
op->getOutput(0)->print();
I the output is empty and I get this error:
[WARNING] - Remaining consumers: ["_mulB (Mul#1)", " (FC#0)"].
[WARNING] - Remaining consumers: [" (Gemm#0)"].
[ERROR] - Connection mismatch: Input#0 of node [_mulB (Mul)] -> Output#0 of node [_mulB_w -
[ERROR] (Producer)]
[WARNING] - Unable to forward dimensions (circular dependency and/or wrong dimensions and/or data
[WARNING] dependent dimension?). Unable to compute output dims for nodes [" (Gemm)"].
[]
[ERROR] - approxEq: Dimension mismatch.
[ERROR] t1 : []
[ERROR] t2 : [2, 5]
Edited by Houssem ROUIS