Skip to content
Snippets Groups Projects
Commit 026123ea authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'NamingInOut' into 'main'

[Operator] Add getInputsName & getOutputsName methods.

See merge request eclipse/aidge/aidge_core!30
parents a5eb92e1 2d4f518a
No related branches found
No related tags found
No related merge requests found
Showing
with 123 additions and 23 deletions
...@@ -162,6 +162,12 @@ public: ...@@ -162,6 +162,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return NUM; } inline IOIndex_t nbInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::size_t NUM> template <std::size_t NUM>
......
...@@ -157,6 +157,12 @@ public: ...@@ -157,6 +157,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
......
...@@ -160,6 +160,12 @@ public: ...@@ -160,6 +160,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 5; } inline IOIndex_t nbInputs() const noexcept override final { return 5; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "scale", "shift", "mean", "variance"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <DimSize_t DIM> template <DimSize_t DIM>
......
...@@ -177,6 +177,12 @@ public: ...@@ -177,6 +177,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
......
...@@ -176,6 +176,12 @@ class ConvDepthWise_Op : public Operator, ...@@ -176,6 +176,12 @@ class ConvDepthWise_Op : public Operator,
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
......
...@@ -158,6 +158,12 @@ public: ...@@ -158,6 +158,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") { inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") {
...@@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", ...@@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels",
"NoBias"}; "NoBias"};
} }
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ #endif /* AIDGE_CORE_OPERATOR_FC_H_ */
\ No newline at end of file
...@@ -137,6 +137,12 @@ public: ...@@ -137,6 +137,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") { inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") {
......
...@@ -148,6 +148,12 @@ public: ...@@ -148,6 +148,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 2; } inline IOIndex_t nbInputs() const noexcept override final { return 2; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") { inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") {
......
...@@ -158,6 +158,12 @@ public: ...@@ -158,6 +158,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
......
...@@ -116,6 +116,12 @@ public: ...@@ -116,6 +116,12 @@ public:
virtual IOIndex_t nbInputs() const noexcept = 0; virtual IOIndex_t nbInputs() const noexcept = 0;
virtual IOIndex_t nbDataInputs() const noexcept = 0; virtual IOIndex_t nbDataInputs() const noexcept = 0;
virtual IOIndex_t nbOutputs() const noexcept = 0; virtual IOIndex_t nbOutputs() const noexcept = 0;
static const std::vector<std::string> getInputsName(){
return {};
}
static const std::vector<std::string> getOutputsName(){
return {};
}
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -160,6 +160,12 @@ public: ...@@ -160,6 +160,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
......
...@@ -79,7 +79,7 @@ public: ...@@ -79,7 +79,7 @@ public:
* @brief Set the Output Tensor of the Producer operator. * @brief Set the Output Tensor of the Producer operator.
* This method will create a copy of the Tensor. * This method will create a copy of the Tensor.
* *
* @param newOutput Tensor containing the values to copy * @param newOutput Tensor containing the values to copy
*/ */
void setOutputTensor(const Tensor& newOutput) { void setOutputTensor(const Tensor& newOutput) {
*mOutput = newOutput; *mOutput = newOutput;
...@@ -132,6 +132,12 @@ public: ...@@ -132,6 +132,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 0; }; inline IOIndex_t nbInputs() const noexcept override final { return 0; };
inline IOIndex_t nbDataInputs() const noexcept override final { return 0; }; inline IOIndex_t nbDataInputs() const noexcept override final { return 0; };
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }; inline IOIndex_t nbOutputs() const noexcept override final { return 1; };
static const std::vector<std::string> getInputsName(){
return {""};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
public: public:
void forward() override final { void forward() override final {
......
...@@ -125,6 +125,12 @@ public: ...@@ -125,6 +125,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> ReLU(const std::string& name = "") { inline std::shared_ptr<Node> ReLU(const std::string& name = "") {
......
...@@ -146,6 +146,12 @@ public: ...@@ -146,6 +146,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") {
......
...@@ -125,6 +125,12 @@ public: ...@@ -125,6 +125,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> Softmax(const std::string& name = "") { inline std::shared_ptr<Node> Softmax(const std::string& name = "") {
......
...@@ -20,7 +20,9 @@ namespace py = pybind11; ...@@ -20,7 +20,9 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
template <std::size_t NUM> void declare_Add(py::module &m) { template <std::size_t NUM> void declare_Add(py::module &m) {
py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "Add_Op", py::multiple_inheritance()); py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance())
.def("get_inputs_name", &Add_Op<NUM>::getInputsName)
.def("get_outputs_name", &Add_Op<NUM>::getOutputsName);
m.def("Add", &Add<NUM>, py::arg("name") = ""); m.def("Add", &Add<NUM>, py::arg("name") = "");
} }
......
...@@ -32,13 +32,15 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { ...@@ -32,13 +32,15 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) {
.def(py::init<const std::array<DimSize_t, DIM> &, .def(py::init<const std::array<DimSize_t, DIM> &,
const std::array<DimSize_t, DIM> &>(), const std::array<DimSize_t, DIM> &>(),
py::arg("kernel_dims"), py::arg("kernel_dims"),
py::arg("stride_dims")); py::arg("stride_dims"))
.def("get_inputs_name", &AvgPooling_Op<DIM>::getInputsName)
m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, .def("get_outputs_name", &AvgPooling_Op<DIM>::getOutputsName);
m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims,
const std::string& name, const std::string& name,
const std::vector<DimSize_t> &stride_dims) { const std::vector<DimSize_t> &stride_dims) {
// Lambda function wrapper because PyBind fails to convert const array. // Lambda function wrapper because PyBind fails to convert const array.
// So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array.
if (kernel_dims.size() != DIM) { if (kernel_dims.size() != DIM) {
throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]");
} }
...@@ -59,7 +61,7 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { ...@@ -59,7 +61,7 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) {
}, py::arg("kernel_dims"), }, py::arg("kernel_dims"),
py::arg("name") = "", py::arg("name") = "",
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1));
} }
...@@ -67,10 +69,10 @@ void init_AvgPooling(py::module &m) { ...@@ -67,10 +69,10 @@ void init_AvgPooling(py::module &m) {
declare_AvgPoolingOp<1>(m); declare_AvgPoolingOp<1>(m);
declare_AvgPoolingOp<2>(m); declare_AvgPoolingOp<2>(m);
declare_AvgPoolingOp<3>(m); declare_AvgPoolingOp<3>(m);
// FIXME: // FIXME:
// m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&AvgPooling)); // (&)[1])>(&AvgPooling));
} }
} // namespace Aidge } // namespace Aidge
#endif #endif
\ No newline at end of file
...@@ -21,7 +21,9 @@ namespace Aidge { ...@@ -21,7 +21,9 @@ namespace Aidge {
template <DimSize_t DIM> template <DimSize_t DIM>
void declare_BatchNormOp(py::module& m) { void declare_BatchNormOp(py::module& m) {
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNorm_Op" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()); py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance())
.def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
} }
......
...@@ -37,16 +37,19 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { ...@@ -37,16 +37,19 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
py::arg("out_channels"), py::arg("out_channels"),
py::arg("kernel_dims"), py::arg("kernel_dims"),
py::arg("stride_dims"), py::arg("stride_dims"),
py::arg("dilation_dims")); py::arg("dilation_dims"))
.def("get_inputs_name", &Conv_Op<DIM>::getInputsName)
.def("get_outputs_name", &Conv_Op<DIM>::getOutputsName)
;
m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels,
DimSize_t out_channels, DimSize_t out_channels,
const std::vector<DimSize_t>& kernel_dims, const std::vector<DimSize_t>& kernel_dims,
const std::string& name, const std::string& name,
const std::vector<DimSize_t> &stride_dims, const std::vector<DimSize_t> &stride_dims,
const std::vector<DimSize_t> &dilation_dims) { const std::vector<DimSize_t> &dilation_dims) {
// Lambda function wrapper because PyBind fails to convert const array. // Lambda function wrapper because PyBind fails to convert const array.
// So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array.
if (kernel_dims.size() != DIM) { if (kernel_dims.size() != DIM) {
throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]");
} }
...@@ -78,7 +81,6 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { ...@@ -78,7 +81,6 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
py::arg("name") = "", py::arg("name") = "",
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1),
py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1));
} }
...@@ -86,7 +88,7 @@ void init_Conv(py::module &m) { ...@@ -86,7 +88,7 @@ void init_Conv(py::module &m) {
declare_ConvOp<1>(m); declare_ConvOp<1>(m);
declare_ConvOp<2>(m); declare_ConvOp<2>(m);
declare_ConvOp<3>(m); declare_ConvOp<3>(m);
// FIXME: // FIXME:
// m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&Conv)); // (&)[1])>(&Conv));
......
...@@ -34,14 +34,16 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { ...@@ -34,14 +34,16 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) {
const std::array<DimSize_t, DIM> &>(), const std::array<DimSize_t, DIM> &>(),
py::arg("kernel_dims"), py::arg("kernel_dims"),
py::arg("stride_dims"), py::arg("stride_dims"),
py::arg("dilation_dims")); py::arg("dilation_dims"))
.def("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName)
m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, .def("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName);
m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims,
const std::string& name, const std::string& name,
const std::vector<DimSize_t> &stride_dims, const std::vector<DimSize_t> &stride_dims,
const std::vector<DimSize_t> &dilation_dims) { const std::vector<DimSize_t> &dilation_dims) {
// Lambda function wrapper because PyBind fails to convert const array. // Lambda function wrapper because PyBind fails to convert const array.
// So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array.
if (kernel_dims.size() != DIM) { if (kernel_dims.size() != DIM) {
throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]");
} }
...@@ -71,7 +73,7 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { ...@@ -71,7 +73,7 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) {
py::arg("name") = "", py::arg("name") = "",
py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1),
py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1));
} }
...@@ -79,7 +81,7 @@ void init_ConvDepthWise(py::module &m) { ...@@ -79,7 +81,7 @@ void init_ConvDepthWise(py::module &m) {
declare_ConvDepthWiseOp<1>(m); declare_ConvDepthWiseOp<1>(m);
declare_ConvDepthWiseOp<2>(m); declare_ConvDepthWiseOp<2>(m);
declare_ConvDepthWiseOp<3>(m); declare_ConvDepthWiseOp<3>(m);
// FIXME: // FIXME:
// m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const // m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const
// (&)[1])>(&ConvDepthWise)); // (&)[1])>(&ConvDepthWise));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment