Skip to content
Snippets Groups Projects
Commit 7dbf6d79 authored by Maxence Naud's avatar Maxence Naud
Browse files

fix GraphRegex test and update OperatorTensor output functions for Identity operator

parent fd0e25a7
No related branches found
No related tags found
No related merge requests found
...@@ -75,18 +75,14 @@ public: ...@@ -75,18 +75,14 @@ public:
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
inline Tensor& input(const IOIndex_t inputIdx) const { return *getInput(inputIdx); }
inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
return std::static_pointer_cast<Data>(getInput(inputIdx)); return std::static_pointer_cast<Data>(getInput(inputIdx));
} }
// output management // output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final; void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final; void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
inline Tensor& output(const IOIndex_t outputIdx) const {
return *getOutput(outputIdx);
}
inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final { inline std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final {
return std::static_pointer_cast<Data>(getOutput(outputIdx)); return std::static_pointer_cast<Data>(getOutput(outputIdx));
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "aidge/operator/FC.hpp" #include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp" #include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/Recipies.hpp" #include "aidge/recipies/Recipies.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
...@@ -94,10 +94,10 @@ TEST_CASE("GraphRegexUser") { ...@@ -94,10 +94,10 @@ TEST_CASE("GraphRegexUser") {
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 1, 1, "c3"); std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3");
g1->add(conv); g1->add(conv);
g1->addChild(conv1, "c"); g1->addChild(conv1, "c");
...@@ -126,10 +126,10 @@ TEST_CASE("GraphRegexUser") { ...@@ -126,10 +126,10 @@ TEST_CASE("GraphRegexUser") {
SECTION("Applied Recipes"){ SECTION("Applied Recipes"){
// generate the original GraphView // generate the original GraphView
auto matmul0 = MatMul(5, "matmul0"); auto matmul0 = MatMul(5, 5, "matmul0");
auto add0 = Add<2>("add0"); auto add0 = Add(2, "add0");
auto matmul1 = MatMul(5, "matmul1"); auto matmul1 = MatMul(5, 5, "matmul1");
auto add1 = Add<2>("add1"); auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0"); auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0"); auto w0 = Producer({5, 5}, "W0");
...@@ -149,8 +149,8 @@ TEST_CASE("GraphRegexUser") { ...@@ -149,8 +149,8 @@ TEST_CASE("GraphRegexUser") {
matmul1->addChild(add1, 0, 0); matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1); b1->addChild(add1, 0, 1);
auto fc = GenericOperator("FC", 1, 1, 1, "c"); auto fc = GenericOperator("FC", 1, 0, 1, "c");
auto fl = GenericOperator("Flatten", 1, 1, 1, "c"); auto fl = GenericOperator("Flatten", 1, 0, 1, "c");
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment