Skip to content
Snippets Groups Projects
Commit 16f41fab authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed missing changes

parent 9127608f
No related branches found
No related tags found
1 merge request!6Tensor setter getter
Pipeline #32459 passed with warnings
...@@ -27,11 +27,11 @@ namespace Aidge { ...@@ -27,11 +27,11 @@ namespace Aidge {
// compute kernel registry for forward and backward // compute kernel registry for forward and backward
class MatMulImplForward_cpu class MatMulImplForward_cpu
: public Registrable<MatMulImplForward_cpu, std::tuple<DataType, DataType, DataType>, : public Registrable<MatMulImplForward_cpu, std::tuple<DataType, DataType, DataType>,
void(const MatMul_Op::Parameters &, const DimSize_t, const DimSize_t, void(const MatMul_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, void *)> {}; const void *, const void *, void *)> {};
class MatMulImplBackward_cpu class MatMulImplBackward_cpu
: public Registrable<MatMulImplBackward_cpu, std::tuple<DataType, DataType, DataType>, : public Registrable<MatMulImplBackward_cpu, std::tuple<DataType, DataType, DataType>,
void(const MatMul_Op::Parameters &, const DimSize_t, const DimSize_t, void(const MatMul_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, void *)> {}; const void *, const void *, void *)> {};
class MatMulImpl_cpu : public OperatorImpl { class MatMulImpl_cpu : public OperatorImpl {
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace Aidge { namespace Aidge {
template <class I, class W, class O> template <class I, class W, class O>
void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Parameters& params, const DimSize_t batchSize, const DimSize_t oneInputSize, void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Attrs& attrs, const DimSize_t batchSize, const DimSize_t oneInputSize,
const void* input_, const void* weights_, void* output_) { const void* input_, const void* weights_, void* output_) {
// FIXME: missing MatMul parameters as arguments // FIXME: missing MatMul parameters as arguments
const I* input = static_cast<const I*>(input_); const I* input = static_cast<const I*>(input_);
...@@ -28,14 +28,14 @@ void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Parameters& params, const Di ...@@ -28,14 +28,14 @@ void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Parameters& params, const Di
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
std::fill(output, output+(batchSize*std::get<0>(params)), O(0)); std::fill(output, output+(batchSize*std::get<0>(attrs)), O(0));
for (std::size_t batch = 0; batch < batchSize; ++batch) { for (std::size_t batch = 0; batch < batchSize; ++batch) {
for (std::size_t out = 0; out < std::get<0>(params); ++out) { for (std::size_t out = 0; out < std::get<0>(attrs); ++out) {
output[out + batch*std::get<0>(params)] = std::inner_product(input + batch*oneInputSize, output[out + batch*std::get<0>(attrs)] = std::inner_product(input + batch*oneInputSize,
input + (batch + 1)*oneInputSize, input + (batch + 1)*oneInputSize,
weights + out*oneInputSize, weights + out*oneInputSize,
output[out + batch*std::get<0>(params)]); output[out + batch*std::get<0>(attrs)]);
} }
} }
} }
......
...@@ -96,7 +96,7 @@ void Aidge::MatMulImpl_cpu::forward() ...@@ -96,7 +96,7 @@ void Aidge::MatMulImpl_cpu::forward()
// Call kernel // Call kernel
// if (mOp.getInput(0)->nbDims() == 4) { // if (mOp.getInput(0)->nbDims() == 4) {
// kernelFunc( // kernelFunc(
// mOp.getParams(), // mOp.getStaticAttributes(),
// std::static_pointer_cast<Tensor>(mOp.getInput(0))->dims<4>(), // std::static_pointer_cast<Tensor>(mOp.getInput(0))->dims<4>(),
// mOp.getInput(0)->getImpl()->rawPtr(), // mOp.getInput(0)->getImpl()->rawPtr(),
// mOp.mInputs[1]->getImpl()->rawPtr(), // mOp.mInputs[1]->getImpl()->rawPtr(),
...@@ -105,7 +105,7 @@ void Aidge::MatMulImpl_cpu::forward() ...@@ -105,7 +105,7 @@ void Aidge::MatMulImpl_cpu::forward()
// } // }
// else // else
kernelFunc( kernelFunc(
mOp.getParams(), mOp.getStaticAttributes(),
mOp.getInput(0)->dims()[0], mOp.getInput(0)->dims()[0],
mOp.getInput(0)->sizeM1(), mOp.getInput(0)->sizeM1(),
mOp.getInput(0)->getImpl()->rawPtr(), mOp.getInput(0)->getImpl()->rawPtr(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment