Skip to content
Snippets Groups Projects
Commit fca0a230 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Update TensorImpl constructor to take the tensor dimensions instead of the...

Update TensorImpl constructor to take the tensor  dimensions instead of the number of elements and override resize function.
parent ce2fe8c7
No related branches found
No related tags found
2 merge requests!10Update backend_opencv with modifications from aidge_core,!4Change tensorimpl opencv `future_std::span<cv::Mat>` to `cv::Mat`
...@@ -31,27 +31,30 @@ public: ...@@ -31,27 +31,30 @@ public:
virtual void setCvMat(const cv::Mat& mat ) = 0; virtual void setCvMat(const cv::Mat& mat ) = 0;
}; };
template <class T> class TensorImpl_opencv : public TensorImpl, public TensorImpl_opencv_ { template <class T>
class TensorImpl_opencv : public TensorImpl, public TensorImpl_opencv_ {
private: private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not // Stores the cv::Mat
// supposed to change it!
cv::Mat mData; cv::Mat mData;
protected:
std::vector<DimSize_t> mDims;
public: public:
static constexpr const char *Backend = "opencv"; static constexpr const char *Backend = "opencv";
TensorImpl_opencv() = delete; TensorImpl_opencv() = delete;
TensorImpl_opencv(const Tensor &tensor) TensorImpl_opencv(DeviceIdx_t device, std::vector<DimSize_t> dims)
: TensorImpl(Backend), mTensor(tensor) : TensorImpl(Backend, device, dims)
{} {
mDims = dims;
}
bool operator==(const TensorImpl &otherImpl) const override final { bool operator==(const TensorImpl &otherImpl) const override final {
// Create iterators for both matrices // Create iterators for both matrices
cv::MatConstIterator_<T> it1 = mData.begin<T>(); cv::MatConstIterator_<T> it1 = mData.begin<T>();
const cv::Mat & otherData = reinterpret_cast<const TensorImpl_opencv<T> &>(otherImpl).data(); const cv::Mat & otherData = reinterpret_cast<const TensorImpl_opencv<T> &>(otherImpl).data();
cv::MatConstIterator_<T> it2 = otherData.begin<T>(); cv::MatConstIterator_<T> it2 = otherData.begin<T>();
// Iterate over the elements and compare them // Iterate over the elements and compare them
...@@ -63,117 +66,127 @@ public: ...@@ -63,117 +66,127 @@ public:
return true; return true;
} }
static std::unique_ptr<TensorImpl_opencv> create(const Tensor &tensor) { static std::unique_ptr<TensorImpl_opencv> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
return std::make_unique<TensorImpl_opencv<T>>(tensor); return std::make_unique<TensorImpl_opencv<T>>(device, dims);
}
void resize(std::vector<DimSize_t> dims) override{
mDims = dims;
size_t product = 1;
for (size_t num : dims) {
product *= num;
}
mNbElts = product;
} }
// native interface // native interface
const cv::Mat & data() const override { return mData; } const cv::Mat & data() const override { return mData; }
inline std::size_t scalarSize() const override { return sizeof(T); } inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
std::size_t size() const override { return mData.total() * mData.channels();} void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
const T* srcT = static_cast<const T *>(src);
T* dstT = static_cast<T *>(rawPtr(offset));
void setDevice(DeviceIdx_t device) override { AIDGE_ASSERT(length <= (mData.total() * mData.channels()) || length <= mNbElts, "copy length is above capacity");
AIDGE_ASSERT(device == 0, "device cannot be != 0 for Opencv backend"); AIDGE_ASSERT(dstT < srcT || dstT >= srcT + length, "overlapping copy is not supported");
} std::copy(srcT, srcT + length, dstT);
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(length <= size() || length <= mTensor.size(), "copy length is above capacity");
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(rawPtr()) + offset);
} }
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override final{
if (length == 0) { if (length == 0) {
return; return;
} }
AIDGE_ASSERT(length <= size() || length <= mTensor.size(), "copy length is above capacity"); T* dstT = static_cast<T *>(rawPtr(offset));
if (srcDt == DataType::Float64) { AIDGE_ASSERT(length <= (mData.total() * mData.channels()) || length <= mNbElts, "copy length is above capacity");
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, switch (srcDt)
static_cast<T *>(rawPtr())); {
case DataType::Float64:
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
dstT);
break; break;
case DataType::Float32: case DataType::Float32:
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::Float16: case DataType::Float16:
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::Int64: case DataType::Int64:
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::UInt64: case DataType::UInt64:
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::Int32: case DataType::Int32:
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::UInt32: case DataType::UInt32:
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::Int16: case DataType::Int16:
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::UInt16: case DataType::UInt16:
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::Int8: case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
case DataType::UInt8: case DataType::UInt8:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
static_cast<T *>(rawPtr())); dstT);
break; break;
default: default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
} }
} }
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match"); AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend"); AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
copy(src, length); copy(src, length, offset);
} }
void copyFromHost(const void *src, NbElts_t length) override { void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
copy(src, length); copy(src, length, offset);
} }
void copyToHost(void *dst, NbElts_t length) const override { void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const override final {
AIDGE_ASSERT(length <= size() || length <= mTensor.size(), "copy length is above capacity"); const T* src = static_cast<const T*>(rawPtr(offset));
const T* src = static_cast<const T*>(rawPtr()); AIDGE_ASSERT(length <= (mData.total() * mData.channels()) || length <= mNbElts, "copy length is above capacity");
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length, std::copy(src, src + length, static_cast<T *>(dst));
static_cast<T *>(dst));
} }
void *rawPtr(NbElts_t offset = 0) override { void *rawPtr(NbElts_t offset = 0) override final {
lazyInit(); lazyInit();
return (mData.ptr<T>() + offset); return (mData.ptr<T>() + offset);
}; };
const void *rawPtr(NbElts_t offset = 0) const override { const void *rawPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(size() >= mTensor.size(), "accessing uninitialized const rawPtr"); AIDGE_ASSERT((mData.total() * mData.channels()) >= mNbElts, "accessing uninitialized const rawPtr");
return (mData.ptr<T>() + offset); return (mData.ptr<T>() + offset);
}; };
void *hostPtr(NbElts_t offset = 0) override { void *hostPtr(NbElts_t offset = 0) override final {
lazyInit(); lazyInit();
return (mData.ptr<T>() + offset); return (mData.ptr<T>() + offset);
}; };
const void *hostPtr(NbElts_t offset = 0) const override { const void *hostPtr(NbElts_t offset = 0) const override {
AIDGE_ASSERT(size() >= mTensor.size(), "accessing uninitialized const hostPtr"); AIDGE_ASSERT((mData.total() * mData.channels()) >= mNbElts, "accessing uninitialized const hostPtr");
AIDGE_ASSERT(mData.isContinuous(), "CV Matrix not continuous"); AIDGE_ASSERT(mData.isContinuous(), "CV Matrix not continuous");
return (mData.ptr<T>() + offset); return (mData.ptr<T>() + offset);
}; };
...@@ -186,28 +199,27 @@ public: ...@@ -186,28 +199,27 @@ public:
private: private:
void lazyInit() { void lazyInit() {
if (size() < mTensor.size()) { if ((mData.total() * mData.channels()) < mNbElts) {
// Need more data, a re-allocation will occur // Need more data, a re-allocation will occur
AIDGE_ASSERT(mData.empty() , "trying to enlarge non-owned data"); AIDGE_ASSERT(mData.empty() , "trying to enlarge non-owned data");
if (mTensor.nbDims() < 3) { if (mDims.size() < 3) {
mData = cv::Mat(((mTensor.nbDims() > 1) ? static_cast<int>(mTensor.dims()[1]) mData = cv::Mat(((mDims.size() > 1) ? static_cast<int>(mDims[0])
: (mTensor.nbDims() > 0) ? 1 : (mDims.size() > 0) ? 1
: 0), : 0),
(mTensor.nbDims() > 0) ? static_cast<int>(mTensor.dims()[0]) : 0, (mDims.size() > 0) ? static_cast<int>(mDims[1]) : 0,
detail::CV_C1_CPP_v<T>); detail::CV_C1_CPP_v<T>);
} else { } else {
std::vector<cv::Mat> channels; std::vector<cv::Mat> channels;
for (std::size_t k = 0; k < mTensor.dims()[2]; ++k) { for (std::size_t k = 0; k < mDims[2]; ++k) {
channels.push_back(cv::Mat(static_cast<int>(mTensor.dims()[1]), channels.push_back(cv::Mat(static_cast<int>(mDims[0]),
static_cast<int>(mTensor.dims()[0]), static_cast<int>(mDims[1]),
detail::CV_C1_CPP_v<T>)); detail::CV_C1_CPP_v<T>));
} }
cv::merge(channels, mData); cv::merge(channels, mData);
} }
} }
} }
}; };
......
...@@ -103,4 +103,31 @@ TEST_CASE("Tensor creation opencv", "[Tensor][OpenCV]") { ...@@ -103,4 +103,31 @@ TEST_CASE("Tensor creation opencv", "[Tensor][OpenCV]") {
REQUIRE_FALSE(x == xFloat); REQUIRE_FALSE(x == xFloat);
} }
} }
}
SECTION("from const array before backend") {
Tensor x = Array3D<int,2,2,2>{
{
{
{1, 2},
{3, 4}
},
{
{5, 6},
{7, 8}
}
}};
x.setBackend("opencv");
REQUIRE(x.nbDims() == 3);
REQUIRE(x.dims()[0] == 2);
REQUIRE(x.dims()[1] == 2);
REQUIRE(x.dims()[2] == 2);
REQUIRE(x.size() == 8);
REQUIRE(x.get<int>({0,0,0}) == 1);
REQUIRE(x.get<int>({0,0,1}) == 2);
REQUIRE(x.get<int>({0,1,1}) == 4);
REQUIRE(x.get<int>({1,1,1}) == 8);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment