Skip to content
Snippets Groups Projects
Commit 37a9cafe authored by Maxence Naud's avatar Maxence Naud
Browse files

Update Operators with device function

parent 08608567
No related branches found
No related tags found
No related merge requests found
......@@ -51,12 +51,9 @@ public:
return std::make_shared<Erf_Op>(*this);
}
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Erf_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -40,7 +40,7 @@ public:
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis)
......@@ -70,13 +70,9 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Gather_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor,
}
else
outDims.push_back(getInput(0)->dims()[d]);
}
}
if(outDims.size()>0)
mOutputs[0]->resize(outDims);
else
......@@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor,
}
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -66,12 +66,9 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name) override {
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor,
}
}
void setBackend(const std::string &name) override {
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name);
// FIXME: temporary workaround
getInput(0)->setBackend(name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment