Skip to content
Snippets Groups Projects

Add backward functions for ReLU, Sigmoid and Tanh

Merged Olivier Antoni requested to merge (removed):dev into dev
6 files
+ 156
0
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -36,6 +36,7 @@ private:
cudnnActivationMode_t mReLUDesc = nullptr;
#endif
std::shared_ptr<Tensor> mInputFallback;
std::shared_ptr<Tensor> mOutputGradFallback;
public:
ReLUImpl_cuda(const ReLU_Op &op) : OperatorImpl(op, "cuda") {}
@@ -46,10 +47,12 @@ public:
public:
void forward();
void backward();
~ReLUImpl_cuda();
private:
template <class T> void forward_(const Tensor& input);
template <class T> void backward_(const Tensor& output_grad);
};
namespace {
Loading