Skip to content
Snippets Groups Projects
Commit 87a17ffa authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Benjamin Halimi
Browse files

getAvailableImplSpecs now return a vector for binding purposes.

parent aca31cb5
No related branches found
No related tags found
1 merge request!54Fix the BatchNorm operator
Showing
with 20 additions and 20 deletions
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<AddImpl_cuda>(op); return std::make_unique<AddImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<AndImpl_cuda>(op); return std::make_unique<AndImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<ArgMaxImpl_cuda>(op); return std::make_unique<ArgMaxImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<AvgPoolingImpl_cuda>(op); return std::make_unique<AvgPoolingImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<BatchNormImpl_cuda>(op); return std::make_unique<BatchNormImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
return std::make_unique<ConvImpl_cuda<DIM>>(op, true); return std::make_unique<ConvImpl_cuda<DIM>>(op, true);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Any} {DataType::Any}
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<DivImpl_cuda>(op); return std::make_unique<DivImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<FCImpl_cuda>(op); return std::make_unique<FCImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<GlobalAveragePoolingImpl_cuda>(op); return std::make_unique<GlobalAveragePoolingImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Any} {DataType::Any}
}; };
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<ILayerNormImpl_cuda>(op); return std::make_unique<ILayerNormImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<LnImpl_cuda>(op); return std::make_unique<LnImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<MaxPoolingImpl_cuda>(op); return std::make_unique<MaxPoolingImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Any} {DataType::Any}
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<MulImpl_cuda>(op); return std::make_unique<MulImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<PadImpl_cuda>(op); return std::make_unique<PadImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<PowImpl_cuda>(op); return std::make_unique<PowImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<ReLUImpl_cuda>(op); return std::make_unique<ReLUImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Any} {DataType::Any}
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceMeanImpl_cuda>(op); return std::make_unique<ReduceMeanImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceSumImpl_cuda>(op); return std::make_unique<ReduceSumImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
return std::make_unique<ReshapeImpl_cuda>(op); return std::make_unique<ReshapeImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
return std::make_unique<ShiftGELUImpl_cuda>(op); return std::make_unique<ShiftGELUImpl_cuda>(op);
} }
virtual std::set<ImplSpec> getAvailableImplSpecs() const override { virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return { return {
{DataType::Float64}, {DataType::Float64},
{DataType::Float32}, {DataType::Float32},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment