Skip to content
Snippets Groups Projects
Commit 9debe864 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

feat : operator squeeze

parent ad93e4b2
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!194Feat/operator squeeze operator unsqueeze
"""
Copyright (c) 2023 CEA-List
This program and the accompanying materials are made available under the
terms of the Eclipse Public License 2.0 which is available at
http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
import unittest
import aidge_core
from aidge_core import Log
import numpy as np
from numpy import testing as npt
class TestSqueeze(unittest.TestCase):
"""
Test squeeze operator
"""
def setUp(self):
############DEFINING INPUT AND OUTPUTS FOR TESTS
axes_to_squeeze_0 = [0]
axes_to_squeeze_many = [0, 1, 4]
axes_to_squeeze_all = []
axes_to_squeeze_error = [1, 2, 4, 5, 10, 3, 42, 127, 12, 3, 4, 1, 4, 50]
squeeze_dim_0 = aidge_core.Squeeze(axes_to_squeeze_0, name="squeeze_dim_0")
squeeze_many = aidge_core.Squeeze(axes_to_squeeze_many, name="squeeze_many")
squeeze_all = aidge_core.Squeeze(axes_to_squeeze_all, name="squeeze_all")
squeeze_error = aidge_core.Squeeze(axes_to_squeeze_error, name="squeeze_error")
input_1_data_shape = np.array([1, 2, 3])
input_2_data_hape = np.array([1, 1, 3, 3, 1, 9])
input_3_data_shape = np.array([1])
input_4_data_shape = np.array([1, 1, 4])
input_axes_0 = axes_to_squeeze_0
input_axes_many = axes_to_squeeze_many
input_axes_all = axes_to_squeeze_all
# input_axes_error = aidge_core.Tensor(axes_to_squeeze_error)
####################### DEFINING TEST RUNS
self.tests_axes_defined_by_attribute = [
(input_1_data_shape, squeeze_dim_0, np.array([2, 3])),
(input_1_data_shape, squeeze_all, np.array([2, 3])),
(input_2_data_hape, squeeze_dim_0, np.array([1, 3, 3, 1, 9])),
(input_2_data_hape, squeeze_many, np.array([3, 3, 9])),
(input_2_data_hape, squeeze_all, np.array([3, 3, 9])),
(input_3_data_shape, squeeze_dim_0, np.array([])),
(input_3_data_shape, squeeze_all, np.array([])),
(input_4_data_shape, squeeze_dim_0, np.array([1, 4])),
(input_4_data_shape, squeeze_all, np.array([4])),
]
# operators are puprposefully chosen with different predefined attribute than the input_axes tensor
self.tests_axes_defined_by_input = [
(input_1_data_shape, input_axes_0, squeeze_error, np.array([2, 3])),
(input_1_data_shape, input_axes_all, squeeze_error, np.array([2, 3])),
(input_2_data_hape, input_axes_0, squeeze_error, np.array([1, 3, 3, 1, 9])),
(input_2_data_hape, input_axes_many, squeeze_error, np.array([3, 3, 9])),
(input_2_data_hape, input_axes_all, squeeze_error, np.array([3, 3, 9])),
(input_3_data_shape, input_axes_0, squeeze_error, np.array([])),
(input_3_data_shape, input_axes_all, squeeze_error, np.array([])),
(input_4_data_shape, input_axes_0, squeeze_error, np.array([1, 4])),
(input_4_data_shape, input_axes_all, squeeze_error, np.array([4])),
]
self.test_error = [
(input_1_data_shape, squeeze_error),
(input_1_data_shape, squeeze_many),
(input_3_data_shape, squeeze_many),
(input_4_data_shape, squeeze_many),
]
return
def tearDown(self):
pass
def test_axes_defined_via_tensor_input(self):
Log.notice("\ntest_axes_defined_via_tensor_input")
for index, (
input_shape,
input_axes_to_squeeze,
squeeze_node_template,
output_shape,
) in enumerate(self.tests_axes_defined_by_input):
test_squeeze_node = squeeze_node_template
test_squeeze_op = test_squeeze_node.get_operator()
print(f"\nTest {index}")
print(f"input shape : {input_shape}")
print(f"input axes: {np.array(input_axes_to_squeeze)}")
print(f"operator : {test_squeeze_node}")
print(f"expected output_shape : {output_shape}")
test_squeeze_op.set_backend("cpu")
test_squeeze_op.set_datatype(aidge_core.dtype.float32)
input_values = np.ones(shape=input_shape, dtype=np.float32)
output_values = np.ones(shape=output_shape, dtype=np.float32)
input_data = aidge_core.Tensor(input_values)
input_data.set_datatype(aidge_core.dtype.float32)
input_data.set_backend("cpu")
input_axes = aidge_core.Tensor(
np.array(input_axes_to_squeeze, dtype=np.float32)
)
input_axes.set_datatype(aidge_core.dtype.int8)
input_axes.set_backend("cpu")
test_squeeze_op.set_input(0, input_data)
test_squeeze_op.set_input(1, input_axes)
self.assertEqual(test_squeeze_op.forward_dims(True), True)
test_squeeze_op.forward()
squeeze_output = test_squeeze_op.get_output(0)
npt.assert_array_equal(
squeeze_output.dims(),
output_shape,
err_msg=f"SQUEEZE FAILURE : expected result differs from output size\n\toperator : {test_squeeze_node}\n\tinput.shape : {input_shape.shape}",
)
npt.assert_array_almost_equal(
np.array(squeeze_output, dtype=np.float32),
output_values,
7,
err_msg=f"SQUEEZE FAILURE : output tensor values differs from expected values\n\toperator : {test_squeeze_node}\n\tinput.shape : {input_shape.shape}",
)
# self.assertEqual(test_squeeze_op.dims_forwarded(), True, "SQUEEZE_FAILURE : dims_forwarded failed.")
return
def test_axes_defined_via_attribute(self):
Log.notice("\ntest_axes_defined_via_attribute")
for index, (input_shape, squeeze_node_template, output_shape) in enumerate(
self.tests_axes_defined_by_attribute
):
test_squeeze_node = squeeze_node_template
test_squeeze_op = test_squeeze_node.get_operator()
print(f"\nTest {index}")
print(f"input size : {input_shape.shape}")
print(f"operator : {test_squeeze_node}")
print(f"expected output_shape : {output_shape}")
test_squeeze_node.get_operator().set_backend("cpu")
input_values = np.ones(shape=input_shape, dtype=np.float32)
output_values = np.ones(shape=output_shape, dtype=np.float32)
input_data = aidge_core.Tensor(input_values)
input_data.set_datatype(aidge_core.dtype.float32)
input_data.set_backend("cpu")
test_squeeze_op.set_input(0, input_data)
test_squeeze_op.forward_dims()
test_squeeze_op.forward()
squeeze_output = test_squeeze_op.get_output(0)
npt.assert_array_equal(
squeeze_output.dims(),
output_shape,
err_msg=f"SQUEEZE FAILURE : expected result differs from output size\n\toperator : {test_squeeze_node}\n\tinput.shape : {input_shape.shape}",
)
npt.assert_array_almost_equal(
np.array(squeeze_output, dtype=np.float32),
output_values,
7,
err_msg=f"SQUEEZE FAILURE : output tensor values differs from expected values\n\toperator : {test_squeeze_node}\n\tinput.shape : {input_shape.shape}",
)
return
def test_error(self):
for input_shape, squeeze_node_template in self.test_error:
test_squeeze_node = squeeze_node_template
test_squeeze_op = test_squeeze_node.get_operator()
input_values = np.ones(shape=input_shape)
input_data = aidge_core.Tensor(input_values)
input_data.set_datatype(aidge_core.dtype.float32)
input_data.set_backend("cpu")
test_squeeze_op.set_input(0, input_data)
with self.assertRaises((RuntimeError, AssertionError)):
test_squeeze_op.forward_dims()
test_squeeze_op.forward()
return
if __name__ == "__main__":
unittest.main()
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SQUEEZE_H_
#define AIDGE_CORE_OPERATOR_SQUEEZE_H_
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @brief implementation of the operator squeeze.
* @note Since this operator implementation is agnostic to the backend it is
* located here instead of in aidge_backend_cpu/cuda.
*/
class Squeeze_OpImpl : public OperatorImpl {
public:
Squeeze_OpImpl(const Operator &op, const std::string &backend = "")
: OperatorImpl(op, backend) {}
void forward() override;
};
enum class SqueezeAttr {
/**
* @brief axes to squeeze, if left empty all 1 sized
* dimensions will be removed.
*/
Axes
};
/**
* @brief This operator has as purpose to remove dummy dimensions around given
* axes.
* input#0 : Tensor to squeeze
* input#1 Optionnal : 1D tensor that lists the axes to squeeze
* @note the axes to squeeze can either be given via attribute or via input #1,
* for the sake of simplicity of the example unders, the axes to squeeze are
* given via attribute
* @example Calling squeeze(1) on a tensor of dimensions (2,1,3,4) will result
* in a tensor of dim (2,3,4).
* @example Calling squeeze(1) on a tensor of dimensions (1,2,3,4) will result
* in a tensor of dim (1,2,3,4).
* @example Calling squeeze() with no argument will result in the removal of
* every 1-sized dimension in the tensor.
*/
class Squeeze_Op
: public OperatorTensor,
public Registrable<Squeeze_Op, std::string,
std::shared_ptr<OperatorImpl>(const Squeeze_Op &)> {
public:
static const std::string
Type; // name of the type of the operation (Here "Squeeze")
private:
using Attributes_ = StaticAttributes<SqueezeAttr, std::vector<int8_t>>;
template <SqueezeAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
/**
* @brief constructor for Squeeze op
* @param[in] axes around which perform the operation
*/
Squeeze_Op(const std::vector<int8_t> &axes = {})
: OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData},
1),
mAttributes(
std::make_shared<Attributes_>(attr<SqueezeAttr::Axes>(axes))) {
mImpl = std::make_shared<Squeeze_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
Squeeze_Op(const Squeeze_Op &op)
: OperatorTensor(op), mAttributes(op.mAttributes) {
if (!op.backend().empty()) {
SET_IMPL_MACRO(Squeeze_Op, *this, op.backend());
} else {
mImpl = std::make_shared<Squeeze_OpImpl>(*this);
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::MatMul_Op
*/
std::shared_ptr<Operator> clone() const override final {
return std::make_shared<Squeeze_Op>(*this);
}
/**
* @brief Compute dimensions for the output Tensor
*/
bool forwardDims(bool allowDataDependency = false) override final;
bool dimsForwarded() const override final;
void setBackend(const std::string &name,
DeviceIdx_t device = 0) override final;
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
/**
* @brief axes to squeeze, if left empty all 1 sized
* dimensions will be removed.
*/
inline std::vector<int8_t> &axes() const noexcept {
return mAttributes->template getAttr<SqueezeAttr::Axes>();
}
static const std::vector<std::string> getInputsName() {
return {"data_input", "axes_to_squeeze"};
}
static const std::vector<std::string> getOutputsName() {
return {"squeezed"};
}
};
// helper with C-style array instead of std::array for kernel_dims to allow
// automatic template DIM deduction
inline std::shared_ptr<Node> Squeeze(const std::vector<int8_t> axes = {},
const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<Squeeze_Op>(axes), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SqueezeAttr>::data[] = {"Axes"};
}
#endif // AIDGE_CORE_OPERATOR_SQUEEZE_H_
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <pybind11/pybind11.h>
#include <string>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Squeeze.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11;
namespace Aidge {
void init_Squeeze(py::module &m) {
py::class_<Squeeze_Op, std::shared_ptr<Squeeze_Op>, OperatorTensor>(
m, "SqueezeOp", py::multiple_inheritance(),
R"mydelimiter(
Initialize squeeze operator
:param axes : axes to squeeze between [-r;r-1]
with r = input_tensor.nbDims()
& r in [-128 , 127]
:type axes : :py:class: List[Int]
)mydelimiter")
.def("get_inputs_name", &Squeeze_Op::getInputsName)
.def("get_outputs_name", &Squeeze_Op::getOutputsName)
.def("axes", &Squeeze_Op::axes);
// Here we bind the constructor of the Squeeze Node. We add an argument
// for each attribute of the operator (in here we only have 'axes') and
// the last argument is the node's name.
m.def("Squeeze", &Squeeze, py::arg("axes") = std::vector<int8_t>({}),
py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a squeeze operator.
:param axes : axes to squeeze between [-r;r-1]
with r = input_tensor.nbDims()
& r in [-128 , 127]
:type axes : :py:class: List[Int]
:param name : name of the node.
)mydelimiter");
}
} // namespace Aidge
...@@ -63,6 +63,7 @@ void init_Slice(py::module&); ...@@ -63,6 +63,7 @@ void init_Slice(py::module&);
void init_Softmax(py::module&); void init_Softmax(py::module&);
void init_Split(py::module&); void init_Split(py::module&);
void init_Sqrt(py::module&); void init_Sqrt(py::module&);
void init_Squeeze(py::module&);
void init_Sub(py::module&); void init_Sub(py::module&);
void init_Tanh(py::module&); void init_Tanh(py::module&);
void init_Transpose(py::module&); void init_Transpose(py::module&);
...@@ -138,6 +139,7 @@ void init_Aidge(py::module& m) { ...@@ -138,6 +139,7 @@ void init_Aidge(py::module& m) {
init_Softmax(m); init_Softmax(m);
init_Split(m); init_Split(m);
init_Sqrt(m); init_Sqrt(m);
init_Squeeze(m);
init_Sub(m); init_Sub(m);
init_Tanh(m); init_Tanh(m);
init_Transpose(m); init_Transpose(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Squeeze.hpp"
#include <algorithm>
#include <bitset>
#include <cstdint>
#include <fmt/core.h>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
const std::string Squeeze_Op::Type = "Squeeze";
bool Squeeze_Op::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->undefined())) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Squeeze_Op::forwardDims(bool allowDataDependency) {
// error checking
if (!inputsAssociated(false) || getInput(0)->undefined()) {
return false;
}
std::shared_ptr<Tensor> fallback;
// Input 1 is axes to squeeze (can also be given via attribute)
if (getInput(1)) {
if (!this->axes().empty()) {
Log::notice("{} : ignoring non-empty axes attribute because input#1 "
"takes precedence",
type());
}
if (!allowDataDependency) {
Log::warn("{} : unable to forwardDims() because output dims are data "
"dependent on input#1",
type());
return false;
}
this->axes().clear(); // If both are provided input would override attrs
this->axes().reserve(getInput(1)->size());
const auto &axes =
getInput(1)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
if (axes.nbDims() == 0) {
this->axes().clear();
} else {
AIDGE_ASSERT(
axes.nbDims() == 1,
"Axes input tensor should be of size 1. Received {} dimensions : {}",
axes.nbDims(), axes.dims());
std::copy_n(static_cast<int8_t *>(axes.getImpl()->hostPtr()), axes.size(),
std::back_inserter(this->axes()));
}
}
std::vector<DimSize_t> input_dims = getInput(0)->dims();
std::vector<DimSize_t> output_dims;
output_dims.reserve(input_dims.size());
std::vector<DimIdx_t> axes_rectified_idx;
axes_rectified_idx.reserve(input_dims.size());
if (this->axes().size() == 0) { // squeeze() => squeeze all 1 sized dimensions
Log::debug("this->axes() is empty, all 1 sized dim will be squeezed. If "
"this is an error ensure that the values are properly set via "
"attribute or data input#1.");
std::copy_if(input_dims.begin(), input_dims.end(),
std::back_inserter(output_dims),
[](DimSize_t dim) { return dim != 1; });
} else { // squeeze({N,.....}) => squeeze all specified dimensions that are of
// size 1.
/////// ensure indexes validity and set pythonic negative indexes to their
// positive value
for (const int8_t &axis : this->axes()) {
AIDGE_ASSERT(axis >= static_cast<int8_t>(-input_dims.size()) &&
axis < static_cast<int8_t>(input_dims.size()),
"{} : Axis index OutOfBounds error, expected value "
"within size limits of input tensor : "
"[-{},{}), got {}.",
type(), input_dims.size(), input_dims.size() - 1, axis);
auto temp =
static_cast<DimIdx_t>(axis >= 0 ? axis : axis + input_dims.size());
if (axes_rectified_idx.end() == std::find(axes_rectified_idx.begin(),
axes_rectified_idx.end(),
temp)) {
axes_rectified_idx.push_back(temp);
}
}
// Create output_dims
// speeds up binary search
std::sort(axes_rectified_idx.begin(), axes_rectified_idx.end());
DimSize_t i = 0;
std::copy_if(
input_dims.begin(), input_dims.end(), std::back_inserter(output_dims),
[&axes_rectified_idx, &i, &input_dims](DimSize_t dim) {
// if current dim index is found in axes to squeeze
// we ensure that this axis is 1 sized, otherwise an error is thrown
bool ok = true;
if (std::binary_search(axes_rectified_idx.begin(),
axes_rectified_idx.end(), i)) {
AIDGE_ASSERT(dim == 1,
"{} : Tried to squeeze axis nb {} of a tensor of dim "
"{}. Dim to squeeze has to be 1-sized, got size {}."
"Axes to squeeze : {}",
__func__, i, input_dims, input_dims[i],
axes_rectified_idx);
ok = false;
}
i++; // Incrementing counter since there is no enumerate
// fctn (until C++23)
return ok;
});
}
mOutputs[0]->resize(output_dims);
return true;
}
void Squeeze_Op::setBackend(const std::string &name,
Aidge::DeviceIdx_t device) {
if (Registrar<Squeeze_Op>::exists({name})) {
SET_IMPL_MACRO(Squeeze_Op, *this, name);
} else {
mImpl = std::make_shared<Squeeze_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
void Aidge::Squeeze_OpImpl::forward() {
const Squeeze_Op &op_ = static_cast<const Squeeze_Op &>(mOp);
// Check if input is provided
AIDGE_ASSERT(op_.getInput(0), "Squeeze : missing input 0");
op_.getOutput(0)->getImpl()->copy(op_.getInput(0)->getImpl()->rawPtr(),
op_.getInput(0)->size());
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Squeeze.hpp"
#include <aidge/utils/Types.h>
#include <algorithm>
#include <array>
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators_random.hpp>
#include <chrono>
#include <cmath>
#include <cstddef> // std::size_t
#include <cstdint> // std::uint16_t
#include <fmt/core.h>
#include <iostream>
#include <iterator>
#include <memory>
#include <numeric> // std::accumulate
#include <ostream>
#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/TensorUtils.hpp"
namespace Aidge {
TEST_CASE("[core/operator] Squeeze(forwardDims)", "[Squeeze][forwardDims]") {
Log::setConsoleLevel(Log::Notice);
constexpr std::uint16_t NB_TRIALS = 10;
// Create a random number generator
auto random_seed = Catch::Generators::Detail::getSeed;
std::mt19937 gen(random_seed());
// Random float distribution between 0 and 1
constexpr int8_t max_nb_dims = 7;
std::uniform_real_distribution<float> tensor_value_dist(0.1f, 1.1f);
std::uniform_int_distribution<std::size_t> tensor_nb_dims_dist(
std::size_t(1), std::size_t(max_nb_dims));
std::uniform_int_distribution<std::size_t> tensor_dims_size_dist(
std::size_t(1), std::size_t(5));
std::uniform_int_distribution<std::size_t> nb_dims_to_squeeze_dist(
std::size_t(1), std::size_t(2));
std::uniform_int_distribution<short> idx_dims_to_squeeze_dist(-9, 8);
std::shared_ptr<Tensor> input_T = std::make_shared<Tensor>();
SECTION("ERROR : Inputs not ready") {
SECTION("unconnected input") {
std::shared_ptr<Node> squeeze_node = Squeeze();
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
REQUIRE_THROWS(op->forwardDims());
}
SECTION("empty tensor") {
// Create the Squeeze Operator
std::shared_ptr<Node> squeeze_node = Squeeze(std::vector<int8_t>({0}));
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
op->associateInput(0, input_T);
CHECK(op->forwardDims() == false);
}
}
SECTION("ERROR : nb_dims_to_squeeze>input.size()") {
constexpr size_t nb_dims_to_squeeze = 100;
std::vector<int8_t> dims_to_squeeze(nb_dims_to_squeeze);
std::generate(dims_to_squeeze.begin(), dims_to_squeeze.end(),
[&gen, &idx_dims_to_squeeze_dist]() {
return idx_dims_to_squeeze_dist(gen);
});
Log::error("dims_to_sqeeze = {}", dims_to_squeeze);
std::shared_ptr<Node> squeeze_node = Squeeze(dims_to_squeeze);
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
// input tensor
const std::size_t nb_dims = tensor_nb_dims_dist(gen);
std::vector<std::size_t> dims_in(nb_dims);
std::generate(dims_in.begin(), dims_in.end(),
[&tensor_dims_size_dist, &gen]() {
return tensor_dims_size_dist(gen);
});
// Test
input_T->resize(dims_in);
op->setInput(0, input_T);
REQUIRE_THROWS(op->forwardDims());
}
SECTION("Compare with reference output") {
SECTION("axes is given via attribute") {
SECTION("Squeeze a 1-sized-axis") {
int8_t nb_dims = 4;
std::shared_ptr<Node> squeeze_node = Squeeze(std::vector<int8_t>({0}));
auto op = std::static_pointer_cast<OperatorTensor>(
squeeze_node->getOperator());
op->associateInput(0, input_T);
std::vector<DimSize_t> dims_in{1, 2, 3, 4};
input_T->resize(dims_in);
CHECK(op->forwardDims());
CHECK(op->getOutput(0)->dims() == std::vector<DimSize_t>({2, 3, 4}));
CHECK((op->getOutput(0)->dims().size()) == 3);
}
SECTION("Squeeze multiple 1-sized axes") {
// test should be successful
std::shared_ptr<Node> squeeze_node =
Squeeze(std::vector<int8_t>({1, -4}));
auto op = std::static_pointer_cast<OperatorTensor>(
squeeze_node->getOperator());
op->associateInput(0, input_T);
std::vector<DimSize_t> dims_in{1, 1, 13, 200};
input_T->resize(dims_in);
CHECK(op->forwardDims());
CHECK(op->getOutput(0)->dims() == std::vector<DimSize_t>{13, 200});
CHECK((op->getOutput(0)->dims().size()) == 2);
}
SECTION("Squeeze a non-1-Sized axis") {
int8_t nb_dims = 4;
std::shared_ptr<Node> squeeze_node = Squeeze(std::vector<int8_t>({3}));
auto op = std::static_pointer_cast<OperatorTensor>(
squeeze_node->getOperator());
op->associateInput(0, input_T);
std::vector<DimSize_t> dims_in{1, 2, 3, 4};
input_T->resize(dims_in);
REQUIRE_THROWS(op->forwardDims());
}
SECTION("Squeeze multiple non-sized-axes") {
std::shared_ptr<Node> squeeze_node =
Squeeze(std::vector<int8_t>({1, -2}));
auto op = std::static_pointer_cast<OperatorTensor>(
squeeze_node->getOperator());
op->associateInput(0, input_T);
std::array<DimSize_t, 3> dims_in{2, 3, 4};
input_T->resize(dims_in);
REQUIRE_THROWS((op->forwardDims()));
}
}
SECTION("axes is given via tensor") {
SECTION("tensor is empty") {
// arguments here should be overriden by axes_T values
std::shared_ptr<Node> myUnsqueeze =
Squeeze(std::vector<std::int8_t>({0, 4}));
auto op = std::static_pointer_cast<OperatorTensor>(
myUnsqueeze->getOperator());
op->associateInput(0, input_T);
auto axes_T =
std::make_shared<Aidge::Tensor>(std::vector<DimSize_t>({}));
axes_T->setDataType(Aidge::DataType::Int8);
axes_T->setBackend("cpu");
std::vector<DimSize_t> dims_in{3, 1, 4, 1, 1, 5};
input_T->resize(dims_in);
op->associateInput(0, input_T);
op->associateInput(1, axes_T);
CHECK(op->forwardDims(true));
CHECK(op->getOutput(0)->dims() == std::vector<DimSize_t>({3, 4, 5}));
}
SECTION("tensor not empty") {
// arguments here should be overriden by axes_T values
std::shared_ptr<Node> myUnsqueeze =
Squeeze(std::vector<std::int8_t>({3, 1}));
auto op = std::static_pointer_cast<OperatorTensor>(
myUnsqueeze->getOperator());
op->associateInput(0, input_T);
auto axes_T =
std::make_shared<Aidge::Tensor>(Aidge::Array1D<int8_t, 2>({0, 3}));
axes_T->setDataType(Aidge::DataType::Int8);
axes_T->setBackend("cpu");
std::vector<DimSize_t> dims_in{1, 3, 4, 1, 5};
input_T->resize(dims_in);
op->associateInput(0, input_T);
op->associateInput(1, axes_T);
CHECK(op->forwardDims(true) == true);
CHECK(op->getOutput(0)->dims() == std::vector<DimSize_t>({3, 4, 5}));
}
}
}
SECTION("Squeeze()") {
// Create the Operator
std::shared_ptr<Node> squeeze_node = Squeeze();
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
op->associateInput(0, input_T);
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
// input tensor
const std::size_t nb_dims = tensor_nb_dims_dist(gen);
std::vector<std::size_t> dims_in(nb_dims);
std::generate(dims_in.begin(), dims_in.end(),
[&gen, &tensor_dims_size_dist]() {
return tensor_dims_size_dist(gen);
});
// output tensor
std::vector<DimSize_t> dims_out;
dims_out.reserve(dims_in.size());
std::copy_if(dims_in.begin(), dims_in.end(), std::back_inserter(dims_out),
[](DimSize_t dim) { return dim != 1; });
// Test
input_T->resize(dims_in);
op->setInput(0, input_T);
CHECK(op->forwardDims() == true);
CHECK(op->getOutput(0)->dims() == dims_out);
int nb_ones = std::count_if(dims_in.begin(), dims_in.end(),
[](int8_t dim) { return dim == 1; });
CHECK((op->getInput(0)->dims().size() -
op->getOutput(0)->dims().size()) == nb_ones);
}
}
SECTION("Squeeze({N,...})") {
int number_of_operation{0};
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
// Create the Operator
size_t nb_dims_to_squeeze = nb_dims_to_squeeze_dist(gen);
std::vector<int8_t> dims_to_squeeze(nb_dims_to_squeeze);
std::generate(dims_to_squeeze.begin(), dims_to_squeeze.end(),
[&gen, &idx_dims_to_squeeze_dist]() {
return idx_dims_to_squeeze_dist(gen);
});
std::shared_ptr<Node> squeeze_node = Squeeze({dims_to_squeeze});
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
op->associateInput(0, input_T);
// input tensor
const std::size_t nb_dims_tensor = tensor_nb_dims_dist(gen);
std::vector<std::size_t> dims_in(nb_dims_tensor);
std::generate(dims_in.begin(), dims_in.end(),
[&gen, &tensor_dims_size_dist]() {
return tensor_dims_size_dist(gen);
});
input_T->resize(dims_in);
op->setInput(0, input_T);
// rectifying indexes
std::transform(dims_to_squeeze.begin(), dims_to_squeeze.end(),
dims_to_squeeze.begin(),
[&nb_dims_tensor](int8_t dim_to_squeeze) {
return dim_to_squeeze < 0
? dim_to_squeeze + nb_dims_tensor
: dim_to_squeeze;
});
std::sort(dims_to_squeeze.begin(), dims_to_squeeze.end());
auto it = std::unique(dims_to_squeeze.begin(), dims_to_squeeze.end());
dims_to_squeeze.erase(it, dims_to_squeeze.end());
// ensuring arguments given to Squeeze are good
bool not_in_bounds = false;
bool dim_to_squeeze_not_1_sized = false;
for (const auto dim_to_squeeze : dims_to_squeeze) {
not_in_bounds = dim_to_squeeze >= nb_dims_tensor;
if (not_in_bounds) {
break;
}
dim_to_squeeze_not_1_sized = dims_in.at(dim_to_squeeze) != 1;
if (dim_to_squeeze_not_1_sized) {
break;
}
}
if (nb_dims_tensor > max_nb_dims || not_in_bounds ||
dim_to_squeeze_not_1_sized) {
REQUIRE_THROWS(op->forwardDims());
} else {
// output tensor
int i = 0;
std::vector<DimSize_t> dims_out;
dims_out.reserve(dims_in.size());
std::copy_if(dims_in.begin(), dims_in.end(),
std::back_inserter(dims_out),
[&dims_to_squeeze, &i](DimSize_t dim) {
bool ok = dim != 1 ||
!std::binary_search(dims_to_squeeze.begin(),
dims_to_squeeze.end(), i);
i++; // incrementing counter since C++ has not enumerate
// fctn (until C++23)
return ok;
});
CHECK(op->forwardDims() == true);
CHECK(op->getOutput(0)->dims() == dims_out);
}
}
}
}
TEST_CASE("[core/operator] Squeeze(forward)", "[Squeeze][forward]") {
Log::setConsoleLevel(Log::Notice);
constexpr std::uint16_t NB_TRIALS = 10;
// Create a random number generator
auto random_seed = Catch::Generators::Detail::getSeed;
std::mt19937 gen(random_seed());
constexpr int8_t max_nb_dims = 7;
std::uniform_real_distribution<float> tensor_value_dist(0.1f, 1.1f);
std::uniform_int_distribution<std::size_t> tensor_nb_dims_dist(
std::size_t(1), std::size_t(max_nb_dims));
std::uniform_int_distribution<std::size_t> tensor_dims_size_dist(
std::size_t(1), std::size_t(5));
std::uniform_int_distribution<std::size_t> nb_dims_to_squeeze_dist(
std::size_t(1), std::size_t(2));
std::uniform_int_distribution<short> idx_dims_to_squeeze_dist(-9, 8);
std::shared_ptr<Tensor> input_T = std::make_shared<Tensor>();
// BENCHMARKING
std::chrono::time_point<std::chrono::system_clock> start;
std::chrono::time_point<std::chrono::system_clock> end;
std::chrono::duration<double, std::micro> duration{};
Log::setConsoleLevel(Log::Notice);
int number_of_operation{0};
for (uint16_t trial = 0; trial < NB_TRIALS; ++trial) {
// Create the Operator
size_t nb_dims_to_squeeze = nb_dims_to_squeeze_dist(gen);
std::vector<int8_t> dims_to_squeeze(nb_dims_to_squeeze);
std::generate(dims_to_squeeze.begin(), dims_to_squeeze.end(),
[&gen, &idx_dims_to_squeeze_dist]() {
return idx_dims_to_squeeze_dist(gen);
});
std::shared_ptr<Node> squeeze_node = Squeeze({dims_to_squeeze});
auto op =
std::static_pointer_cast<OperatorTensor>(squeeze_node->getOperator());
op->setDataType(DataType::Float32);
op->setBackend("cpu");
// input tensor
const std::size_t nb_dims_tensor = tensor_nb_dims_dist(gen);
std::vector<std::size_t> dims_in(nb_dims_tensor);
std::generate(dims_in.begin(), dims_in.end(),
[&gen, &tensor_dims_size_dist]() {
return tensor_dims_size_dist(gen);
});
input_T->resize(dims_in);
op->setInput(0, input_T);
// rectifying indexes
std::transform(dims_to_squeeze.begin(), dims_to_squeeze.end(),
dims_to_squeeze.begin(),
[&nb_dims_tensor](int8_t dim_to_squeeze) {
return dim_to_squeeze < 0 ? dim_to_squeeze + nb_dims_tensor
: dim_to_squeeze;
});
// ensuring arguments given to Squeeze are good
bool not_in_bounds = false;
bool dim_to_squeeze_not_1_sized = false;
for (const auto dim_to_squeeze : dims_to_squeeze) {
not_in_bounds = dim_to_squeeze >= nb_dims_tensor;
if (not_in_bounds) {
break;
}
dim_to_squeeze_not_1_sized = dims_in.at(dim_to_squeeze) != 1;
if (dim_to_squeeze_not_1_sized) {
break;
}
}
if (nb_dims_tensor > max_nb_dims || not_in_bounds ||
dim_to_squeeze_not_1_sized) {
REQUIRE_THROWS(op->forwardDims());
} else {
// output tensor
int i = 0;
std::vector<DimSize_t> dims_out;
dims_out.reserve(dims_in.size());
for (DimIdx_t i = 0; i < dims_in.size(); ++i) {
if (dims_in[i] == 1 &&
std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) !=
dims_to_squeeze.end()) {
continue;
}
dims_out.push_back(dims_in[i]);
}
CHECK(op->forwardDims());
CHECK(op->getOutput(0)->dims() == dims_out);
SECTION("forward") {
// Create the input Tensor
std::shared_ptr<Tensor> input_T = std::make_shared<Tensor>();
input_T->setDataType(DataType::Float32);
input_T->setBackend("cpu");
op->associateInput(0, input_T);
// Create results Tensor
std::shared_ptr<Tensor> result_T = std::make_shared<Tensor>();
result_T->setDataType(DataType::Float32);
result_T->setBackend("cpu");
const std::size_t nb_elems =
std::accumulate(dims_in.cbegin(), dims_in.cend(), std::size_t(1),
std::multiplies<std::size_t>());
float *array_in = new float[nb_elems];
for (std::size_t i = 0; i < nb_elems; ++i) {
float val = tensor_value_dist(gen);
array_in[i] = val;
}
number_of_operation += nb_elems; // Copying all values : 1
// assignation / item in the tensor
// input0
input_T->resize(dims_in);
input_T->getImpl()->setRawPtr(array_in, nb_elems);
result_T->resize(dims_out);
result_T->getImpl()->setRawPtr(array_in, nb_elems);
CHECK(op->forwardDims() == true);
start = std::chrono::system_clock::now();
REQUIRE_NOTHROW(squeeze_node->forward());
end = std::chrono::system_clock::now();
duration +=
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
CHECK(approxEq<float>(*result_T, *(op->getOutput(0))));
CHECK(result_T->nbDims() == op->getOutput(0)->nbDims());
for (DimSize_t i = 0; i < op->getOutput(0)->nbDims(); ++i) {
CHECK(result_T->dims().at(i) == op->getOutput(0)->dims().at(i));
}
CHECK(approxEq<float>(*result_T, *(op->getOutput(0))));
delete[] array_in;
}
std::cout << "Squeeze total execution time : " << duration.count() << "µs"
<< std::endl;
std::cout << "Number of operations : " << number_of_operation
<< std::endl;
std::cout << "Operation / µs = " << number_of_operation / duration.count()
<< std::endl;
}
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment