Skip to content
Snippets Groups Projects
Commit 3ba465b3 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Added meta op input category

parent 2abb961c
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!203Fix LSTM ONNX compatibility
......@@ -37,7 +37,7 @@ public:
std::weak_ptr<Node> mUpperNode;
public:
MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph);
MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {});
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -113,6 +113,7 @@ public:
std::shared_ptr<Node> MetaOperator(const char *type,
const std::shared_ptr<GraphView>& graph,
const std::vector<InputCategory>& forcedInputsCategory = {},
const std::string& name = "");
} // namespace Aidge
......
......@@ -126,7 +126,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &
MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims, ceil_mode)
});
return MetaOperator("PaddedMaxPooling", graph, name);
return MetaOperator("PaddedMaxPooling", graph, {}, name);
}
template <std::array<DimSize_t, 1>::size_type DIM>
......
......@@ -195,15 +195,17 @@ void init_MetaOperatorDefs(py::module &m) {
declare_LSTMOp(m);
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance())
.def(py::init<const char *, const std::shared_ptr<GraphView>&>(),
.def(py::init<const char *, const std::shared_ptr<GraphView>&, const std::vector<InputCategory>&>(),
py::arg("type"),
py::arg("graph"))
py::arg("graph"),
py::arg("forced_inputs_category") = std::vector<InputCategory>())
.def("get_micro_graph", &MetaOperator_Op::getMicroGraph)
.def("set_upper_node", &MetaOperator_Op::setUpperNode);
m.def("meta_operator", &MetaOperator,
py::arg("type"),
py::arg("graph"),
py::arg("forced_inputs_category") = std::vector<InputCategory>(),
py::arg("name") = ""
);
......
......@@ -20,17 +20,22 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/DynamicAttributes.hpp"
Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph)
: OperatorTensor(type, [graph]() {
Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory)
: OperatorTensor(type, [graph, forcedInputsCategory]() {
IOIndex_t inputIdx = 0;
std::vector<InputCategory> inputsCategory;
for (const auto& in : graph->getOrderedInputs()) {
if (in.first) {
if (inputIdx < forcedInputsCategory.size()) {
inputsCategory.push_back(forcedInputsCategory[inputIdx]);
}
else if (in.first) {
inputsCategory.push_back(in.first->getOperator()->inputCategory(in.second));
}
else {
// Dummy input, default to OptionalData
inputsCategory.push_back(InputCategory::OptionalData);
}
++inputIdx;
}
return inputsCategory;
}(), graph->getOrderedOutputs().size()),
......@@ -245,9 +250,10 @@ void Aidge::MetaOperator_Op::forward() {
std::shared_ptr<Aidge::Node> Aidge::MetaOperator(const char *type,
const std::shared_ptr<Aidge::GraphView>& graph,
const std::vector<InputCategory>& forcedInputsCategory,
const std::string& name)
{
auto op = std::make_shared<MetaOperator_Op>(type, graph);
auto op = std::make_shared<MetaOperator_Op>(type, graph, forcedInputsCategory);
auto node = std::make_shared<Node>(op, name);
op->setUpperNode(node);
return node;
......
......@@ -115,7 +115,7 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
{hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
auto metaOp = MetaOperator("LSTM", microGraph, name);
auto metaOp = MetaOperator("LSTM", microGraph, {}, name);
addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi");
addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo");
addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf");
......
......@@ -41,7 +41,7 @@ std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_
AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
});
return MetaOperator("PaddedAvgPooling", graph, name);
return MetaOperator("PaddedAvgPooling", graph, {}, name);
}
template std::shared_ptr<Node> PaddedAvgPooling<1>(const std::array<DimSize_t,1>&, const std::string&, const std::array<DimSize_t,1>&, const std::array<DimSize_t,2>&);
......
......@@ -43,7 +43,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConv(Aidge::DimSize_t in_channels,
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "")
});
auto metaOpNode = MetaOperator("PaddedConv", graph, name);
auto metaOpNode = MetaOperator("PaddedConv", graph, {}, name);
addProducer(metaOpNode, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
if (!no_bias) {
addProducer(metaOpNode, 2, {out_channels}, "b");
......
......@@ -40,7 +40,7 @@ std::shared_ptr<Aidge::Node> Aidge::PaddedConvDepthWise(const Aidge::DimSize_t n
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv_depth_wise" : "")
});
auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, name);
auto metaOpNode = MetaOperator("PaddedConvDepthWise", graph, {}, name);
addProducer(metaOpNode, 1, append(nb_channels, append(Aidge::DimSize_t(1), kernel_dims)), "w");
if (!no_bias) {
addProducer(metaOpNode, 2, {nb_channels}, "b");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment