Skip to content
Snippets Groups Projects
Commit 7c21dac5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added new input category system

parent cb56677f
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,6 @@ void Aidge::ConvDepthWiseImpl1D_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 3) && "support for 3-dimensions tensors only");
......@@ -84,7 +83,6 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 3-dimensions tensors only");
......
......@@ -33,7 +33,6 @@ void Aidge::ConvImpl1D_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type
const auto outputDataType = opTensor.getOutput(0)->dataType();
......@@ -85,7 +84,6 @@ void Aidge::ConvImpl2D_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
// Find the correct kernel type
const auto outputDataType = opTensor.getOutput(0)->dataType();
......
......@@ -29,7 +29,6 @@ void Aidge::FCImpl_cpu::forward()
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
AIDGE_ASSERT(op_.getInput(0), "missing input #0");
AIDGE_ASSERT(op_.getInput(1), "missing input #1");
AIDGE_ASSERT(op_.getInput(2), "missing input #2");
// Find the correct kernel type
const auto outputDataType = op_.getOutput(0)->dataType();
......@@ -77,7 +76,6 @@ void Aidge::FCImpl_cpu::backward()
AIDGE_ASSERT(fc_grad, "missing ouput #0 gradient");
AIDGE_ASSERT(op_.getInput(0)->grad(), "missing input #0 gradient");
AIDGE_ASSERT(op_.getInput(1)->grad(), "missing input #1 gradient");
AIDGE_ASSERT(op_.getInput(2)->grad(), "missing input #2 gradient");
// Find the correct kernel type
const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
......
......@@ -200,7 +200,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
microGraph->save("lstm", false, true);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -259,7 +265,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -316,7 +328,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -378,7 +396,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
myGraph->add(pop);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......@@ -441,7 +465,13 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
myGraph->add(pop);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data);
for (size_t i = 1; i < 9; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param);
}
for (size_t i = 9; i < 17; ++i) {
REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam);
}
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment