Skip to content
Snippets Groups Projects

Update TensorImpl constructor to take the tensor dimensions instead of the number of elements.

Merged Thibault Allenet requested to merge TensorImpl into dev
1 file
+ 4
3
Compare changes
  • Side-by-side
  • Inline
@@ -69,12 +69,13 @@ private:
@@ -69,12 +69,13 @@ private:
public:
public:
static constexpr const char *Backend = "cuda";
static constexpr const char *Backend = "cuda";
TensorImpl_cuda(DeviceIdx_t device, NbElts_t length) : TensorImpl(Backend, device, length), mDataOwner(nullptr, cudaDelete) {}
TensorImpl_cuda(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims), mDataOwner(nullptr, cudaDelete) {}
 
bool operator==(const TensorImpl &otherImpl) const override final;
bool operator==(const TensorImpl &otherImpl) const override final;
static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, NbElts_t length) {
static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
return std::make_shared<TensorImpl_cuda<T>>(device, length);
return std::make_shared<TensorImpl_cuda<T>>(device, dims);
}
}
// native interface
// native interface
Loading