Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
HeavisideImpl.cpp 1.87 KiB
/********************************************************************************
 * Copyright (c) 2025 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/backend/cpu/operator/HeavisideImpl.hpp"

#include <stdexcept>

#include "aidge/backend/cpu/operator/HeavisideImpl_kernels.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/utils/ErrorHandling.hpp"

template <> void Aidge::HeavisideImplCpu::forward() {
    const Heaviside_Op &op_ = dynamic_cast<const Heaviside_Op &>(mOp);
    std::shared_ptr<Tensor> input0 = op_.getInput(0);
    std::shared_ptr<Tensor> output0 = op_.getOutput(0);
    AIDGE_ASSERT(input0, "missing input #0");

    const auto impl =
        Registrar<HeavisideImplCpu>::create(getBestMatch(getRequiredSpec()));

    impl.forward(input0->size(),
                 getCPUPtr(mOp.getRawInput(0)),
                 getCPUPtr(mOp.getRawOutput(0)),
                 op_.value());
}

template <> 
void Aidge::HeavisideImplCpu::backward() {
    AIDGE_THROW_OR_ABORT(std::runtime_error, "Heaviside backward not implemented yet");

    // TODO: The following lines are assuming that the surrogate gradient is Atan
    // remove that assumption by providing an attribute to Heaviside, 
    // allowing to choose between different surrogate gradients.
    
    // const Heavisde_Op& op_ = dynamic_cast<const Heavisie_Op &>(mOp);



    // ! backward of hs = forward of atan
    //const auto impl = Registrar<HeavisideImplCpu>::create(getBestMatch(getRequiredSpec()));
    // std::shared_ptr<Tensor> in0 = op_.getInput(0);
    // std::shared_ptr<Tensor> out0 = op_.getOutput(0);
  
    //impl.forward()
}