/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <cassert> #include <chrono> // std::chrono::milliseconds #include <numeric> // std::accumulate #include <thread> // std::this_thread::sleep_for #include <vector> #include "aidge/operator/FC.hpp" #include "aidge/utils/Types.h" #include "aidge/backend/cpu/operator/FCImpl.hpp" #include "aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp" void Aidge::FCImpl_cpu::forward() { assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1"); assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(2)) && "missing input #2"); // Find the correct kernel type auto kernelFunc = Registrar<FCImplForward_cpu>::create( {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); // Call kernel // if (std::static_pointer_cast<Tensor>(mOp.getRawInput(0)->nbDims() == 4) { // kernelFunc( // mOp.getStaticAttributes(), // std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(), // std::static_pointer_cast<Tensor>(mOp.getRawInput(0)->getImpl()->rawPtr(), // mOp.mInputs[1]->getImpl()->rawPtr(), // mOp.mInputs[2]->getImpl()->rawPtr(), // mOp.getOutput(0)->getImpl()->rawPtr()); // } // else kernelFunc( dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()[0], std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->sizeM1(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); }