Skip to content
Snippets Groups Projects
Commit 328ee3e3 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Merge branch 'tiling' of gitlab.eclipse.org:eclipse/aidge/aidge_backend_cpu into dataloader

parents 895d7a43 ed45d0c4
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!2Scheduler ConnectInput and tensor filling with offset tests
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#define AIDGE_CPU_OPERATOR_SLICEIMPL_H_ #define AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#include <memory> #include <memory>
#include <tuple>
#include <vector> #include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
...@@ -39,7 +38,6 @@ class SliceImplBackward_cpu ...@@ -39,7 +38,6 @@ class SliceImplBackward_cpu
const void*, const void*,
void*)> {}; void*)> {};
class SliceImpl_cpu : public OperatorImpl { class SliceImpl_cpu : public OperatorImpl {
public: public:
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {} SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {}
...@@ -48,7 +46,6 @@ public: ...@@ -48,7 +46,6 @@ public:
return std::make_unique<SliceImpl_cpu>(op); return std::make_unique<SliceImpl_cpu>(op);
} }
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
...@@ -58,14 +55,12 @@ public: ...@@ -58,14 +55,12 @@ public:
void updateConsummerProducer() override final; void updateConsummerProducer() override final;
void forward() override; void forward() override;
void backward() override; void backward() override;
}; };
namespace { namespace {
static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create); static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create);
} // namespace }
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */ #endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_H_ */
\ No newline at end of file
...@@ -12,57 +12,73 @@ ...@@ -12,57 +12,73 @@
#ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ #ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include <vector>
#include <cstddef> #include <cstddef>
#include <vector>
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
template <class I> template <class I>
void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs, void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs,
const std::vector<std::size_t> inputDims, const std::vector<std::size_t> inputDims,
const void* input_, const void* input_,
void* output_) { void* output_) {
std::vector<std::size_t> slicedDims = inputDims;
std::size_t beginning = 0;
DimSize_t nbAxes = std::get<2>(attrs).size();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = std::get<2>(attrs)[i];
const std::int64_t start_ = std::get<0>(attrs)[i];
const std::int64_t end_ = std::get<1>(attrs)[i];
const std::size_t axis = axis_ >= 0 ? axis_ : static_cast<std::size_t>(axis_ + static_cast<std::int32_t>(inputDims.size()));
const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis];
const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis];
std::size_t stride = 1;
for (std::size_t j = inputDims.size() - 1; j > axis; --j) stride *= inputDims[j];
beginning += start * stride;
const std::size_t sliceLength = end - start + 1;
slicedDims[axis] = sliceLength;
}
const I* input = static_cast<const I*>(input_) + std::get<0>(attrs); const I* input = static_cast<const I*>(input_) + beginning;
I* output = static_cast<I*>(output_); I* output = static_cast<I*>(output_);
const std::vector<std::size_t> slicedDims = std::get<1>(attrs);
const std::size_t nbDims = slicedDims.size(); const std::size_t nbDims = slicedDims.size();
// for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3} // for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3}
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
for (std::size_t i = 0; i < nbDims; ++i) { for (std::size_t i = 0; i < nbDims; ++i) {
substractedDims[i] = inputDims[i] - slicedDims[i]; substractedDims[i] = inputDims[i] - slicedDims[i];
} }
// for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1} // for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1}
std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims); std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims+1); std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1);
prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1]; prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1];
prodInputDims[nbDims - 1] = inputDims[nbDims - 1]; prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
prodInputDims[nbDims] = 1; prodInputDims[nbDims] = 1;
for (std::size_t i = 2; i <= nbDims; ++i) { for (std::size_t i = 2; i <= nbDims; ++i) {
prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1]*slicedDims[nbDims - i]; prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1] * slicedDims[nbDims - i];
prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1]*inputDims[nbDims - i]; prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i];
} }
std::size_t j = 0; std::size_t j = 0;
std::size_t i = 0; std::size_t i = 0;
for (; j < prodSlicedDims[0];) { for (; j < prodSlicedDims[0];) {
output[j] = input[i++]; output[j] = input[i++];
++j; ++j;
for (std::size_t idx = nbDims - 1; idx > 0; --idx) { for (std::size_t idx = nbDims - 1; idx > 0; --idx) {
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx] * prodInputDims[idx + 1] : 0;
} }
} }
} }
namespace { namespace {
// DIM = 1
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32( static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>); {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32( static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32(
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cmath>
#include <cstdlib>
#include <memory>
#include "aidge/backend/cpu/operator/ConvImpl.hpp"
#include "aidge/backend/cpu/operator/PadImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Pad.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][PaddedConv][CPU]") {
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(
Array4D<double, 4, 3, 3, 3>{{{{{6.20986394e-01, 1.19775136e-03, 7.22876095e-02},
{1.16492919e-01, 8.21634093e-02, 1.17413265e-01},
{2.23743494e-01, 3.99495413e-01, 5.55552411e-01}},
{{6.64970077e-01, 9.62199940e-01, 4.87531967e-01},
{6.12586558e-01, 8.09918671e-02, 8.40649383e-01},
{4.15264406e-01, 8.28247138e-01, 1.52301135e-01}},
{{1.76992844e-02, 7.78697112e-01, 8.14531592e-01},
{1.36960611e-01, 4.64806728e-01, 4.85150000e-01},
{4.34776520e-01, 9.51740977e-01, 9.05793799e-01}}},
{{{1.71925246e-02, 1.91082720e-01, 3.67982644e-01},
{1.56806559e-01, 6.22280998e-01, 3.15827594e-01},
{6.04359038e-01, 2.83095947e-01, 6.11168892e-01}},
{{2.76942832e-01, 1.89768419e-01, 8.07988176e-01},
{1.67925807e-01, 2.68356150e-01, 6.28875602e-01},
{1.69093357e-04, 9.64788636e-01, 7.29254981e-01}},
{{6.34030122e-01, 1.32087038e-01, 3.33857107e-01},
{7.63047502e-01, 5.12539506e-02, 9.77400493e-01},
{8.06151288e-01, 2.60237147e-01, 3.93729313e-01}}},
{{{5.84605240e-01, 4.74648725e-01, 8.54111741e-01},
{7.10897067e-02, 5.02579011e-01, 3.35236224e-01},
{9.08637408e-01, 8.02903830e-01, 2.83929907e-01}},
{{3.68206999e-01, 9.18579021e-02, 7.33168098e-01},
{1.59875539e-01, 9.13163381e-01, 3.59806060e-01},
{1.41295882e-01, 7.00312185e-01, 5.63728289e-01}},
{{9.39513546e-01, 1.91704891e-01, 1.11454944e-01},
{5.46298282e-01, 2.89698587e-01, 2.62612651e-01},
{1.18554992e-01, 4.32147376e-02, 7.53016994e-01}}},
{{{9.53179175e-01, 2.05041054e-02, 1.11318451e-01},
{8.67878485e-01, 2.93263422e-01, 8.03912714e-01},
{8.93620255e-01, 1.37831128e-01, 3.83640583e-01}},
{{3.96020188e-01, 6.24959320e-01, 1.90709175e-01},
{5.80538620e-01, 6.63031275e-01, 2.07247191e-01},
{5.65672171e-01, 5.57014317e-01, 9.26909496e-01}},
{{3.43901418e-01, 4.47741636e-01, 6.59249367e-01},
{7.34639028e-01, 2.84957200e-02, 9.70225217e-01},
{1.33578790e-02, 6.12054702e-01, 9.36685235e-02}}}}});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(
Array1D<double, 4>{{0.16884905, 0.27994487, 0.57227465, 0.06435205}});
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<double, 2, 3, 5, 5>{
// NCHW
{{{{0.43224481, 0.9047832, 0.18402257, 0.06162838, 0.52490127},
{0.27773404, 0.55402353, 0.9485062, 0.31197083, 0.80328607},
{0.85065842, 0.88226201, 0.54971951, 0.23360494, 0.53907884},
{0.33423098, 0.79564312, 0.80419414, 0.76839638, 0.87248221},
{0.77328729, 0.65749407, 0.47277589, 0.32889198, 0.93970518}},
{{0.66669145, 0.64193351, 0.45315988, 0.32794057, 0.38461822},
{0.72295814, 0.18395073, 0.85909664, 0.30010301, 0.56065865},
{0.34777938, 0.77869746, 0.33159421, 0.19540932, 0.77767906},
{0.5778391, 0.08218411, 0.27758371, 0.99017749, 0.61827997},
{0.10440745, 0.3197831, 0.89157608, 0.12216887, 0.950232}},
{{0.68073443, 0.2681118, 0.51848834, 0.62864493, 0.36717478},
{0.64106244, 0.43779425, 0.02771029, 0.78275231, 0.45693104},
{0.6487417, 0.01603838, 0.73869997, 0.96494221, 0.39588782},
{0.5975827, 0.90913292, 0.55036969, 0.4747373, 0.62460509},
{0.79675124, 0.02807549, 0.53227602, 0.88805927, 0.96646591}}},
{{{0.81851935, 0.21267665, 0.01580692, 0.54907998, 0.89010049},
{0.80165784, 0.55195592, 0.20740314, 0.22782844, 0.89205031},
{0.94217108, 0.58434542, 0.20738313, 0.79065873, 0.9371597},
{0.02254708, 0.95539178, 0.95165758, 0.53736666, 0.49100362},
{0.08018625, 0.69108027, 0.00329741, 0.74565761, 0.30899213}},
{{0.34868638, 0.12792604, 0.37382248, 0.0374756, 0.50653087},
{0.59614405, 0.64820746, 0.31470307, 0.62460364, 0.29253268},
{0.92864889, 0.51014224, 0.08921206, 0.11094072, 0.64691121},
{0.50586371, 0.6686477, 0.72511169, 0.41681783, 0.6325049},
{0.71594137, 0.73382767, 0.36589439, 0.03255165, 0.75006865}},
{{0.6294127, 0.85548534, 0.0902963, 0.28915773, 0.36564289},
{0.95873236, 0.6742374, 0.55679676, 0.6323497, 0.34072958},
{0.49694061, 0.79173045, 0.19738225, 0.14755281, 0.80818177},
{0.02332061, 0.74270703, 0.59415632, 0.08195934, 0.46295434},
{0.71426058, 0.85032931, 0.90750818, 0.28768431, 0.4401146}}}}});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(
Array4D<double, 2, 4, 5, 5>{{{{{3.40294218, 3.74021220, 4.02050114, 4.07054710, 2.46286273},
{4.61770582, 6.70517588, 6.50356627, 6.29688787, 3.53332567},
{5.47480106, 5.92094421, 6.64605665, 7.95090199, 4.28721523},
{4.01485729, 6.06748962, 7.52447891, 7.37980652, 5.28401136},
{2.83065438, 3.62033439, 3.56222963, 5.56103945, 3.23335814}},
{{3.30230498, 4.92814112, 4.34710836, 3.96262765, 2.97987890},
{4.49693012, 6.68929291, 5.53603029, 5.68874264, 4.28756475},
{4.20528078, 6.82776880, 6.70569849, 7.12809610, 4.40845442},
{4.31169367, 6.73352146, 6.30962515, 7.45826864, 4.99164438},
{2.18136287, 4.28968000, 4.20080042, 4.89814138, 2.87394023}},
{{3.54787683, 4.35851812, 4.63881302, 4.23359537, 3.16992092},
{5.25099468, 7.54282856, 6.69849157, 5.64309788, 4.56919575},
{4.71914101, 7.52830601, 6.71450949, 7.81113863, 5.84658146},
{4.97893143, 7.39293909, 6.89905310, 8.14430809, 5.62998581},
{2.79735112, 4.80967140, 5.57630205, 5.38828325, 4.57078695}},
{{3.03048635, 5.04540300, 4.21824932, 4.87323284, 2.35113740},
{4.45167351, 6.47721338, 7.40922976, 6.70445728, 3.60700107},
{3.77927423, 6.82826376, 7.41777134, 7.57402420, 5.13131523},
{4.08747244, 7.07994175, 7.57206821, 8.51897335, 5.26987123},
{2.34426999, 4.60127831, 4.86486769, 6.01579571, 3.97803569}}},
{{{3.84700942, 4.25972605, 3.05269003, 3.78043652, 2.08771229},
{6.00459957, 6.05633259, 4.45951605, 4.54089880, 4.03066444},
{5.41579390, 7.29543972, 6.18680000, 5.58812714, 3.45964241},
{6.04531050, 7.70924091, 5.52207708, 5.02131319, 4.09403706},
{3.18092418, 4.45422697, 4.04294252, 3.86577177, 2.18776536}},
{{4.02600670, 4.27603531, 3.81011319, 4.03631020, 2.57254648},
{5.33471155, 5.72588634, 5.12079763, 5.11733150, 3.76836705},
{5.62947607, 5.92492962, 6.24170446, 6.44130468, 3.44276404},
{5.38414621, 6.02679539, 5.88985586, 5.90263271, 3.15044069},
{3.31261086, 4.44371319, 3.47660780, 4.15411520, 1.48961508}},
{{3.95879412, 4.17324543, 3.70114422, 3.27447152, 3.09713888},
{5.78258181, 6.57920837, 4.99913597, 6.20961237, 4.98552179},
{5.84685421, 7.19971228, 6.66386652, 6.68013430, 4.90963316},
{5.24417877, 7.06430531, 6.58512402, 6.02492285, 4.48986387},
{3.64294529, 5.00678444, 5.04760027, 4.72895622, 2.67990756}},
{{3.48610687, 4.12853813, 4.07563591, 3.51327014, 2.44217038},
{4.80529881, 7.33211374, 5.14774036, 4.77281189, 4.44612408},
{5.11703110, 7.55168772, 7.14374542, 6.43696356, 4.10621357},
{5.41270018, 6.85949135, 6.73503923, 5.74601364, 4.46150303},
{3.16612267, 4.38248920, 5.23248482, 4.21292210, 2.86031270}}}}});
std::shared_ptr<Node> myConv = Conv<2>(3, 4, {3, 3}, "myconv");
auto convOp = std::static_pointer_cast<OperatorTensor>(myConv->getOperator());
std::shared_ptr<Node> myPad =
Pad<2>({1, 1, 1, 1}, "myPad", PadBorderType::Constant, 0.0);
auto padOp = std::static_pointer_cast<OperatorTensor>(myPad->getOperator());
convOp->setInput(1, myWeights);
convOp->setInput(2, myBias);
myPad->addChild(myConv, 0, 0);
padOp->setInput(0, myInput);
padOp->setDataType(DataType::Float64);
padOp->setBackend("cpu");
padOp->computeOutputDims();
convOp->setDataType(DataType::Float64);
convOp->setBackend("cpu");
convOp->computeOutputDims();
myPad->forward();
myConv->forward();
convOp -> getOutput(0) -> print();
double* computedOutput = static_cast<double*>(convOp->getOutput(0)->getImpl()->rawPtr());
double* expectedOutput = static_cast<double*>(myOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i < myOutput->size(); ++i) {
REQUIRE(std::abs(computedOutput[i] - expectedOutput[i]) < 1e-5);
}
std::shared_ptr<Node> myPaddedConv =
PaddedConv(3, 4, {3, 3}, "myPaddedConv", {1, 1}, {1, 1, 1, 1});
}
\ No newline at end of file
...@@ -27,14 +27,14 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -27,14 +27,14 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
{0, 1, 2,-3} {0, 1, 2,-3}
}); });
std::shared_ptr<Node> mySlice = Slice(0, {4}); std::shared_ptr<Node> mySlice = Slice({0}, {3}, {0});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu"); mySlice->getOperator()->setBackend("cpu");
op->computeOutputDims(); op->computeOutputDims();
mySlice->forward(); mySlice->forward();
// mySlice->getOperator()->output(0).print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
...@@ -54,7 +54,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -54,7 +54,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(5, {2,3}); std::shared_ptr<Node> mySlice = Slice({0,5}, {1,7}, {0,1});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
...@@ -88,7 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -88,7 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(14, {1,1,3}); std::shared_ptr<Node> mySlice = Slice({0,1,4}, {0,1,6}, {0,1,2});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
...@@ -151,7 +151,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -151,7 +151,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(0, {2,2,2,10}); std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,9}, {0,1,2,3});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
......
...@@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") { ...@@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") {
} }
}); });
std::shared_ptr<Node> mySoftmax = Softmax(); std::shared_ptr<Node> mySoftmax = Softmax(1);
auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator());
mySoftmax->getOperator()->associateInput(0,input); mySoftmax->getOperator()->associateInput(0,input);
mySoftmax->getOperator()->setDataType(DataType::Float32); mySoftmax->getOperator()->setDataType(DataType::Float32);
...@@ -108,7 +108,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") { ...@@ -108,7 +108,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)", "[Softmax][CPU]") {
} }
}); });
std::shared_ptr<Node> mySoftmax = Softmax(); std::shared_ptr<Node> mySoftmax = Softmax(1);
auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySoftmax -> getOperator());
mySoftmax->getOperator()->associateInput(0,input); mySoftmax->getOperator()->associateInput(0,input);
mySoftmax->getOperator()->setDataType(DataType::Float32); mySoftmax->getOperator()->setDataType(DataType::Float32);
......
...@@ -183,26 +183,4 @@ TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") { ...@@ -183,26 +183,4 @@ TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") {
} }
} }
} }
} } // namespace Aidge
// std::shared_ptr<GraphView> g = Sequential({ \ No newline at end of file
// Conv(3, 16, {3,3}, "conv1"),
// ReLU("relu1"),
// Conv(16, 32, {1,1}, "conv2"),
// Conv(32, 16, {1,1}, "conv3"),
// Conv(16, 10, {3,3}, "conv4"),
// ReLU("relu2")
// });
// for (auto& individualConv : g->match("Conv")) {
// auto tiledConv = horizontalTiling(individualConv);
// g->replace(individualConv, tiledConv);
// }
// }
// SECTION("Create the GraphView with tiled layers") {
// std::shared_ptr<GraphView> g;
// g->addChild(horizontalTiling(Conv()))
// }
// }
// } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment