Skip to content
Snippets Groups Projects
Commit f1d89973 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Removed temporary workaround

parent e45c2e38
No related branches found
No related tags found
No related merge requests found
Showing
with 11 additions and 67 deletions
......@@ -79,11 +79,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Add_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
for (std::size_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setBackend(name, device);
}
}
static const std::vector<std::string> getInputsName(){
......
......@@ -137,9 +137,6 @@ public:
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -98,13 +98,23 @@ public:
mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
// By default, automatically set backend for scale, shift, mean and variance
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
getInput(3)->setBackend(name, device);
getInput(4)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for scale, shift, mean and variance
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
getInput(3)->setDataType(dt);
getInput(4)->setDataType(dt);
}
static const std::vector<std::string> getInputsName() {
return {"data_input", "scale", "shift", "mean", "variance"};
}
......
......@@ -104,11 +104,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
for (std::size_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setBackend(name, device);
}
}
static const std::vector<std::string> getInputsName(){
......
......@@ -57,10 +57,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Div_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -70,9 +70,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<LeakyReLU_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -86,10 +86,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<MatMul_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -107,9 +107,6 @@ public:
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -59,10 +59,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Mul_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -100,9 +100,6 @@ public:
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -57,10 +57,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Pow_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -54,9 +54,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<ReLU_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -69,8 +69,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Scaling_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
mInputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName() {
......
......@@ -93,9 +93,6 @@ public:
void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<Slice_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -54,9 +54,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Softmax_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -59,9 +59,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Sqrt_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -62,10 +62,6 @@ public:
void setBackend(const std::string& name, int device = 0) override {
mImpl = Registrar<Sub_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){
......
......@@ -149,14 +149,4 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
}
/*
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
}
else {
getInput(i)->setDataType(dataType);
}
}
*/
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment