Skip to content
Snippets Groups Projects
Commit 414f6632 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Make OperatorTensor::setDataType() more generic

parent ef618824
No related branches found
No related tags found
No related merge requests found
......@@ -105,16 +105,6 @@ public:
getInput(4)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for scale, shift, mean and variance
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
getInput(3)->setDataType(dt);
getInput(4)->setDataType(dt);
}
static const std::vector<std::string> getInputsName() {
return {"data_input", "scale", "shift", "mean", "variance"};
}
......
......@@ -180,14 +180,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
getInput(2)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
......
......@@ -174,14 +174,6 @@ public:
getInput(2)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
......
......@@ -104,14 +104,6 @@ public:
getInput(2)->setBackend(name, device);
}
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"};
}
......
......@@ -149,4 +149,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDataType(dataType);
}
for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
getInput(i)->setDataType(dataType);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment