Skip to content
Snippets Groups Projects
Commit 17948093 authored by Maxence Naud's avatar Maxence Naud
Browse files

minor changes

parent 143a9693
No related branches found
No related tags found
1 merge request!73version 0.2.3
Pipeline #47050 passed
...@@ -37,10 +37,9 @@ void Aidge::SoftmaxImpl_cpu::forward() { ...@@ -37,10 +37,9 @@ void Aidge::SoftmaxImpl_cpu::forward() {
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes(); Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes();
const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
// Call kernel // Call kernel
kernelFunc(axisIdx, kernelFunc(std::get<0>(attr), // axisIdx
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
......
...@@ -194,10 +194,10 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -194,10 +194,10 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
SECTION("LSTM(forward)") { SECTION("LSTM(forward)") {
auto pop = Pop(); auto pop = Pop();
auto myLSTM = LSTM(32, 64, 0, true, "ltsm"); auto myLSTM = LSTM(32, 64, 0, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto op = std::dynamic_pointer_cast<MetaOperator_Op>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); auto microGraph = op->getMicroGraph();
microGraph->save("lstm", false, false); microGraph->save("lstm", false, true);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1); REQUIRE(myLSTM->nbData() == 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment