Skip to content
Snippets Groups Projects
Commit e9fdb3b6 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Step attribute to Slice

parent b97a42bc
No related branches found
No related tags found
No related merge requests found
...@@ -30,25 +30,26 @@ public: ...@@ -30,25 +30,26 @@ public:
void forward() override; void forward() override;
}; };
enum class SliceAttr { Starts, Ends, Axes }; enum class SliceAttr { Starts, Ends, Axes, Steps };
class Slice_Op class Slice_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>, public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>> { public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>> {
public: public:
static const std::string Type; static const std::string Type;
Slice_Op() = delete; Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>>; using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>>;
template <SliceAttr e> using attr = typename Attributes_::template attr<e>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes) Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes, const std::vector<std::int64_t>& steps)
: OperatorTensor(Type, 4, 0, 1), : OperatorTensor(Type, 5, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts), Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends), attr<SliceAttr::Ends>(ends),
attr<SliceAttr::Axes>(axes)) attr<SliceAttr::Axes>(axes),
attr<SliceAttr::Steps>(steps))
{ {
mImpl = std::make_shared<Slice_OpImpl>(*this); mImpl = std::make_shared<Slice_OpImpl>(*this);
} }
...@@ -83,11 +84,12 @@ public: ...@@ -83,11 +84,12 @@ public:
void setBackend(const std::string &name, DeviceIdx_t device = 0) override; void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "starts", "ends", "axes"}; return {"data_input", "starts", "ends", "axes", "steps"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
}; };
/** /**
...@@ -98,14 +100,32 @@ public: ...@@ -98,14 +100,32 @@ public:
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {}, inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {},
const std::vector<std::int64_t>& ends = {}, const std::vector<std::int64_t>& ends = {},
const std::vector<std::int8_t>& axes = {}, const std::vector<std::int8_t>& axes = {},
const std::vector<std::int64_t>& steps = {},
const std::string &name = "") { const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes, steps), name);
} }
} // namespace Aidge } // namespace Aidge
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" }; const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
} }
namespace Aidge {
class SliceImplForward
: public Registrable<SliceImplForward,
std::tuple<DataType>,
void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
template <typename I>
void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
namespace {
static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
{DataType::Float32}, Slice_forward_kernel<float>);
static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
{DataType::Int32}, Slice_forward_kernel<int>);
static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
{DataType::Float64}, Slice_forward_kernel<double>);
}
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
...@@ -30,6 +30,7 @@ void init_Slice(py::module& m) { ...@@ -30,6 +30,7 @@ void init_Slice(py::module& m) {
py::arg("starts") = std::vector<std::int64_t>(), py::arg("starts") = std::vector<std::int64_t>(),
py::arg("ends") = std::vector<std::int64_t>(), py::arg("ends") = std::vector<std::int64_t>(),
py::arg("axes") = std::vector<std::int8_t>(), py::arg("axes") = std::vector<std::int8_t>(),
py::arg("steps") = std::vector<std::int64_t>(),
py::arg("name") = ""); py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "aidge/operator/Slice.hpp" #include "aidge/operator/Slice.hpp"
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
...@@ -25,88 +26,91 @@ ...@@ -25,88 +26,91 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Registrar.hpp"
void Aidge::Slice_OpImpl::forward() { template<class I>
const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp); void Aidge::Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_){
const I* input = static_cast<const I*>(input_);
if (!op.getInput(0)) { I* output = static_cast<I*>(output_);
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
}
AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
const std::size_t nbDims = op.getInput(0)->nbDims();
const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
auto outputDims = op.getInput(0)->dims();
// compute index of the output's first element const std::size_t nbDims = inputDims.size();
// compute output dimension at the same time (may change between two forward calls) std::vector<DimSize_t> dims = inputDims;
std::size_t beginning = 0; DimSize_t totalSize = std::accumulate(inputDims.cbegin(), inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size(); I* outputAccumulation = new I[totalSize];
const I* inputAccumulation = input;
const std::size_t nbAxes = std::get<0>(attrs).size();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t DimIdx_t axis = std::get<2>(attrs)[i] >= 0 ?
DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ? static_cast<DimIdx_t>(std::get<2>(attrs)[i]) :
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) : static_cast<DimIdx_t>(std::get<2>(attrs)[i] + static_cast<DimIdx_t>(inputDims.size()));
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size())); std::int64_t start = std::get<0>(attrs)[i] >= 0 ?
DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ? std::get<0>(attrs)[i] :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) : std::get<0>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis])); std::int64_t end = std::get<1>(attrs)[i] >= 0 ?
DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ? std::get<1>(attrs)[i] :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) : std::get<1>(attrs)[i] + static_cast<std::int64_t>(inputDims[axis]);
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis])); std::int64_t step = std::get<3>(attrs)[i];
const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
beginning += start * stridePostAxis;
const std::size_t sliceLength = end - start;
outputDims[axis] = sliceLength;
}
op.getOutput(0)->resize(outputDims);
std::size_t sliceSize = static_cast<std::size_t>((end - start) / std::abs(step));
// for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3} if ( i > 0) {
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); outputAccumulation = new I[totalSize];
for (std::size_t i = 0; i < nbDims; ++i) { }
substractedDims[i] = inputDims[i] - outputDims[i]; const std::size_t stride_pre = std::accumulate(dims.cbegin(), dims.cbegin() + axis, 1, std::multiplies<std::size_t>());
} const std::size_t stride_post = std::accumulate(dims.crbegin(), dims.crbegin() + nbDims -1 - axis, 1, std::multiplies<std::size_t>());
std::int64_t firstElem = step > 0 ? start : end;
// for outputDims = {3,2,2,1}: prodOutputDims = {12,4,2,1} std::int64_t lastElem = step > 0 ? end : start;
std::vector<std::size_t> prodOutputDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1);
prodOutputDims[nbDims - 1] = outputDims[nbDims - 1];
prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
prodInputDims[nbDims] = 1;
for (std::size_t i = 2; i <= nbDims; ++i) {
prodOutputDims[nbDims - i] = prodOutputDims[nbDims - i + 1] * outputDims[nbDims - i];
prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i];
}
std::size_t i = beginning; for (std::size_t outer = 0; outer < stride_pre; outer++)
std::size_t size = 0; // number of elements to copy {
std::size_t offset = 0; std::size_t addedSlices = 0;
for (std::size_t j = 0; j < prodOutputDims[0];) { for (std::int64_t inner = firstElem; inner < lastElem; inner+=step)
++size; {
++i; size_t idx = outer * stride_post * dims[axis] + inner * stride_post;
++j; size_t idx_out = outer * stride_post * sliceSize + addedSlices * stride_post;
bool newChunk = false; if (idx < totalSize) {
for (std::size_t idx = nbDims - 1; idx > 0; --idx) { std::copy_n(std::next(inputAccumulation, idx), stride_post, std::next(outputAccumulation, idx_out));
if (j % prodOutputDims[idx] == 0) { }
i += substractedDims[idx] * prodInputDims[idx + 1]; addedSlices++;
newChunk = true;
} }
} }
totalSize /= dims[axis];
if (newChunk) { totalSize *= sliceSize;
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset); dims[axis] = sliceSize;
beginning = i;
offset += size; if (inputAccumulation != input) {
size = 0; delete[] inputAccumulation;
} }
inputAccumulation = outputAccumulation;
}
// Copy elements from inputAccumulation to output while dividing by divisor
std::copy_n(inputAccumulation, totalSize, output);
// op.getOutput(0)->getImpl()->copy(inputAccumulation, totalSize);
if (outputAccumulation) {
delete[] outputAccumulation;
} }
}
if (size > 0) { void Aidge::Slice_OpImpl::forward() {
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset); const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp);
if (!op.getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type);
} }
AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) &&
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
// Find the correct kernel type
auto kernelFunc =
Registrar<SliceImplForward>::create({std::static_pointer_cast<Tensor>(op.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Slice_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(op.getInput(0))->getImpl()->hostPtr(),
std::static_pointer_cast<Tensor>(op.getOutput(0))->getImpl()->hostPtr());
} }
const std::string Aidge::Slice_Op::Type = "Slice"; const std::string Aidge::Slice_Op::Type = "Slice";
...@@ -127,7 +131,7 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -127,7 +131,7 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType.");
this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs this->template getAttr<SliceAttr::Starts>().clear();
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
this->template getAttr<SliceAttr::Ends>().clear(); this->template getAttr<SliceAttr::Ends>().clear();
this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size()); this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size());
...@@ -179,11 +183,46 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -179,11 +183,46 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
std::back_inserter(this->template getAttr<SliceAttr::Axes>())); std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
break; break;
default: default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Input DataType is not supported.", type());
break; break;
} }
} }
// Fill Steps attr if empty
if(this->template getAttr<SliceAttr::Steps>().empty()) {
// In case the input Steps is not provided, default value is 1
this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(getInput(1)->size(), 1);
if (getInput(4) && !getInput(4)->empty()) {
this->template getAttr<SliceAttr::Steps>().clear();
this->template getAttr<SliceAttr::Steps>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[4]->getImpl()->rawPtr()),
getInput(4)->size(),
std::back_inserter(this->template getAttr<SliceAttr::Steps>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type());
break;
}
}
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
...@@ -197,7 +236,10 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -197,7 +236,10 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
const std::size_t sliceLength = end - start; if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
}
const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
// Check if slice length is valid // Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis]) if (sliceLength > getInput(0)->dims()[axis])
{ {
......
...@@ -106,7 +106,11 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -106,7 +106,11 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
std::vector<std::int8_t> usedDims(inputDimsEnd.size()); std::vector<std::int8_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0)); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); // Create Slice's Steps attribute
std::vector<std::int64_t> steps(inputDimsEnd.size());
std::iota(steps.begin(), steps.end(), static_cast<std::int64_t>(1));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, steps, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i); newNode -> addChild(concat, 0, i);
......
...@@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -69,7 +69,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu"); mySlice->getOperator()->setBackend("cpu");
mySlice->forward(); mySlice->forward();
// mySlice->getOperator()->output(0).print(); op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
...@@ -176,7 +176,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -176,7 +176,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu"); mySlice->getOperator()->setBackend("cpu");
mySlice->forward(); mySlice->forward();
// mySlice->getOperator()->output(0).print(); // op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
...@@ -217,13 +217,13 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -217,13 +217,13 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}); std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}, {1,1,1,1});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu"); mySlice->getOperator()->setBackend("cpu");
mySlice->forward(); mySlice->forward();
// mySlice->getOperator()->output(0).print(); // op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment