From ac87b48e65cf9ba8220d71553a948a0263b22fc7 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 19 Mar 2024 15:41:04 +0100
Subject: [PATCH] fix onesVector for FCImpl

---
 src/operator/FCImpl.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index 02c74dc..bb107a5 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -61,8 +61,8 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co
     int lda = k;
     int ldb = k;
     int ldc = n;
-    const T alpha = 1.0;
-    const T beta = 0.0;
+    const T alpha = T(1.0);
+    const T beta = T(0.0);
     CHECK_CUBLAS_STATUS(cublasGemm(CudaContext::cublasHandle(),
                                     CUBLAS_OP_T,
                                     CUBLAS_OP_N,
@@ -80,9 +80,9 @@ void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, co
 
     if(!noBias){
         T* onesVector;
-        cudaMalloc((void**)&onesVector, outChannels * sizeof(T));
+        cudaMalloc((void**)&onesVector, m * sizeof(T));
         // Fill the vector with ones
-        std::vector<T> onesVec(m, 1.0f);
+        std::vector<T> onesVec(m, T(1.0));
         CHECK_CUDA_STATUS(cudaMemcpy(onesVector,
                                     &onesVec[0],
                                     m * sizeof(T),
-- 
GitLab