-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_TensorImpl.cpp 3.85 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/backend/arrayfire/data/TensorImpl.hpp"
using namespace Aidge;
TEST_CASE("Tensor creation arrayfire", "[Tensor][arrayfire]") {
SECTION("from const array") {
Tensor x;
x.setDataType(Aidge::DataType::Int32);
x.setBackend("arrayfire");
x = Array3D<int,2,2,2>{
{
{
{1, 2},
{3, 4}
},
{
{5, 6},
{7, 8}
}
}};
Tensor xCopy;
xCopy.setDataType(Aidge::DataType::Int32);
xCopy.setBackend("arrayfire");
xCopy = Array3D<int,2,2,2>{
{
{
{1, 2},
{3, 4}
},
{
{5, 6},
{7, 8}
}
}};
Tensor xFloat;
xFloat.setBackend("arrayfire");
xFloat = Array3D<float,2,2,2>{
{
{
{1., 2.},
{3., 4.}
},
{
{5., 6.},
{7., 8.}
}
}};
SECTION("Tensor features") {
REQUIRE(x.nbDims() == 3);
REQUIRE(x.dims()[0] == 2);
REQUIRE(x.dims()[1] == 2);
REQUIRE(x.dims()[2] == 2);
REQUIRE(x.size() == 8);
}
SECTION("arrayfire tensor features") {
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().dims(0) == 2);
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().dims(1) == 2);
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().dims(2) == 2);
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().numdims() == 3);
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().elements() == 8);
REQUIRE(static_cast<TensorImpl_arrayfire<int>*>(x.getImpl().get())->data().bytes() == 8 * sizeof(int32_t));
}
SECTION("Access to array") {
REQUIRE(static_cast<int*>(x.getImpl()->rawPtr())[0] == 1);
REQUIRE(static_cast<int*>(x.getImpl()->rawPtr())[7] == 8);
}
SECTION("get function") {
REQUIRE(x.get<int>({0,0,0}) == 1);
REQUIRE(x.get<int>({0,0,1}) == 2);
REQUIRE(x.get<int>({0,1,1}) == 4);
REQUIRE(x.get<int>({1,1,0}) == 7);
x.set<int>({1, 1, 1}, 36);
REQUIRE(x.get<int>({1,1,1}) == 36);
}
SECTION("Pretty printing for debug") {
REQUIRE_NOTHROW(x.print());
}
SECTION("Tensor (in)equality") {
REQUIRE(x == xCopy);
REQUIRE_FALSE(x == xFloat);
}
}
SECTION("from const array before backend") {
Tensor x = Array3D<int,2,2,2>{
{
{
{1, 2},
{3, 4}
},
{
{5, 6},
{7, 8}
}
}};
x.setBackend("arrayfire");
REQUIRE(x.nbDims() == 3);
REQUIRE(x.dims()[0] == 2);
REQUIRE(x.dims()[1] == 2);
REQUIRE(x.dims()[2] == 2);
REQUIRE(x.size() == 8);
REQUIRE(x.get<int>({0,0,0}) == 1);
REQUIRE(x.get<int>({0,0,1}) == 2);
REQUIRE(x.get<int>({0,1,1}) == 4);
REQUIRE(x.get<int>({1,1,1}) == 8);
}
}