Skip to content
Snippets Groups Projects
Commit de1a11b9 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'Add_Shape_and_Split' into 'dev'

Add Shape and Split operators

See merge request eclipse/aidge/aidge_core!134
parents 5aa6d261 840a775e
No related branches found
No related tags found
No related merge requests found
...@@ -59,9 +59,11 @@ ...@@ -59,9 +59,11 @@
#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Reshape.hpp" #include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/operator/Scaling.hpp" #include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Slice.hpp" #include "aidge/operator/Slice.hpp"
#include "aidge/operator/Softmax.hpp" #include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Split.hpp"
#include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sqrt.hpp"
#include "aidge/operator/Sub.hpp" #include "aidge/operator/Sub.hpp"
#include "aidge/operator/Transpose.hpp" #include "aidge/operator/Transpose.hpp"
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SHAPE_H_
#define AIDGE_CORE_OPERATOR_SHAPE_H_
#include <cstdint> // std::int64_t
#include <memory>
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Shape_OpImpl : public OperatorImpl {
public:
Shape_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class ShapeAttr { Start, End };
class Shape_Op : public OperatorTensor,
public Registrable<Shape_Op,
std::string,
std::shared_ptr<OperatorImpl>(const Shape_Op&)>,
public StaticAttributes<ShapeAttr, std::int64_t, std::int64_t> {
public:
static const std::string Type;
Shape_Op() = delete;
using Attributes_ = StaticAttributes<ShapeAttr, std::int64_t, std::int64_t>;
template <ShapeAttr e> using attr = typename Attributes_::template attr<e>;
Shape_Op(std::int64_t start, std::int64_t end)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ShapeAttr::Start>(start),
attr<ShapeAttr::End>(end))
{
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Shape_Op(const Shape_Op& op)
: OperatorTensor(op),
Attributes_(op)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Shape_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Shape_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Shape_Op>(*this);
}
bool forwardDims(bool /*allowDataDependency*/ = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Shape(std::int64_t start = 0, std::int64_t end = -1, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Shape_Op>(start, end), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ShapeAttr>::data[] = {"Start", "End"};
}
#endif /* AIDGE_CORE_OPERATOR_SHAPE_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SPLIT_H_
#define AIDGE_CORE_OPERATOR_SPLIT_H_
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Split_OpImpl : public OperatorImpl {
public:
Split_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class SplitAttr { Axis, Split };
class Split_Op
: public OperatorTensor,
public Registrable<Split_Op, std::string, std::shared_ptr<OperatorImpl>(const Split_Op &)>,
public StaticAttributes<SplitAttr, std::int8_t, std::vector<DimSize_t>> {
public:
static const std::string Type;
Split_Op() = delete;
using Attributes_ = StaticAttributes<SplitAttr, std::int8_t, std::vector<DimSize_t>>;
template <SplitAttr e> using attr = typename Attributes_::template attr<e>;
Split_Op( std::int8_t axis, DimSize_t nbOutputs, const std::vector<DimSize_t>& split)
: OperatorTensor(Type, 2, 0, nbOutputs),
Attributes_(attr<SplitAttr::Axis>(axis),
attr<SplitAttr::Split>(split))
{
mImpl = std::make_shared<Split_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its
* input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Split_Op(const Split_Op &op)
: OperatorTensor(op),
Attributes_(op)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Split_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Split_OpImpl>(*this);
}
}
public:
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Split_Op
*/
std::shared_ptr<Operator> clone() const override { return std::make_shared<Split_Op>(*this); }
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){
return {"data_input", "split"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output_0", "data_output_n"};
}
};
/**
* @brief Exract a sub-Tensor from a bigger original Tensor.
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Split(DimSize_t nbOutput,
std::int8_t axis = 0,
const std::vector<DimSize_t>& split = {},
const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<Split_Op>(axis, nbOutput, split), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::SplitAttr>::data[] = { "Axis", "Split" };
}
#endif /* AIDGE_CORE_OPERATOR_SPLIT_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Shape(py::module& m) {
py::class_<Shape_Op, std::shared_ptr<Shape_Op>, Attributes, OperatorTensor>(m, "ShapeOp", py::multiple_inheritance())
.def(py::init<std::int64_t,
std::int64_t>(),
py::arg("start"),
py::arg("end"))
.def_static("get_inputs_name", &Shape_Op::getInputsName)
.def_static("get_outputs_name", &Shape_Op::getOutputsName)
.def_static("attributes_name", &Shape_Op::staticGetAttrsName);
declare_registrable<Shape_Op>(m, "ShapeOp");
m.def("Shape", &Shape, py::arg("start") = 0, py::arg("end") = -1, py::arg("name") = "");
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Split.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Split(py::module& m) {
py::class_<Split_Op, std::shared_ptr<Split_Op>, Attributes, OperatorTensor>(m, "SplitOp", py::multiple_inheritance())
.def(py::init<DimSize_t, std::int8_t, std::vector<DimSize_t>&>(),
py::arg("nb_outputs"),
py::arg("axis"),
py::arg("split"))
.def_static("get_inputs_name", &Split_Op::getInputsName)
.def_static("get_outputs_name", &Split_Op::getOutputsName)
.def_static("attributes_name", &Split_Op::staticGetAttrsName);
declare_registrable<Split_Op>(m, "SplitOp");
m.def("Split", &Split, py::arg("nb_outputs"), py::arg("axis") = 0, py::arg("split") = std::vector<DimSize_t>(), py::arg("name") = "");
}
} // namespace Aidge
...@@ -52,9 +52,11 @@ void init_ReduceMean(py::module&); ...@@ -52,9 +52,11 @@ void init_ReduceMean(py::module&);
void init_ReLU(py::module&); void init_ReLU(py::module&);
void init_Reshape(py::module&); void init_Reshape(py::module&);
void init_Scaling(py::module&); void init_Scaling(py::module&);
void init_Shape(py::module&);
void init_Sigmoid(py::module&); void init_Sigmoid(py::module&);
void init_Slice(py::module&); void init_Slice(py::module&);
void init_Softmax(py::module&); void init_Softmax(py::module&);
void init_Split(py::module&);
void init_Sqrt(py::module&); void init_Sqrt(py::module&);
void init_Sub(py::module&); void init_Sub(py::module&);
void init_Tanh(py::module&); void init_Tanh(py::module&);
...@@ -120,9 +122,11 @@ void init_Aidge(py::module& m) { ...@@ -120,9 +122,11 @@ void init_Aidge(py::module& m) {
init_ReLU(m); init_ReLU(m);
init_Reshape(m); init_Reshape(m);
init_Scaling(m); init_Scaling(m);
init_Shape(m);
init_Sigmoid(m); init_Sigmoid(m);
init_Slice(m); init_Slice(m);
init_Softmax(m); init_Softmax(m);
init_Split(m);
init_Sqrt(m); init_Sqrt(m);
init_Sub(m); init_Sub(m);
init_Tanh(m); init_Tanh(m);
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <string>
#include <vector>
#include "aidge/operator/Shape.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::Shape_OpImpl::forward() {
const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp);
const auto start = op.template getAttr<std::int64_t>("Start");
const auto end = op.template getAttr<std::int64_t>("End");
op.getOutput(0)->getImpl()->copyCast(std::next(op.getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
}
const std::string Aidge::Shape_Op::Type = "Shape";
bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
// check data input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (getInput(0)->empty()) {
return false;
}
if (this->template getAttr<std::int64_t>("Start") < 0)
this->template getAttr<std::int64_t>("Start") += static_cast<std::int64_t>(getInput(0)->nbDims());
if (this->template getAttr<std::int64_t>("End") < 0)
this->template getAttr<std::int64_t>("End") += static_cast<std::int64_t>(getInput(0)->nbDims());
const auto start = this->template getAttr<std::int64_t>("Start");
const auto end = this->template getAttr<std::int64_t>("End");
const auto nbDims = static_cast<std::int64_t>(getInput(0)->nbDims());
const DimSize_t roi = end - start + 1;
AIDGE_ASSERT(start < nbDims && end < nbDims, "'Start' and 'End' must be < {}", nbDims);
AIDGE_ASSERT(roi> 1, "Unvalid ROI for Shape");
mOutputs[0]->resize({roi});
return true;
}
void Aidge::Shape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Shape_Op>::exists({name})) {
SET_IMPL_MACRO(Shape_Op, *this, name);
}
else {
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Split.hpp"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include <vector>
#include <fmt/format.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::Split_OpImpl::forward() {
const Split_Op& op = dynamic_cast<const Split_Op&>(mOp);
const auto axis = op.template getAttr<std::int8_t>("Axis");
const auto splits = op.template getAttr<std::vector<DimSize_t>>("Split");
const auto dims = op.getInput(0)->dims();
//Compute pre/post axis strides
const std::size_t stride_pre = std::accumulate(dims.cbegin(), dims.cbegin() + axis, 1, std::multiplies<std::size_t>());
const std::size_t stride_post = std::accumulate(dims.crbegin(), dims.crbegin() + dims.size() -1 - axis, 1, std::multiplies<std::size_t>());
for (auto i = 0; i < op.nbOutputs(); ++i)
{
DimIdx_t chunkIdxOnAxis = std::accumulate(splits.cbegin(), splits.cbegin() + i, 0) * stride_post;
DimIdx_t offset = 0;
for (std::size_t j = 0; j < stride_pre; ++j)
{
// Compute chunk position in input tensor
DimIdx_t idx = j * stride_post * dims[axis] + chunkIdxOnAxis;
// Copy chunk in ouput
op.getOutput(i)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(idx),
splits[i] * stride_post, offset);
offset += splits[i] * stride_post;
}
}
}
const std::string Aidge::Split_Op::Type = "Split";
bool Aidge::Split_Op::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->empty()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Split_Op::forwardDims(bool allowDataDependency) {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (getInput(0)->empty()) {
return false;
}
std::shared_ptr<Tensor> fallback;
if (getInput(1) && !getInput(1)->empty()) { // Split is given, replace
if (!this->template getAttr<SplitAttr::Split>().empty()) {
Log::notice("Split_Op: ignoring non-empty Split attribute because input#1 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Split_Op: unable to forwardDims() because output dims are data dependent on input#1");
return false;
}
this->template getAttr<SplitAttr::Split>().reserve(getInput(1)->size());
const auto& splits = getInput(1)->refCastFrom(fallback, NativeType<DimSize_t>::type, "cpu");
std::copy_n(static_cast<DimSize_t*>(splits.getImpl()->hostPtr()),
splits.size(),
std::back_inserter(this->template getAttr<SplitAttr::Split>()));
}
if (this->template getAttr<std::int8_t>("Axis") < 0)
this->template getAttr<std::int8_t>("Axis") += static_cast<std::int8_t>(getInput(0)->nbDims());
DimSize_t dimToSplit = getInput(0)->dims()[this->template getAttr<std::int8_t>("Axis")];
DimSize_t nbOutput = this->nbOutputs();
// Fill Split attr if empty
if(this->template getAttr<SplitAttr::Split>().empty()) {
// In case the input Split is not provided, divide the dimension of Axis into equal slices
AIDGE_ASSERT(dimToSplit > nbOutput, "Split_Op: Output number {} musn't be bigger than dimension {}.", nbOutput, dimToSplit);
DimSize_t baseSliceSize = dimToSplit / nbOutput;
DimSize_t remainder = dimToSplit % nbOutput;
for (DimSize_t i = 0; i < static_cast<DimSize_t>(nbOutput -1); ++i) {
this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize);
}
this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize + remainder);
}
const auto splits = this->template getAttr<SplitAttr::Split>();
AIDGE_ASSERT(splits.size() == nbOutput, "Split_Op: number of slices {} must be equal to number of outputs {}", splits, nbOutput);
DimSize_t totalSplitSize = std::accumulate(splits.cbegin(), splits.cend(), 0);
AIDGE_ASSERT(totalSplitSize == dimToSplit, "Split_Op: Total chunks size {} is different from dimension size {}.", totalSplitSize, dimToSplit);
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbOutput; ++i)
{
outDims[this->template getAttr<std::int8_t>("Axis")] = this->template getAttr<SplitAttr::Split>()[i];
mOutputs[i]->resize(outDims);
}
return true;
}
void Aidge::Split_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Split_Op>::exists({name})) {
SET_IMPL_MACRO(Split_Op, *this, name);
}
else {
mImpl = std::make_shared<Split_OpImpl>(*this);
}
for (std::size_t i = 0; i < this->nbOutputs(); i++)
{
mOutputs[i]->setBackend(name, device);
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Shape.hpp"
#include <cstdint>
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
SECTION("Default attributes") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,4> {
{1, 2, 3, 5}
});
std::shared_ptr<Node> myShape = Shape();
auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myShape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Using attributes") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,2> {
{2, 3}
});
std::shared_ptr<Node> myShape = Shape(1, 2);
auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myShape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Split.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] Split(forward)", "[Split][CPU]") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,1,3,7,2> {
{
{
{{ 1, 2},{ 3, 4},{ 5, 6},{ 7, 8},{ 9, 10},{11, 12},{13, 14}},
{{15, 16},{17, 18},{19, 20},{21, 22},{23, 24},{25, 26},{27, 28}},
{{30, 31},{32, 33},{34, 35},{36, 37},{38, 39},{40, 41},{42, 43}}
}
}
});
SECTION("Default split") {
std::shared_ptr<Tensor> expectedOutput0 = std::make_shared<Tensor>(Array4D<int,1,3,2,2> {
{
{
{{ 1, 2},{ 3, 4}},
{{15, 16},{17, 18}},
{{30, 31},{32, 33}}
}
}
});
std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int,1,3,2,2> {
{
{
{{ 5, 6},{ 7, 8}},
{{19, 20},{21, 22}},
{{34, 35},{36, 37}}
}
}
});
std::shared_ptr<Tensor> expectedOutput2 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
{
{
{{ 9, 10},{11, 12},{13, 14}},
{{23, 24},{25, 26},{27, 28}},
{{38, 39},{40, 41},{42, 43}}
}
}
});
auto mySplit = Split(DimSize_t(3), int8_t(2)); // Split on axis 2 into 3 outputs
mySplit->getOperator()->associateInput(0, input);
mySplit->getOperator()->setBackend("cpu");
mySplit->getOperator()->setDataType(DataType::Int32);
mySplit->forward();
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(0) == *expectedOutput0);
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(1) == *expectedOutput1);
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(2) == *expectedOutput2);
}
SECTION("Split with different chunk size") {
std::shared_ptr<Tensor> expectedOutput0 = std::make_shared<Tensor>(Array4D<int,1,3,4,2> {
{
{
{{ 1, 2},{ 3, 4},{ 5, 6},{ 7, 8}},
{{15, 16},{17, 18},{19, 20},{21, 22}},
{{30, 31},{32, 33},{34, 35},{36, 37}}
}
}
});
std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int,1,3,1,2> {
{
{
{{ 9, 10}},
{{23, 24}},
{{38, 39}}
}
}
});
std::shared_ptr<Tensor> expectedOutput2 = std::make_shared<Tensor>(Array4D<int,1,3,2,2> {
{
{
{{11, 12},{13, 14}},
{{25, 26},{27, 28}},
{{40, 41},{42, 43}}
}
}
});
auto mySplit = Split(DimSize_t(3), int8_t(2), {DimSize_t(4), DimSize_t(1), DimSize_t(2)});
mySplit->getOperator()->associateInput(0, input);
mySplit->getOperator()->setBackend("cpu");
mySplit->getOperator()->setDataType(DataType::Int32);
mySplit->forward();
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(0) == *expectedOutput0);
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(1) == *expectedOutput1);
REQUIRE(*std::static_pointer_cast<OperatorTensor>(mySplit->getOperator())->getOutput(2) == *expectedOutput2);
}
SECTION("Split with bad split attribute") {
auto mySplit = Split(DimSize_t(3), int8_t(2), {DimSize_t(4), DimSize_t(1), DimSize_t(3)});
mySplit->getOperator()->associateInput(0, input);
mySplit->getOperator()->setBackend("cpu");
mySplit->getOperator()->setDataType(DataType::Int32);
REQUIRE_THROWS(mySplit->forward());
}
SECTION("Split with bad outNumber") {
auto mySplit = Split(DimSize_t(8), int8_t(2));
mySplit->getOperator()->associateInput(0, input);
mySplit->getOperator()->setBackend("cpu");
mySplit->getOperator()->setDataType(DataType::Int32);
REQUIRE_THROWS(mySplit->forward());
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment