Skip to content
Snippets Groups Projects
Commit 87a17ffa authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Benjamin Halimi
Browse files

getAvailableImplSpecs now return a vector for binding purposes.

parent aca31cb5
No related branches found
No related tags found
1 merge request!54Fix the BatchNorm operator
Showing
with 20 additions and 20 deletions
......@@ -36,7 +36,7 @@ public:
return std::make_unique<AddImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<AndImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ArgMaxImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<AvgPoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<BatchNormImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -43,7 +43,7 @@ public:
return std::make_unique<ConvImpl_cuda<DIM>>(op, true);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<DivImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<FCImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<GlobalAveragePoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<ILayerNormImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<LnImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<MaxPoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<MulImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<PadImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<PowImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReLUImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceMeanImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceSumImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReshapeImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<ShiftGELUImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment