Skip to content
Snippets Groups Projects

Multiple fixes to enable multi-GPUs forward execution

Merged Olivier BICHLER requested to merge htmlescape into dev
Files
46
@@ -885,6 +885,10 @@ public:
// First refFrom, to ensure that fallback, if required, is also on desired device
return refFrom(fallback, backend, device).refCast(fallback, dt);
}
const Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const {
// First refFrom, to ensure that fallback, if required, is also on desired device
return refFrom(fallback, backend, device).refCast(fallback, dt);
}
/**
* Return a reference to a Tensor with same characteristics
@@ -904,6 +908,10 @@ public:
const auto& device = targetReqs.getImpl()->device();
return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
}
const Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) const {
const auto& device = targetReqs.getImpl()->device();
return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
}
/**
* @brief Return a reference to a Tensor on desired data type and backend/device:
@@ -941,6 +949,10 @@ public:
const auto& device = targetReqs.getImpl()->device();
return ref(fallback, targetReqs.dataType(), device.first, device.second);
}
const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) const {
const auto& device = targetReqs.getImpl()->device();
return ref(fallback, targetReqs.dataType(), device.first, device.second);
}
/**
Loading