Skip to content
Snippets Groups Projects
Commit bc61fe4e authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix MetaOperator test

parent 58c5ee5e
No related branches found
No related tags found
3 merge requests!1190.2.1,!113Draft: Fix slice,!104Make forwardDims() optional and handle data dependency
Pipeline #44233 canceled
......@@ -9,6 +9,12 @@
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <memory>
#include <string>
#include <utility> // std::pair
#include <vector>
#include <catch2/catch_test_macros.hpp>
#include "aidge/operator/Pop.hpp"
......@@ -17,7 +23,6 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/recipes/Recipes.hpp"
#include <cstddef>
using namespace Aidge;
......@@ -37,8 +42,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
REQUIRE(op->nbData() == 1);
REQUIRE(op->nbOutputs() == 1);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>();
myInput->resize({2,3,5,5});
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(std::vector<std::size_t>({2,1,5,5}));
std::shared_ptr<OperatorTensor> opTensor = std::static_pointer_cast<OperatorTensor>(op->getOperator());
opTensor->associateInput(0,myInput);
opTensor->forwardDims();
......
import onnx
from onnx.backend.test.case.node.lstm import LSTMHelper
from onnx.backend.test.case.node import expect
import numpy as np
input = np.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]).astype(np.float32)
print(input.shape)
input_size = 2
hidden_size = 3
weight_scale = 0.1
number_of_gates = 4
node = onnx.helper.make_node(
"LSTM", inputs=["X", "W", "R"], outputs=["", "Y_h"], hidden_size=hidden_size
)
W = weight_scale * np.ones(
(1, number_of_gates * hidden_size, input_size)
).astype(np.float32)
R = weight_scale * np.ones(
(1, number_of_gates * hidden_size, hidden_size)
).astype(np.float32)
lstm = LSTMHelper(X=input, W=W, R=R)
_, Y_h = lstm.step()
print(lstm.C_0 )
seq_length = input.shape[0]
batch_size = input.shape[1]
print(seq_length)
print(np.split(input, input.shape[0], axis=0))
expect(
node,
inputs=[input, W, R],
outputs=[Y_h.astype(np.float32)],
name="test_lstm_defaults",
)
print(Y_h)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment