Skip to content
Snippets Groups Projects
Commit 0f4c668b authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

Merge branch 'Fix/update_backend_cuda' into 'dev'

[Fix] update backend cuda

See merge request !42
parents 61e04aa3 3629ffc6
No related branches found
No related tags found
4 merge requests!61v0.4.0,!54Fix the BatchNorm operator,!47v0.4.0,!42[Fix] update backend cuda
Pipeline #57867 passed
Showing
with 20 additions and 20 deletions
......@@ -36,7 +36,7 @@ public:
return std::make_unique<AddImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<AndImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ArgMaxImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<AvgPoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<BatchNormImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -43,7 +43,7 @@ public:
return std::make_unique<ConvImpl_cuda<DIM>>(op, true);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<DivImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<FCImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<GlobalAveragePoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<ILayerNormImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<LnImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<MaxPoolingImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<MulImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<PadImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<PowImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReLUImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Any}
};
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceMeanImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReduceSumImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -36,7 +36,7 @@ public:
return std::make_unique<ReshapeImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
......@@ -37,7 +37,7 @@ public:
return std::make_unique<ShiftGELUImpl_cuda>(op);
}
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
virtual std::vector<ImplSpec> getAvailableImplSpecs() const override {
return {
{DataType::Float64},
{DataType::Float32},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment