Skip to content
Snippets Groups Projects
Commit d759b214 authored by vincent  lorrain's avatar vincent lorrain
Browse files

merge main and fix

parents d93381a6 daec3e64
No related branches found
No related tags found
1 merge request!29GraphRegex interface class
Pipeline #33940 canceled
Showing
with 206 additions and 40 deletions
...@@ -9,6 +9,14 @@ variables: ...@@ -9,6 +9,14 @@ variables:
GIT_SSL_NO_VERIFY: 1 GIT_SSL_NO_VERIFY: 1
DEBIAN_FRONTEND: noninteractive DEBIAN_FRONTEND: noninteractive
# See https://docs.gitlab.com/ee/ci/yaml/workflow.html#switch-between-branch-pipelines-and-merge-request-pipelines
workflow:
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
- if: $CI_COMMIT_BRANCH && $CI_OPEN_MERGE_REQUESTS
when: never
- if: $CI_COMMIT_BRANCH
default: default:
image: nvidia/cuda:12.2.0-devel-ubuntu22.04 image: nvidia/cuda:12.2.0-devel-ubuntu22.04
before_script: before_script:
......
...@@ -88,8 +88,7 @@ build:ubuntu_python: ...@@ -88,8 +88,7 @@ build:ubuntu_python:
- virtualenv venv - virtualenv venv
- source venv/bin/activate - source venv/bin/activate
# Numpy dependancy for unit test # Numpy dependancy for unit test
- python3 -m pip install numpy - python3 -m pip install -r requirements.txt
- export AIDGE_INSTALL=`pwd`/install
- python3 -m pip install . - python3 -m pip install .
artifacts: artifacts:
expire_in: 1 week expire_in: 1 week
...@@ -147,8 +146,7 @@ build:windows_python: ...@@ -147,8 +146,7 @@ build:windows_python:
- virtualenv venv - virtualenv venv
- venv\Scripts\Activate.ps1 - venv\Scripts\Activate.ps1
# Numpy dependancy for unit test # Numpy dependancy for unit test
- python -m pip install numpy - python -m pip install -r requirements.txt
- $env:AIDGE_INSTALL = "$pwd" + "install"
- python -m pip install . - python -m pip install .
artifacts: artifacts:
expire_in: 1 week expire_in: 1 week
......
...@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project) ...@@ -6,7 +6,7 @@ file(READ "${CMAKE_SOURCE_DIR}/project_name.txt" project)
message(STATUS "Project name: ${project}") message(STATUS "Project name: ${project}")
message(STATUS "Project version: ${version}") message(STATUS "Project version: ${version}")
# Note : project name is {project} and python module name is also {project} # Note : project name is {project} and python module name is also {project}
set(module_name _${project}) # target name set(module_name _${project}) # target name
...@@ -57,7 +57,7 @@ if (PYBIND) ...@@ -57,7 +57,7 @@ if (PYBIND)
# Handles Python + pybind11 headers dependencies # Handles Python + pybind11 headers dependencies
target_link_libraries(${module_name} target_link_libraries(${module_name}
PUBLIC PUBLIC
pybind11::pybind11 pybind11::pybind11
PRIVATE PRIVATE
Python::Python Python::Python
...@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) ...@@ -101,8 +101,8 @@ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(EXPORT ${project}-targets install(EXPORT ${project}-targets
FILE "${project}-targets.cmake" FILE "${project}-targets.cmake"
DESTINATION ${INSTALL_CONFIGDIR} DESTINATION ${INSTALL_CONFIGDIR}
# COMPONENT ${module_name} # COMPONENT ${module_name}
) )
#Create a ConfigVersion.cmake file #Create a ConfigVersion.cmake file
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
...@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets ...@@ -136,4 +136,4 @@ export(EXPORT ${project}-targets
if(TEST) if(TEST)
enable_testing() enable_testing()
add_subdirectory(unit_tests) add_subdirectory(unit_tests)
endif() endif()
\ No newline at end of file
...@@ -8,3 +8,4 @@ http://www.eclipse.org/legal/epl-2.0. ...@@ -8,3 +8,4 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0 SPDX-License-Identifier: EPL-2.0
""" """
from aidge_core.aidge_core import * # import so generated by PyBind from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode
from .node_export import *
import aidge_core
from abc import ABC, abstractmethod
class ExportNode(ABC):
"""Abstract class to interface node with export generation.
"""
@abstractmethod
def __init__(self, aidge_node: aidge_core.Node) -> None:
"""Create ExportNode and retieve attirubtes from ``aidge_node``:
- name: aidge Node name
- attributes: dictionnary of attributes of the aidge Operator linked to the node, attributes name follow aidge naming convention
- parameters: List of parameters node, order in the list is the same as the one defined by the aidge operator
"""
super().__init__()
self.node = aidge_node
self.operator = aidge_node.get_operator()
self.name = self.node.name()
self.attributes = {} # Attributes are auto fetched from aidge operators
if isinstance(self.operator, aidge_core.Attributes):
for attr_name in self.operator.get_attrs_name():
self.attributes[attr_name] = self.operator.get_attr(attr_name)
# rename is_leaf ?
self.is_last = len(self.node.get_children()) == 0
self.inputs = []
self.outputs = []
self.inputs_dims = []
self.outputs_dims = []
for idx, parent_node in enumerate(self.node.get_parents()):
self.inputs.append(parent_node)
if parent_node is not None:
self.inputs_dims.append(self.operator.input(idx).dims())
else:
self.inputs_dims.append(None)
for idx, child_node in enumerate(self.node.get_children()):
self.outputs.append(child_node)
# Dirty hot fix, change it quickly
self.outputs_dims.append(self.operator.output(0).dims())
@abstractmethod
def export(self, export_folder:str, list_configs:list):
"""Define how to export the node definition.
"""
pass
@abstractmethod
def forward(self, list_actions:list):
"""Define how to generate code to perform a forward pass.
"""
pass
...@@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase): ...@@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase):
genOp.get_operator().compute_output_dims() genOp.get_operator().compute_output_dims()
self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims)
def test_set_impl(self):
class PythonCustomImpl(aidge_core.OperatorImpl):
"""Dummy implementation to test that C++ call python code
"""
def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error !
self.idx = 0
def forward(self):
"""Increment idx attribute on forward.
"""
self.idx += 1
generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu")
generic_op = generic_node.get_operator()
customImpl = PythonCustomImpl(generic_op)
generic_op.forward() # Do nothing, no implementation set
generic_op.set_impl(customImpl)
generic_op.forward() # Increment idx
self.assertEqual(customImpl.idx, 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Softmax.hpp" #include "aidge/operator/Softmax.hpp"
......
...@@ -18,11 +18,13 @@ ...@@ -18,11 +18,13 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Operator;
class OperatorImpl { class OperatorImpl {
public: public:
OperatorImpl(const Operator& op);
virtual void forward(){}; virtual void forward();
virtual void backward(){}; virtual void backward();
/** /**
* @brief Minimum amount of data from a specific input required by the * @brief Minimum amount of data from a specific input required by the
...@@ -31,13 +33,13 @@ public: ...@@ -31,13 +33,13 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return std::size_t * @return std::size_t
*/ */
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution. // Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size. // Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0; virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/** /**
* @brief Total amount of consumed data from a specific input. * @brief Total amount of consumed data from a specific input.
...@@ -45,7 +47,7 @@ public: ...@@ -45,7 +47,7 @@ public:
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0; virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/** /**
* @brief Total amount of produced data ready to be used on a specific output. * @brief Total amount of produced data ready to be used on a specific output.
...@@ -53,15 +55,20 @@ public: ...@@ -53,15 +55,20 @@ public:
* @param outputIdx Index of the output analysed. * @param outputIdx Index of the output analysed.
* @return DimSize_t * @return DimSize_t
*/ */
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0; virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
/** /**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o * @brief Update the Consummer Producer system by simulating the consumption and production of i/o
* *
*/ */
virtual void updateConsummerProducer() = 0; virtual void updateConsummerProducer();
virtual ~OperatorImpl() = default; virtual ~OperatorImpl() = default;
protected:
const Operator &mOp;
std::vector<NbElts_t> mNbConsumedData;
std::vector<NbElts_t> mNbProducedData;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -477,13 +477,14 @@ class Tensor : public Data, ...@@ -477,13 +477,14 @@ class Tensor : public Data,
if (dims().empty()) { return "{}"; } if (dims().empty()) { return "{}"; }
std::string res; std::string res;
std::size_t dim = 0; std::size_t dim = 0;
std::size_t *dimVals = new std::size_t[nbDims()];
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
std::size_t counter = 0; std::size_t counter = 0;
res += "{\n"; if (nbDims()>=2) {
if (nbDims()>=2){ std::size_t *dimVals = new std::size_t[nbDims()];
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
// std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0);
res += "{\n";
while (counter < mSize) { while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' '); std::string spaceString = std::string((dim+1)<<1,' ');
if (dim < nbDims()-2) { if (dim < nbDims()-2) {
...@@ -532,31 +533,35 @@ class Tensor : public Data, ...@@ -532,31 +533,35 @@ class Tensor : public Data,
} }
res += "\n"; res += "\n";
} }
if (dim == 0) {
break;
}
dimVals[dim--] = 0; dimVals[dim--] = 0;
dimVals[dim]++; dimVals[dim]++;
} }
} }
for(int i = static_cast<int>(dim); i>=0; --i) { delete[] dimVals;
for(int i = static_cast<int>(dim); i > 0; --i) {
res += std::string((dim+1)<<1,' ') + "}\n"; res += std::string((dim+1)<<1,' ') + "}\n";
} }
}else{ } else {
res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
switch (mDataType) switch (mDataType)
{ {
case DataType::Int32: case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break; break;
case DataType::Float64: case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break; break;
default: default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break; break;
} }
} }
} }
res += "}"; res += "}";
return res; return res;
} }
......
...@@ -21,7 +21,7 @@ class GraphRegex{ ...@@ -21,7 +21,7 @@ class GraphRegex{
std::vector<std::string> mQuery; std::vector<std::string> mQuery;
std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest;
std::map<std::string, std::function<bool(NodePtr)>&> mAllLambda; std::map<std::string, std::function<bool(NodePtr)>> mAllLambda;
public: public:
GraphRegex(){}; GraphRegex(){};
......
...@@ -162,6 +162,12 @@ public: ...@@ -162,6 +162,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return NUM; } inline IOIndex_t nbInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::size_t NUM> template <std::size_t NUM>
......
...@@ -157,18 +157,23 @@ public: ...@@ -157,18 +157,23 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "", const std::string& name = "",
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) {
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported");
auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); return std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name);
return avgPool;
} }
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> AvgPooling( inline std::shared_ptr<Node> AvgPooling(
DimSize_t const (&kernel_dims)[DIM], DimSize_t const (&kernel_dims)[DIM],
......
...@@ -160,6 +160,12 @@ public: ...@@ -160,6 +160,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 5; } inline IOIndex_t nbInputs() const noexcept override final { return 5; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "scale", "shift", "mean", "variance"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <DimSize_t DIM> template <DimSize_t DIM>
......
...@@ -177,6 +177,12 @@ public: ...@@ -177,6 +177,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
...@@ -195,6 +201,7 @@ inline std::shared_ptr<Node> Conv(DimSize_t in_channels, ...@@ -195,6 +201,7 @@ inline std::shared_ptr<Node> Conv(DimSize_t in_channels,
return conv; return conv;
} }
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> Conv( inline std::shared_ptr<Node> Conv(
DimSize_t in_channels, DimSize_t in_channels,
......
...@@ -176,6 +176,12 @@ class ConvDepthWise_Op : public Operator, ...@@ -176,6 +176,12 @@ class ConvDepthWise_Op : public Operator,
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
...@@ -191,6 +197,7 @@ inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &ker ...@@ -191,6 +197,7 @@ inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &ker
return convDW; return convDW;
} }
// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> ConvDepthWise( inline std::shared_ptr<Node> ConvDepthWise(
DimSize_t const (&kernel_dims)[DIM], DimSize_t const (&kernel_dims)[DIM],
......
...@@ -158,6 +158,12 @@ public: ...@@ -158,6 +158,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") { inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") {
...@@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", ...@@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels",
"NoBias"}; "NoBias"};
} }
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ #endif /* AIDGE_CORE_OPERATOR_FC_H_ */
\ No newline at end of file
...@@ -168,9 +168,20 @@ class GenericOperator_Op ...@@ -168,9 +168,20 @@ class GenericOperator_Op
void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); }
void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); } void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); }
void forward() override final { printf("forward: not available yet.\n"); } void forward() override final {
void backward() override final { printf("backward: not available yet.\n"); } if(mImpl){
mImpl->forward();
}else{
printf("forward: No implementation is linked.\n");
}
}
void backward() override final {
if(mImpl){
mImpl->backward();
}else{
printf("backward: No implementation is linked.\n");
}
}
inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; };
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; };
inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; }; inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; };
......
...@@ -137,10 +137,15 @@ public: ...@@ -137,10 +137,15 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") { inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<LeakyReLU_Op>(negativeSlope), name); return std::make_shared<Node>(std::make_shared<LeakyReLU_Op>(negativeSlope), name);
} }
} }
......
...@@ -148,6 +148,12 @@ public: ...@@ -148,6 +148,12 @@ public:
inline IOIndex_t nbInputs() const noexcept override final { return 2; } inline IOIndex_t nbInputs() const noexcept override final { return 2; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
}; };
inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") { inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment