From 7c21dac5e2c79a224c0bc1ed1cb16eb5089ac66d Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 20 Jun 2024 17:37:23 +0200
Subject: [PATCH] Added new input category system

---
 src/operator/ConvDepthWiseImpl.cpp        |  2 --
 src/operator/ConvImpl.cpp                 |  2 --
 src/operator/FCImpl.cpp                   |  2 --
 unit_tests/operator/Test_MetaOperator.cpp | 40 ++++++++++++++++++++---
 4 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/src/operator/ConvDepthWiseImpl.cpp b/src/operator/ConvDepthWiseImpl.cpp
index 9e0daa46..11e50aaa 100644
--- a/src/operator/ConvDepthWiseImpl.cpp
+++ b/src/operator/ConvDepthWiseImpl.cpp
@@ -32,7 +32,6 @@ void Aidge::ConvDepthWiseImpl1D_cpu::forward() {
 
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
-    assert(mOp.getRawInput(2) && "missing input #2");
 
     assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 3) && "support for 3-dimensions tensors only");
 
@@ -84,7 +83,6 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
 
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
-    assert(mOp.getRawInput(2) && "missing input #2");
 
     assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 3-dimensions tensors only");
 
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index b7deb6b8..e522aac1 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -33,7 +33,6 @@ void Aidge::ConvImpl1D_cpu::forward() {
     // FIXME: uncomment the following code once memory handling will work
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
-    assert(mOp.getRawInput(2) && "missing input #2");
 
     // Find the correct kernel type
     const auto outputDataType = opTensor.getOutput(0)->dataType();
@@ -85,7 +84,6 @@ void Aidge::ConvImpl2D_cpu::forward() {
     // FIXME: uncomment the following code once memory handling will work
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
-    assert(mOp.getRawInput(2) && "missing input #2");
 
     // Find the correct kernel type
     const auto outputDataType = opTensor.getOutput(0)->dataType();
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index ad3727f6..f7eebb7b 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -29,7 +29,6 @@ void Aidge::FCImpl_cpu::forward()
     const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
     AIDGE_ASSERT(op_.getInput(0), "missing input #0");
     AIDGE_ASSERT(op_.getInput(1), "missing input #1");
-    AIDGE_ASSERT(op_.getInput(2), "missing input #2");
 
     // Find the correct kernel type
     const auto outputDataType = op_.getOutput(0)->dataType();
@@ -77,7 +76,6 @@ void Aidge::FCImpl_cpu::backward()
     AIDGE_ASSERT(fc_grad, "missing ouput #0 gradient");
     AIDGE_ASSERT(op_.getInput(0)->grad(), "missing input #0 gradient");
     AIDGE_ASSERT(op_.getInput(1)->grad(), "missing input #1 gradient");
-    AIDGE_ASSERT(op_.getInput(2)->grad(), "missing input #2 gradient");
 
     // Find the correct kernel type
     const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 78058eca..2893788e 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -200,7 +200,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         microGraph->save("lstm", false, true);
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
-        REQUIRE(myLSTM->nbData() == 1);
+        REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
+        for (size_t i = 1; i < 9; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
+        }
+        for (size_t i = 9; i < 17; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
+        }
         REQUIRE(myLSTM->nbOutputs() == 2);
 
         std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
@@ -259,7 +265,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         microGraph->save("lstm", false, false);
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
-        REQUIRE(myLSTM->nbData() == 1);
+        REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
+        for (size_t i = 1; i < 9; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
+        }
+        for (size_t i = 9; i < 17; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
+        }
         REQUIRE(myLSTM->nbOutputs() == 2);
 
         std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
@@ -316,7 +328,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
-        REQUIRE(myLSTM->nbData() == 1);
+        REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
+        for (size_t i = 1; i < 9; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
+        }
+        for (size_t i = 9; i < 17; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
+        }
         REQUIRE(myLSTM->nbOutputs() == 2);
 
         std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
@@ -378,7 +396,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         myGraph->add(pop);
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
-        REQUIRE(myLSTM->nbData() == 1);
+        REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
+        for (size_t i = 1; i < 9; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
+        }
+        for (size_t i = 9; i < 17; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
+        }
         REQUIRE(myLSTM->nbOutputs() == 2);
 
         std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
@@ -441,7 +465,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         myGraph->add(pop);
 
         REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
-        REQUIRE(myLSTM->nbData() == 1);
+        REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
+        for (size_t i = 1; i < 9; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
+        }
+        for (size_t i = 9; i < 17; ++i) {
+            REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
+        }
         REQUIRE(myLSTM->nbOutputs() == 2);
 
         std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
-- 
GitLab