diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fc800816b7ab012cd87d6af8d451392b39563e4f..b972c87dcda8f912ff40feef0001b95d5feac71e 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -336,11 +336,7 @@ void init_Tensor(py::module& m){ .def("cpy_transpose", (void (Tensor::*)(const Tensor& src, const std::vector<DimSize_t>& transpose)) &Tensor::copyTranspose, py::arg("src"), py::arg("transpose")) .def("__str__", [](Tensor& b) { - if (b.empty() && b.undefined()) { - return std::string("{}"); - } else { - return b.toString(); - } + return b.toString(); }) .def("__repr__", [](Tensor& b) { return fmt::format("Tensor(dims = {}, dtype = {})", b.dims(), std::string(EnumStrings<DataType>::data[static_cast<int>(b.dataType())])); diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 6f60d2f15ce0e561c32d7bc5a7561c2f8d507588..3dcdcc65d0ef40b0443eb5b9662111420ce4fb86 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -235,10 +235,11 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t>& dims, } std::string Aidge::Tensor::toString() const { - AIDGE_ASSERT( - mImpl && (undefined() || (dims() == std::vector<DimSize_t>({0})) || - (mImpl->hostPtr() != nullptr)), - "tensor should have a valid host pointer"); + + if (!hasImpl() || undefined()) { + // Return no value on no implementation or undefined size + return std::string("{}"); + } // TODO: move lambda elsewhere? auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) { @@ -272,8 +273,10 @@ std::string Aidge::Tensor::toString() const { }; if (dims().empty()) { + // The Tensor is defined with rank 0, hence scalar return ptrToString(mDataType, mImpl->hostPtr(), 0); } + std::string res; std::size_t dim = 0; std::size_t counter = 0; diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index c313675e13f9941fa8e044e488827837f7b1a708..4462eb91ed6c6cdfce77b47b6a1a8808eec88423 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -458,7 +458,7 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") { SECTION("Pretty printing for debug") { Tensor x{}; // Empty Tensor - REQUIRE_THROWS(x.print()); + REQUIRE_NOTHROW(x.print()); // scalar x = Tensor(42); REQUIRE_NOTHROW(x.print());