Skip to content
Snippets Groups Projects
Commit 7c21dac5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added new input category system

parent cb56677f
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,6 @@ void Aidge::ConvDepthWiseImpl1D_cpu::forward() { ...@@ -32,7 +32,6 @@ void Aidge::ConvDepthWiseImpl1D_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 3) && "support for 3-dimensions tensors only"); assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 3) && "support for 3-dimensions tensors only");
...@@ -84,7 +83,6 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() { ...@@ -84,7 +83,6 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 3-dimensions tensors only"); assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 3-dimensions tensors only");
......
...@@ -33,7 +33,6 @@ void Aidge::ConvImpl1D_cpu::forward() { ...@@ -33,7 +33,6 @@ void Aidge::ConvImpl1D_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work // FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type // Find the correct kernel type
const auto outputDataType = opTensor.getOutput(0)->dataType(); const auto outputDataType = opTensor.getOutput(0)->dataType();
...@@ -85,7 +84,6 @@ void Aidge::ConvImpl2D_cpu::forward() { ...@@ -85,7 +84,6 @@ void Aidge::ConvImpl2D_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work // FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0"); assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1"); assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type // Find the correct kernel type
const auto outputDataType = opTensor.getOutput(0)->dataType(); const auto outputDataType = opTensor.getOutput(0)->dataType();
......
...@@ -29,7 +29,6 @@ void Aidge::FCImpl_cpu::forward() ...@@ -29,7 +29,6 @@ void Aidge::FCImpl_cpu::forward()
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
AIDGE_ASSERT(op_.getInput(0), "missing input #0"); AIDGE_ASSERT(op_.getInput(0), "missing input #0");
AIDGE_ASSERT(op_.getInput(1), "missing input #1"); AIDGE_ASSERT(op_.getInput(1), "missing input #1");
AIDGE_ASSERT(op_.getInput(2), "missing input #2");
// Find the correct kernel type // Find the correct kernel type
const auto outputDataType = op_.getOutput(0)->dataType(); const auto outputDataType = op_.getOutput(0)->dataType();
...@@ -77,7 +76,6 @@ void Aidge::FCImpl_cpu::backward() ...@@ -77,7 +76,6 @@ void Aidge::FCImpl_cpu::backward()
AIDGE_ASSERT(fc_grad, "missing ouput #0 gradient"); AIDGE_ASSERT(fc_grad, "missing ouput #0 gradient");
AIDGE_ASSERT(op_.getInput(0)->grad(), "missing input #0 gradient"); AIDGE_ASSERT(op_.getInput(0)->grad(), "missing input #0 gradient");
AIDGE_ASSERT(op_.getInput(1)->grad(), "missing input #1 gradient"); AIDGE_ASSERT(op_.getInput(1)->grad(), "missing input #1 gradient");
AIDGE_ASSERT(op_.getInput(2)->grad(), "missing input #2 gradient");
// Find the correct kernel type // Find the correct kernel type
const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = { const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
......
...@@ -200,7 +200,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -200,7 +200,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
microGraph->save("lstm", false, true); microGraph->save("lstm", false, true);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
...@@ -259,7 +265,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -259,7 +265,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
microGraph->save("lstm", false, false); microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
...@@ -316,7 +328,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -316,7 +328,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
...@@ -378,7 +396,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -378,7 +396,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
myGraph->add(pop); myGraph->add(pop);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
...@@ -441,7 +465,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -441,7 +465,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
myGraph->add(pop); myGraph->add(pop);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment