Skip to content
Snippets Groups Projects
Commit 142cba14 authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

update TensorImplwith capacity() and zeros(), change DataType to dtype in Python

parent 1062621f
No related branches found
No related tags found
2 merge requests!17version 0.0.3,!16UI parameters
Pipeline #49601 failed
......@@ -19,7 +19,7 @@ class test_tensor(unittest.TestCase):
# np_array = np.arange(9).reshape(1,1,3,3)
# # Numpy -> Tensor
# t = aidge_core.Tensor(np_array)
# self.assertEqual(t.dtype(), aidge_core.DataType.Int32)
# self.assertEqual(t.dtype(), aidge_core.dtype.int32)
# for i_t, i_n in zip(t, np_array.flatten()):
# self.assertTrue(i_t == i_n)
# for i,j in zip(t.dims(), np_array.shape):
......@@ -41,7 +41,7 @@ class test_tensor(unittest.TestCase):
# np_array = np.random.rand(1, 1, 3, 3).astype(np.float32)
# # Numpy -> Tensor
# t = aidge_core.Tensor(np_array)
# self.assertEqual(t.dtype(), aidge_core.DataType.Float32)
# self.assertEqual(t.dtype(), aidge_core.dtype.float32)
# for i_t, i_n in zip(t, np_array.flatten()):
# self.assertTrue(i_t == i_n) # TODO : May need to change this to a difference
# for i,j in zip(t.dims(), np_array.shape):
......
......@@ -31,7 +31,7 @@ public:
virtual void setCvMat(const cv::Mat& mat ) = 0;
};
template <class T>
template <class T>
class TensorImpl_opencv : public TensorImpl, public TensorImpl_opencv_ {
private:
// Stores the cv::Mat
......@@ -44,7 +44,7 @@ public:
static constexpr const char *Backend = "opencv";
TensorImpl_opencv() = delete;
TensorImpl_opencv(DeviceIdx_t device, std::vector<DimSize_t> dims)
TensorImpl_opencv(DeviceIdx_t device, std::vector<DimSize_t> dims)
: TensorImpl(Backend, device, dims)
{
mDims = dims;
......@@ -70,6 +70,14 @@ public:
return std::make_unique<TensorImpl_opencv<T>>(device, dims);
}
inline std::size_t capacity() const noexcept override {
return mData.total() * mData.channels();
}
void zeros() override final {
mData = cv::Mat::zeros(mData.size(), mData.type());
}
void resize(std::vector<DimSize_t> dims) override{
mDims = dims;
size_t product = 1;
......@@ -104,7 +112,7 @@ public:
if (length == 0) {
return;
}
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(length <= (mData.total() * mData.channels()) || length <= mNbElts, "TensorImpl_opencv<{}>::copyCast(): copy length ({}) is above capacity ({})", typeid(T).name(), length, mNbElts);
switch (srcDt)
......@@ -198,7 +206,7 @@ public:
};
void setCvMat(const cv::Mat& mat) override {mData=mat;}
virtual ~TensorImpl_opencv() = default;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment