Skip to content
Snippets Groups Projects
Commit 991a38c1 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add fc operator

parent 7401428e
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "aidge/backend/cuda/data/TensorImpl.hpp" #include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cuda/operator/AvgPoolingImpl.hpp"
#include "aidge/backend/cuda/operator/ConvImpl.hpp" #include "aidge/backend/cuda/operator/ConvImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl.hpp"
#include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cuda/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cuda/operator/ProducerImpl.hpp" #include "aidge/backend/cuda/operator/ProducerImpl.hpp"
#include "aidge/backend/cuda/operator/ReLUImpl.hpp" #include "aidge/backend/cuda/operator/ReLUImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_
#define AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include <cudnn.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
class FCImplForward_cuda : public Registrable<FCImplForward_cuda,
std::tuple<DataType>,
void(unsigned int , unsigned int , unsigned int, bool, const void* , const void* , const void* , void*)> {};
class FCImpl_cuda : public OperatorImpl {
private:
// CuDNN specific variables
public:
FCImpl_cuda(const FC_Op &op) : OperatorImpl(op) {}
static std::unique_ptr<FCImpl_cuda> create(const FC_Op &op) {
return std::make_unique<FCImpl_cuda>(op);
}
public:
void forward();
// ~FCImpl_cuda();
private:
template <class T> void forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, DimSize_t outChannels);
};
namespace {
// add cuda backend to FC_Op implementation registry
static Registrar<FC_Op> registrarFCImpl_cuda("cuda", Aidge::FCImpl_cuda::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_FCIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_
#define AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_
#include <stdexcept>
#include <cfloat>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "aidge/data/Data.hpp"
#include "aidge/backend/cuda/operator/FCImpl.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
template<class T>
void fc_forward_cuda(DimSize_t nbInputs, DimSize_t inChannels, DimSize_t outChannels, bool noBias, const void *input, const void *weights, const void *bias, void *output);
namespace {
static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float32({DataType::Float32}, Aidge::fc_forward_cuda<float>);
static Registrar<FCImplForward_cuda> registrarFCImpl2DForward_cuda_Float64({DataType::Float64}, Aidge::fc_forward_cuda<double>);
} // namespace
}
#endif /* AIDGE_CUDA_OPERATOR_FCIMPL_FORWARD_KERNEL_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include <iostream>
#include "aidge/utils/Types.h"
#include "aidge/operator/FC.hpp"
#include "aidge/backend/cuda/data/TensorImpl.hpp"
#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
#include "aidge/backend/cuda/operator/FCImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
void Aidge::FCImpl_cuda::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
std::shared_ptr<Tensor> inputFallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(inputFallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const FC_Op& fcOp = static_cast<const FC_Op&>(mOp);
bool noBias = fcOp.template getAttr<FCAttr::NoBias>();
DimSize_t outChannels = fcOp.template getAttr<FCAttr::OutChannels>();
if (std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType() == DataType::Float64) {
forward_<double>(input0, input1, input2, noBias, outChannels);
}
else {
forward_<float>(input0, input1, input2, noBias, outChannels);
}
}
template<class T>
void Aidge::FCImpl_cuda::forward_(const Tensor& input0, const Tensor& input1, const Tensor& input2, bool noBias, DimSize_t outChannels)
{
Aidge::fc_forward_cuda<T>(
input0.dims()[0],
input0.size() / input0.dims()[0],
outChannels,
noBias,
input0.getImpl()->rawPtr(),
input1.getImpl()->rawPtr(),
input2.getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <stdio.h>
#include "aidge/backend/cuda/operator/FCImpl_CUDA_kernels.hpp"
template<class T>
__global__
void fc_forward_cuda_kernel(std::size_t nbInputs, std::size_t inChannels, std::size_t outChannels, bool noBias,const T* input, const T* weights, const T* bias, T *output)
{
const std::size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
for(std::size_t batch=idx; batch<nbInputs; ++batch)
{
for (std::size_t out = 0; out < outChannels; ++out) {
T sum = 0;
for (std::size_t in = 0; in < inChannels; ++in) {
sum += input[batch * inChannels + in] * weights[out * inChannels + in];
}
output[batch * outChannels + out] = sum + (noBias ? 0 : bias[out]);
}
}
}
namespace Aidge{
template<class T>
void fc_forward_cuda(DimSize_t nbInputs, DimSize_t inChannels, DimSize_t outChannels, bool noBias, const void* input_, const void* weights_, const void* bias_, void* output_)
{
const T* input = static_cast<const T*>(input_);
const T* weights = static_cast<const T*>(weights_);
const T* bias = static_cast<const T*>(bias_);
T * output = static_cast<T*>(output_);
const dim3 blocksPerGrid = {(static_cast<unsigned int>(inChannels) + 255) / 256, 1, static_cast<unsigned int>(outChannels)};
const dim3 threadsPerBlocks = {256, 1, 1};
fc_forward_cuda_kernel<<<blocksPerGrid, threadsPerBlocks>>>(nbInputs, inChannels, outChannels, noBias, input, weights, bias, output);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <array>
#include <catch2/catch_test_macros.hpp>
#include "Test_cuda.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/backend/cuda.hpp"
using namespace Aidge;
TEST_CASE("[gpu/operator] FC(forward)", "[FC][GPU]") {
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array2D<float, 5, 75>{
{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}});
std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float, 5>{{1, 2, 3, 4, 5}});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array2D<float, 2, 5>{
{{23601, 23602, 23603, 23604, 23605}, {68601, 68602, 68603, 68604, 68605}}});
myWeights->setBackend("cuda");
myBias->setBackend("cuda");
std::shared_ptr<Node> myFC = FC(75, 5, false, "myfc");
auto op = std::static_pointer_cast<OperatorTensor>(myFC -> getOperator());
op -> associateInput(1, myWeights);
op -> associateInput(2, myBias);
SECTION("2D input") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array2D<float, 2, 75>{
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74},
{75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149}}});
myInput->setBackend("cuda");
op->associateInput(0, myInput);
op -> setDataType(DataType::Float32);
op -> setBackend("cuda");
op->computeOutputDims();
myFC->forward();
float* computedOutput = new float[myOutput->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
for(int i = 0; i < myOutput->size(); i++){
const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
std::cout << "targetOutput " << targetOutput << ", out " << computedOutput[i]<<std::endl;
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
}
delete[] computedOutput;
}
SECTION("4D input") {
std::shared_ptr<Tensor> myInput =
std::make_shared<Tensor>(Array4D<float, 2, 3, 5, 5>{{{{{0, 1, 2, 3, 4},
{5, 6, 7, 8, 9},
{10, 11, 12, 13, 14},
{15, 16, 17, 18, 19},
{20, 21, 22, 23, 24}},
{{25, 26, 27, 28, 29},
{30, 31, 32, 33, 34},
{35, 36, 37, 38, 39},
{40, 41, 42, 43, 44},
{45, 46, 47, 48, 49}},
{{50, 51, 52, 53, 54},
{55, 56, 57, 58, 59},
{60, 61, 62, 63, 64},
{65, 66, 67, 68, 69},
{70, 71, 72, 73, 74}}},
{{{75, 76, 77, 78, 79},
{80, 81, 82, 83, 84},
{85, 86, 87, 88, 89},
{90, 91, 92, 93, 94},
{95, 96, 97, 98, 99}},
{{100, 101, 102, 103, 104},
{105, 106, 107, 108, 109},
{110, 111, 112, 113, 114},
{115, 116, 117, 118, 119},
{120, 121, 122, 123, 124}},
{{125, 126, 127, 128, 129},
{130, 131, 132, 133, 134},
{135, 136, 137, 138, 139},
{140, 141, 142, 143, 144},
{145, 146, 147, 148, 149}}}}});
myInput->setBackend("cuda");
op->associateInput(0, myInput);
op -> setDataType(DataType::Float32);
op -> setBackend("cuda");
op->computeOutputDims();
myFC->forward();
float* computedOutput = new float[myOutput->size()]();
cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * myOutput->size(), cudaMemcpyDeviceToHost);
for(int i = 0; i < myOutput->size(); i++){
const float targetOutput = *(static_cast<float*>(myOutput->getImpl()->rawPtr()) + i);
REQUIRE(fabs(computedOutput[i] - targetOutput) < 1e-6);
}
delete[] computedOutput;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment