Skip to content
Snippets Groups Projects
Commit 16f41fab authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed missing changes

parent 9127608f
No related branches found
No related tags found
No related merge requests found
......@@ -27,11 +27,11 @@ namespace Aidge {
// compute kernel registry for forward and backward
class MatMulImplForward_cpu
: public Registrable<MatMulImplForward_cpu, std::tuple<DataType, DataType, DataType>,
void(const MatMul_Op::Parameters &, const DimSize_t, const DimSize_t,
void(const MatMul_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, void *)> {};
class MatMulImplBackward_cpu
: public Registrable<MatMulImplBackward_cpu, std::tuple<DataType, DataType, DataType>,
void(const MatMul_Op::Parameters &, const DimSize_t, const DimSize_t,
void(const MatMul_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, void *)> {};
class MatMulImpl_cpu : public OperatorImpl {
......
......@@ -20,7 +20,7 @@
namespace Aidge {
template <class I, class W, class O>
void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Parameters& params, const DimSize_t batchSize, const DimSize_t oneInputSize,
void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Attrs& attrs, const DimSize_t batchSize, const DimSize_t oneInputSize,
const void* input_, const void* weights_, void* output_) {
// FIXME: missing MatMul parameters as arguments
const I* input = static_cast<const I*>(input_);
......@@ -28,14 +28,14 @@ void MatMulImpl_cpu_forward_kernel(const MatMul_Op::Parameters& params, const Di
O* output = static_cast<O*>(output_);
std::fill(output, output+(batchSize*std::get<0>(params)), O(0));
std::fill(output, output+(batchSize*std::get<0>(attrs)), O(0));
for (std::size_t batch = 0; batch < batchSize; ++batch) {
for (std::size_t out = 0; out < std::get<0>(params); ++out) {
output[out + batch*std::get<0>(params)] = std::inner_product(input + batch*oneInputSize,
for (std::size_t out = 0; out < std::get<0>(attrs); ++out) {
output[out + batch*std::get<0>(attrs)] = std::inner_product(input + batch*oneInputSize,
input + (batch + 1)*oneInputSize,
weights + out*oneInputSize,
output[out + batch*std::get<0>(params)]);
output[out + batch*std::get<0>(attrs)]);
}
}
}
......
......@@ -96,7 +96,7 @@ void Aidge::MatMulImpl_cpu::forward()
// Call kernel
// if (mOp.getInput(0)->nbDims() == 4) {
// kernelFunc(
// mOp.getParams(),
// mOp.getStaticAttributes(),
// std::static_pointer_cast<Tensor>(mOp.getInput(0))->dims<4>(),
// mOp.getInput(0)->getImpl()->rawPtr(),
// mOp.mInputs[1]->getImpl()->rawPtr(),
......@@ -105,7 +105,7 @@ void Aidge::MatMulImpl_cpu::forward()
// }
// else
kernelFunc(
mOp.getParams(),
mOp.getStaticAttributes(),
mOp.getInput(0)->dims()[0],
mOp.getInput(0)->sizeM1(),
mOp.getInput(0)->getImpl()->rawPtr(),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment