Skip to content
Snippets Groups Projects

[Add] Dropout Operator

Merged Marwa ABDELOUINISSE requested to merge maab05/aidge_backend_cpu:feat_183_add_dropout into dev
Compare and
4 files
+ 197
0
Compare changes
  • Side-by-side
  • Inline
Files
4
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_DROPOUTIMPL_H_
#define AIDGE_CPU_OPERATOR_DROPOUTIMPL_H_
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Dropout.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
// Forward and backward templates for CPU kernel functions
class DropoutImplForward_cpu
: public Registrable<DropoutImplForward_cpu,
std::tuple<DataType, DataType>,
void(const typename Dropout_Op::Attrs&,
const std::vector<DimSize_t>&,
const void*,
void*)> {};
class DropoutImplBackward_cpu
: public Registrable<DropoutImplBackward_cpu,
std::tuple<DataType, DataType>,
void(const typename Dropout_Op::Attrs&,
const std::vector<DimSize_t>&,
const void*,
void*)> {};
// CPU implementation for Dropout operator
class DropoutImpl_cpu : public OperatorImpl {
public:
DropoutImpl_cpu(const Dropout_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<DropoutImpl_cpu> create(const Dropout_Op& op) {
return std::make_unique<DropoutImpl_cpu>(op);
}
void forward() override;
};
// Register DropoutImpl_cpu with the backend as "cpu"
namespace {
static Registrar<Dropout_Op> registrarDropoutImpl_cpu("cpu", Aidge::DropoutImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_DROPOUTIMPL_H_ */
\ No newline at end of file
Loading