Skip to content
Snippets Groups Projects
Commit e45c2e38 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Various fixes

parent 927950e9
No related branches found
No related tags found
No related merge requests found
...@@ -377,11 +377,8 @@ class Tensor : public Data, ...@@ -377,11 +377,8 @@ class Tensor : public Data,
*/ */
void setDataType(const DataType dt) { void setDataType(const DataType dt) {
if (mImpl && (dataType() != dt)) { if (mImpl && (dataType() != dt)) {
// get ptr before changing Tensor backend or the type difference will trigger a warning
const void *data = mImpl->rawPtr();
mDataType = dt;
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this);
newImpl->copy(data, size()); // /!\ it does not cast data but reinterpret them newImpl->copyCast(mImpl->rawPtr(), size(), mDataType);
mImpl = std::move(newImpl); mImpl = std::move(newImpl);
} }
mDataType = dt; mDataType = dt;
...@@ -487,7 +484,7 @@ class Tensor : public Data, ...@@ -487,7 +484,7 @@ class Tensor : public Data,
std::string toString() { std::string toString() const {
if (dims().empty()) { return "{}"; } if (dims().empty()) { return "{}"; }
std::string res; std::string res;
std::size_t dim = 0; std::size_t dim = 0;
...@@ -580,7 +577,7 @@ class Tensor : public Data, ...@@ -580,7 +577,7 @@ class Tensor : public Data,
return res; return res;
} }
inline void print() { printf("%s\n", toString().c_str()); } inline void print() const { printf("%s\n", toString().c_str()); }
std::shared_ptr<Tensor> grad() { std::shared_ptr<Tensor> grad() {
if (!mGrad) { if (!mGrad) {
......
...@@ -169,11 +169,19 @@ public: ...@@ -169,11 +169,19 @@ public:
mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround // By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device); getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device); getInput(2)->setBackend(name, device);
} }
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"}; return {"data_input", "weight", "bias"};
} }
......
...@@ -57,10 +57,6 @@ public: ...@@ -57,10 +57,6 @@ public:
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
void setDataType(const DataType& dataType) const override {
mOutputs[0]->setDataType(dataType);
}
void forward() override; void forward() override;
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -99,12 +99,19 @@ public: ...@@ -99,12 +99,19 @@ public:
mImpl = Registrar<FC_Op>::create(name)(*this); mImpl = Registrar<FC_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround // By default, automatically set backend for weight and bias inputs
getInput(0)->setBackend(name, device);
getInput(1)->setBackend(name, device); getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device); getInput(2)->setBackend(name, device);
} }
void setDataType(const DataType& dt) const override {
mOutputs[0]->setDataType(dt);
// By default, automatically set data type for weight and bias inputs
getInput(1)->setDataType(dt);
getInput(2)->setDataType(dt);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"}; return {"data_input", "weight", "bias"};
} }
......
...@@ -31,10 +31,10 @@ ...@@ -31,10 +31,10 @@
* @param absolute absolute error allowed (shoulmd be positive) * @param absolute absolute error allowed (shoulmd be positive)
* @return true if both tensor are approximately equal and have the datatype, shape. Else return false * @return true if both tensor are approximately equal and have the datatype, shape. Else return false
*/ */
template <typename T> template <typename T1, typename T2 = T1>
bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){ bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute){
assert(t1.dataType() == t2.dataType()); assert(t1.dataType() == NativeType<T1>::type);
assert(t1.dataType() == NativeType<T>::type); assert(t2.dataType() == NativeType<T2>::type);
assert(relative >= 0); assert(relative >= 0);
assert(absolute >= 0 && absolute<=1); assert(absolute >= 0 && absolute<=1);
...@@ -42,7 +42,7 @@ bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute ...@@ -42,7 +42,7 @@ bool approxEq(Aidge::Tensor t1, Aidge::Tensor t2, float relative, float absolute
return false; return false;
} }
for(size_t i; i < t1.size(); ++i){ for(size_t i; i < t1.size(); ++i){
if (static_cast<float>(std::abs(t1.get<T>(i) - t2.get<T>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T>(i)))))){ if (static_cast<float>(std::abs(t1.get<T1>(i) - t2.get<T2>(i))) > (absolute + (relative * static_cast<float>(std::abs(t2.get<T2>(i)))))){
return false; return false;
} }
} }
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
if (srcImpl == *this) {
return;
}
if (srcImpl.device() != device()) { if (srcImpl.device() != device()) {
if (srcImpl.backend() == backend()) { if (srcImpl.backend() == backend()) {
// Same backend, but different device // Same backend, but different device
......
...@@ -42,7 +42,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -42,7 +42,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
const auto device = getImpl()->device(); const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second); fallback->setBackend(device.first, device.second);
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dt); fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
return *fallback; return *fallback;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment