diff --git a/CMakeLists.txt b/CMakeLists.txt index b3c6d459dfaf29f5accbc0be4565a3709e9ffd3b..afb882af0d02000e5490f1d2a0c56b4487481be9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,7 @@ option(PYBIND "python binding" OFF) option(WERROR "Warning as error" OFF) option(TEST "Enable tests" OFF) option(COVERAGE "Enable coverage" OFF) -option(CUDA "Enable CUDA backend" OFF) # XXX OFF +option(CUDA "Enable CUDA backend" ON) # XXX OFF option(ENABLE_ASAN "Enable ASan (AddressSanitizer) for runtime analysis of memory use (over/underflow, memory leak, ...)" OFF) ############################################## @@ -182,7 +182,6 @@ endif() # Coverage flags for GCC if(CMAKE_COMPILER_IS_GNUCXX AND COVERAGE) - include(CodeCoverage) append_coverage_compiler_flags() endif() diff --git a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp index 9d7a106cdd6b1a0970c87b2a27cc7d6637846b49..935d8f065a5e91729c5c0ff25b13f5ea1234a8b6 100644 --- a/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/FixedQImpl_kernels.hpp @@ -23,7 +23,7 @@ void FixedQImpl_cpu_forward_kernel( std::size_t nbBits, float span_, bool isOutputUnsigned, - std::size_t inputLenght, + std::size_t inputLength, const void* input_, void* output_) { @@ -40,7 +40,7 @@ void FixedQImpl_cpu_forward_kernel( const I* input = static_cast<const I*>(input_); O* output = static_cast<O*>(output_); - for (std::size_t i = 0; i < inputLenght; ++i) { + for (std::size_t i = 0; i < inputLength; ++i) { I clipped = std::max(lower, std::min(input[i], upper)); output[i] = std::round(clipped / stepSize) * stepSize; } @@ -49,14 +49,14 @@ void FixedQImpl_cpu_forward_kernel( template <class GI, class GO> void FixedQImpl_cpu_backward_kernel( - const std::size_t inputLenght, + const std::size_t inputLength, const void* grad_output_, void* grad_input_) { const GO* grad_output = static_cast<const GO*>(grad_output_); GI* grad_input = static_cast<GI*>(grad_input_); - for (std::size_t i = 0; i < inputLenght; ++i) { + for (std::size_t i = 0; i < inputLength; ++i) { // Straight Through Estimator grad_input[i] = grad_output[i]; } diff --git a/include/aidge/quantization/PTQ/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp similarity index 100% rename from include/aidge/quantization/PTQ/PTQMetaOps.hpp rename to include/aidge/operator/PTQMetaOps.hpp diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 4fc38bc3b959ec8264ddaddbd4673fbe1f75e4ab..bfe671e3556c3af2c367ce7f86708f01c8e3d3b5 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -124,11 +124,11 @@ namespace Aidge { * @brief Quantize an already normalized (in term of parameters and activations) network. * @param graphView The GraphView to be quantized. * @param nbBits The desired number of bits of the quantization. - * @param applyRounding Whether to apply the rounding operations or not. + * @param noQuant Whether to apply the rounding operations or not. * @param optimizeSigns Whether to take account of the IO signs of the operators or not. * @param verbose Whether to print the sign map or not. */ - void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool applyRounding, bool optimizeSigns, bool verbose); + void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, bool noQuant, bool optimizeSigns, bool verbose); /** * @brief Main quantization routine. Performs every step of the quantization pipeline. @@ -136,12 +136,12 @@ namespace Aidge { * @param nbBits The desired number of bits of the quantization. * @param inputDataSet The input dataset on which the value ranges are computed. * @param clippingMode: Type of the clipping optimization. Can be either 'MAX', 'MSE', 'AA' or 'KL'. - * @param applyRounding Whether to apply the rounding operations or not. + * @param noQuant Whether to apply the rounding operations or not. * @param optimizeSigns Whether to take account of the IO signs of the operators or not. * @param singleShift Whether to convert the scaling factors into powers of two. If true the approximations are compensated using the previous nodes weights. * @param verbose Whether to print internal informations about the quantization process. */ - void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool applyRounding, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose); + void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose); /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. diff --git a/include/aidge/quantization/QAT/QAT_LSQ.hpp b/include/aidge/quantization/QAT/QAT_LSQ.hpp index a44c71b04ca9e9c6a8fba27c615c99b4893d3d8c..922187abca915daa1c00f3949d0d791b0d3e1c39 100644 --- a/include/aidge/quantization/QAT/QAT_LSQ.hpp +++ b/include/aidge/quantization/QAT/QAT_LSQ.hpp @@ -22,22 +22,13 @@ namespace Aidge { namespace QuantLSQ { /** - * @brief Insert the LSQ quantizer nodes in a given GraphView - * @param graphView The GraphView containing the graph to quantize. + * @brief Given a GraphView with parameters properly initialized, insert + * the LSQ quantizer nodes, and setup the adjustment their step-sizes. + * @param graphView The GraphView containing the network to quantize. * @param nbBits Number of quantization bits. - * @param span Fixed output span of the quantizers. */ -void insertQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, float step_size); -/** - * @brief Given a GraphView with parameters properly initialized and some calibration data, - * insert the LSQ quantizer nodes, and adjust their step-sizes. - * @param graphView The GraphView containing the graph to quantize. - * @param nbBits Number of quantization bits. - * @param calibrationData Calibration data used to adjust the spans. - * @param scale Multiplicative constant applied to the spans. - */ -void insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, std::size_t nbBits, std::shared_ptr<Tensor> calibrationData); +void setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits); } // namespace QuantLSQ } // namespace Aidge diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index b5193bddcfe345a1702f02fcc139a4cf5b94a1ce..1de797693468273814f4c5e82a161991648d06ff 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -78,7 +78,7 @@ void init_PTQ(py::module &m) { :type value_ranges: list of float. )mydelimiter"); - m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quant")=false, py::arg("optimize_signs"), py::arg("verbose") = false, + m.def("quantize_normalized_network", &quantizeNormalizedNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("no_quantization")=false, py::arg("optimize_signs"), py::arg("verbose") = false, R"mydelimiter( Quantize an already normalized (in term of parameters and activations) network. :param network: The GraphView to be quantized. @@ -93,7 +93,7 @@ void init_PTQ(py::module &m) { :type verbose: bool )mydelimiter"); - m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = true, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false, + m.def("quantize_network", &quantizeNetwork ,py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("verbose") = false, R"mydelimiter( Main quantization routine. Performs every step of the quantization pipeline. :param network: The GraphView to be quantized. diff --git a/python_binding/pybind_QAT_LSQ.cpp b/python_binding/pybind_QAT_LSQ.cpp index 206985efe4558a84ce1ed67a1264bd6902213764..4bba3b6baa5eda41a024399eb1be1402c74b2c1b 100644 --- a/python_binding/pybind_QAT_LSQ.cpp +++ b/python_binding/pybind_QAT_LSQ.cpp @@ -23,8 +23,6 @@ void init_QAT_LSQ(py::module &m) { auto mQuantLSQ = m.def_submodule("lsq"); - mQuantLSQ.def("insert_quantizers", &QuantLSQ::insertQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("step_size")); - - mQuantLSQ.def("insert_and_init_quantizers", &QuantLSQ::insertAndInitQuantizers, py::arg("network"), py::arg("nb_bits"), py::arg("calibration_data")); + mQuantLSQ.def("setup_quantizers", &QuantLSQ::setupQuantizers, py::arg("network"), py::arg("nb_bits")); } } // namespace Aidge diff --git a/setup.py b/setup.py index 1bfc0ac515fd8cceeec4cba666addc1e7666fd25..cde7c1e513e8f3092474bddcb57842efced415e6 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ class AidgePkgBuild(build_ext): cxx_compiler = os.environ.get("AIDGE_CXX_COMPILER", "g++") build_type = os.environ.get("AIDGE_BUILD_TYPE", "Release") asan = os.environ.get("AIDGE_ASAN", "OFF") - with_cuda = os.environ.get("AIDGE_WITH_CUDA", "OFF") + with_cuda = os.environ.get("AIDGE_WITH_CUDA", "ON") # default could be "OFF" cmake_arch = os.environ.get("AIDGE_CMAKE_ARCH", "") build_gen = os.environ.get("AIDGE_BUILD_GEN", "") diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 5265d9c9b1326e73ee4080fe5f69fed5047a0dbb..28858d0e3c693a7620bc32806008523e0602faa9 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -24,6 +24,12 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Log.hpp" +#include "aidge/operator/Mul.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Abs.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Round.hpp" + namespace Aidge { @@ -39,27 +45,58 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) { - // Get the tensor data pointer - double * castedTensor = static_cast<double *> (tensor->getImpl()->rawPtr()); - - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; + auto mulOp = Mul_Op(); + mulOp.setDataType(tensor->dataType()); + mulOp.setBackend(tensor->backend()); + + std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(scaling); + scalingTensor->setDataType(tensor->dataType()); + scalingTensor->setBackend(tensor->backend()); + + mulOp.associateInput(0, tensor); + mulOp.associateInput(1, scalingTensor); + + mulOp.forward(); + + auto outTensor = mulOp.getOutput(0); + *tensor = *outTensor; + //tensor->copyCast(*outTensor); } +// TODO : make the retreival of argmax values backend independant (refCastFrom) static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*> (tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); - } - } - return maxValue; + // get the abs tensor + + std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); + + // flatten the abs tensor + + std::int64_t nbElement = tensor->size(); + + auto reshapeOp = Reshape_Op({nbElement}); + reshapeOp.setDataType(tensor->dataType()); + reshapeOp.setBackend(tensor->backend()); + + reshapeOp.associateInput(0, absTensor); + reshapeOp.forward(); + std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); + + // Get the argmax + + auto argmaxOp = ArgMax_Op(0, true, false); + argmaxOp.setDataType(tensor->dataType()); + argmaxOp.setBackend(tensor->backend()); + + argmaxOp.associateInput(0, flatTensor); + argmaxOp.forward(); + std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0); + + // Return the max + + int maxIndex = std::round(argmaxTensor->get<double>(0)); + + return flatTensor->get<double>(maxIndex); } void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta) @@ -83,22 +120,13 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD if (isAffine(node)) affineNodeVector.push_back(node); - if (affineNodeVector.empty()) { - Log::notice("No affine nodes found in the network. CLE cannot be applied."); - return; - } double maxRangeDelta; - int iteration = 0; do { - ++iteration; maxRangeDelta = 0.0; - //std::cout << " ----- " << std::endl; - //for (std::shared_ptr<Node> node : affineNodeVector) - // std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl; - - for (std::size_t i = 0; i < (affineNodeVector.size() - 1); i++) + + for (size_t i = 0; i < (affineNodeVector.size() - 1); i++) { std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n2 = affineNodeVector[i+1]; @@ -120,9 +148,6 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD } } while (maxRangeDelta > targetDelta); - - Log::notice("CLE completed after {} iterations. Final max range delta: {:.6f}", - iteration, maxRangeDelta); } } \ No newline at end of file diff --git a/src/PTQ/Clipping.cpp b/src/PTQ/Clipping.cpp index 57ad7a836bbb6251a8eeb6da87e3647b4f54afe2..66b0ab36fba7634d7ee350cdccb27895ffa52da1 100644 --- a/src/PTQ/Clipping.cpp +++ b/src/PTQ/Clipping.cpp @@ -26,7 +26,7 @@ std::map<std::string, std::vector<int>> computeHistograms(std::map<std::string, std::shared_ptr<Node> firstNode = retrieveNodeVector(graphView)[0]; - //std::cout << " COMPUTING HISTOGRAMS ... " << std::endl; + // Log::debug(" COMPUTING HISTOGRAMS ... "); std::map<std::string, std::vector<int>> histograms; diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 0e26313475bbbda23a56dcdda52d55a0a5af8204..7c29ee0b9178fbb07f4a2d5edf9f0ad7ac8dcac4 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -12,7 +12,7 @@ #include "aidge/quantization/PTQ/CLE.hpp" #include "aidge/quantization/PTQ/Clipping.hpp" #include "aidge/quantization/PTQ/PTQ.hpp" -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" +#include "aidge/operator/PTQMetaOps.hpp" #include "aidge/data/Tensor.hpp" @@ -28,6 +28,12 @@ #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Abs.hpp" +#include "aidge/operator/Reshape.hpp" +#include "aidge/operator/Round.hpp" + + #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" @@ -66,51 +72,75 @@ bool checkArchitecture(std::shared_ptr<GraphView> graphView) return true; } -static void fillTensor(std::shared_ptr<Tensor> tensor, double value) +static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + auto mulOp = Mul_Op(); + mulOp.setDataType(tensor->dataType()); + mulOp.setBackend(tensor->backend()); - // Fill the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = value; -} + std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(scaling); + scalingTensor->setDataType(tensor->dataType()); + scalingTensor->setBackend(tensor->backend()); -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) -{ - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + mulOp.associateInput(0, tensor); + mulOp.associateInput(1, scalingTensor); - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] *= scaling; + mulOp.forward(); + + auto outTensor = mulOp.getOutput(0); + *tensor = *outTensor; } static void roundTensor(std::shared_ptr<Tensor> tensor) { - // Get the tensor data pointer - double * castedTensor = static_cast <double *> (tensor->getImpl()->rawPtr()); + auto roundOp = Round_Op(); + roundOp.setDataType(tensor->dataType()); + roundOp.setBackend(tensor->backend()); - // Rescale the tensor - for(std::size_t i = 0; i < tensor->size(); i++) - castedTensor[i] = std::nearbyint(castedTensor[i]);//Round + roundOp.associateInput(0, tensor); + roundOp.forward(); + + auto outTensor = roundOp.getOutput(0); + *tensor = *outTensor; } -static double getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor) +// TODO : make the retreival of argmax values backend independant (refCastFrom) +static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { - // Get the tensor data pointer and edit it - double * castedTensor = static_cast<double*>(tensor->getImpl()->rawPtr()); - - // Get the tensor absolute max value - double maxValue = 0.0f; - for(std::size_t i = 0; i < tensor->size(); ++i) { - if(std::fabs(castedTensor[i]) > maxValue) { - maxValue = std::fabs(castedTensor[i]); - } - } - return maxValue; + // get the abs tensor + + std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); + + // flatten the abs tensor + + std::int64_t nbElement = tensor->size(); + + auto reshapeOp = Reshape_Op({nbElement}); + reshapeOp.setDataType(tensor->dataType()); + reshapeOp.setBackend(tensor->backend()); + + reshapeOp.associateInput(0, absTensor); + reshapeOp.forward(); + std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); + + // Get the argmax + + auto argmaxOp = ArgMax_Op(0, true, false); + argmaxOp.setDataType(tensor->dataType()); + argmaxOp.setBackend(tensor->backend()); + + argmaxOp.associateInput(0, flatTensor); + argmaxOp.forward(); + std::shared_ptr<Tensor> argmaxTensor = argmaxOp.getOutput(0); + + // Return the max + + int maxIndex = std::round(argmaxTensor->get<double>(0)); + + return flatTensor->get<double>(maxIndex); } + // TODO : pass nodeVector by reference ... static std::vector<std::shared_ptr<Node>> removeMatchingNodes(std::vector<std::shared_ptr<Node>> nodeVector, std::string nodeType) { @@ -185,6 +215,8 @@ void prepareNetwork(std::shared_ptr<GraphView> graphView) { removeFlatten(graphView); + sanitizeNodeNames(graphView); + bool containsBatchNorm = false; std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); @@ -876,50 +908,42 @@ static void insertCompensationNodes(std::shared_ptr<GraphView> graphView, std::u for (std::shared_ptr<Node> node : nodeVector) { - // A merging node is always followed by a scaling node at this point ... + // A merging node is always followed by a Quantizer node at this point if (node->type() == "Quantizer") { + // check if the Quantizer is a residual one, and insert a compensation node if so ... + bool prevNodeIsForking = ((node->getParent(0))->getChildren().size() > 1); bool prevNodeIsAffine = isAffine(node->getParent(0)); bool insertNode = prevNodeIsForking || !prevNodeIsAffine; if (insertNode) { - // create and insert the multplicative node + // create and insert the multplicative node before the Quantizer std::string mulNodeName = makeUniqueName(node->name() + "_Mul", graphView); std::shared_ptr<Node> mulNode = Mul(mulNodeName); - mulNode->getOperator()->setDataType(DataType::Float64); // getDataType(parentNode) mulNode->getOperator()->setBackend("cpu"); graphView->insertParent(node, mulNode, 0, 0, 0); - // create and insert the producer node - - std::shared_ptr<Tensor> inputTensor = std::static_pointer_cast<Tensor> (mulNode->getOperator()->getRawInput(0)); - std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(); + // Add the coeff producer to the multiplier node - coeffTensor->setDataType(DataType::Float64); // getDataType(parentNode) - coeffTensor->setBackend("cpu"); + std::shared_ptr<Node> coeffProducer = addProducer(mulNode, 1, {1}, ""); + std::shared_ptr<Tensor> coeffTensor = std::make_shared<Tensor>(signedMax); + coeffProducer->getOperator()->setOutput(0, coeffTensor); - coeffTensor->resize(inputTensor->dims()); - fillTensor(coeffTensor, 1); + coeffProducer->getOperator()->setDataType(DataType::Float64); + coeffProducer->getOperator()->setBackend("cpu"); - std::shared_ptr<Node> producerNode = Producer(coeffTensor, makeUniqueName("coeff", graphView)); - producerNode->addChild(mulNode); - graphView->add(producerNode); + graphView->add(coeffProducer); // needed ? - // rescale the coeffs and edit scaling factor + // Adapt the scaling factor value accordingly - fillTensor(coeffTensor, signedMax); - - double currScalingFactor = getScalingFactor(node); // XXX bad naming ! + double currScalingFactor = getScalingFactor(node); updateScalingFactor(node, currScalingFactor / signedMax); - - // TODO : double check this !!! - //std::cout << getTensorAbsoluteMax(coeffTensor) << std::endl; } } } @@ -931,7 +955,8 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool for (std::shared_ptr<Node> node : nodeVector) { - // Use A meatoperator of type Scaling of MulCompensation instead + // TODO : use Compensation nodes instead of Mul nodes + if (isAffine(node) || (node->type() == "Mul")) { std::shared_ptr<Node> scalingNode = (*node->getChildren().begin()); @@ -940,7 +965,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool double approx = std::pow(2, std::ceil(std::log2(base))); - updateScalingFactor(scalingNode,approx); + updateScalingFactor(scalingNode, approx); double ratio = base / approx; @@ -954,7 +979,7 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); rescaleTensor(biasTensor, ratio); if (!noQuant) - roundTensor(biasTensor); + roundTensor(biasTensor); } } } @@ -962,7 +987,6 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView, bool static void printScalingFactors(std::shared_ptr<GraphView> graphView) { - Log::info(" === SCALING FACTORS === "); for (auto node : retrieveNodeVector(graphView)) if (node->type() == "Scaling" || node->type() == "Quantizer") { @@ -995,7 +1019,7 @@ static void printRanges(std::shared_ptr<GraphView> graphView, std::map<std::stri auto scheduling = scheduler.getStaticScheduling(); for (auto node : scheduling) if (node->type() == "Scaling") - fmt::println("{} range = {}", node->name(), valueRanges[node->name()]); + Log::info(" {} range = {} ", node->name(), valueRanges[node->name()]); } void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool verbose) @@ -1024,13 +1048,13 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, Log::info(" Computing the value ranges ..."); std::map<std::string, double> valueRanges = computeRanges(graphView, inputDataSet, true, useCuda); - //std::cout << " === RANGES (BEFORE ADJUST) ===" << std::endl; + //Log::info(" === RANGES (BEFORE ADJUST) ==="); //printRanges(graphView, valueRanges); Log::info(" Optimizing the clipping values ..."); valueRanges = adjustRanges(clippingMode, valueRanges, nbBits, graphView, inputDataSet, useCuda, verbose); - //std::cout << " === RANGES (AFTER ADJUST) ===" << std::endl; + //Log::info(" === RANGES (AFTER ADJUST) ==="); //printRanges(graphView, valueRanges); Log::info(" Normalizing the activations ..."); @@ -1051,14 +1075,15 @@ void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, if (verbose) printScalingFactors(graphView); - //std::cout << " === SCALINGS (BEFORE CAST) ===" << std::endl; + //Log::info(" === SCALINGS (BEFORE CAST) ==="); //printScalingFactors(graphView); setupDataType(graphView, inputDataSet, initialDataType); + if (useCuda) graphView->setBackend("cuda"); - //std::cout << " === SCALINGS (AFTER CAST) ===" << std::endl; + //Log::info(" === SCALINGS (AFTER CAST) ==="); //printScalingFactors(graphView); Log::info(" Reseting the scheduler ..."); @@ -1098,7 +1123,7 @@ void clearBiases(std::shared_ptr<GraphView> graphView) void devPTQ(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> node : graphView->getNodes()) - fmt::println(" UUU : {}", node->name()); + Log::info(" UUU : {}", node->name()); } } diff --git a/src/QAT/QAT_FixedQ.cpp b/src/QAT/QAT_FixedQ.cpp index 9160b4ae6add5ae0347e008962956dc90c3a36fd..6ada53239f92d19f96dc87e0b91247aa093caecf 100644 --- a/src/QAT/QAT_FixedQ.cpp +++ b/src/QAT/QAT_FixedQ.cpp @@ -91,7 +91,7 @@ static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator()); float inputStd = getTensorStd(op->getInput(0)); inputStats.insert(std::make_pair(node->name(), inputStd)); - fmt::println("{} -> {}", node->name(), inputStd); + Log::info(" {} -> {} ", node->name(), inputStd); } } @@ -108,7 +108,7 @@ static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> const auto op = std::static_pointer_cast<FixedQ_Op>(node->getOperator()); float paramStd = getTensorStd(op->getInput(1)); paramStats.insert(std::make_pair(node->name(), paramStd)); - fmt::println("{} -> {}", node->name(), paramStd); + Log::info(" {} -> {} ", node->name(), paramStd); } } @@ -156,7 +156,7 @@ void QuantFixedQ::devQAT(std::shared_ptr<GraphView> graphView) scheduler.generateScheduling(); auto s = scheduler.getStaticScheduling(); for (std::shared_ptr<Node> node : s) - fmt::println(" name : {}", node->name()); + Log::info(" name : {} ", node->name()); } } \ No newline at end of file diff --git a/src/QAT/QAT_LSQ.cpp b/src/QAT/QAT_LSQ.cpp index 9b51e846df498a9303b7373ae1c86d4b007a96f0..6eae077b060027eb4029f6b59f55376a1674df70 100644 --- a/src/QAT/QAT_LSQ.cpp +++ b/src/QAT/QAT_LSQ.cpp @@ -21,193 +21,152 @@ #include "aidge/graph/Matching.hpp" #include "aidge/recipes/QuantRecipes.hpp" -namespace Aidge { -void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, float stepSize) +namespace Aidge { - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); +static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor).abs().mean(); + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = valueTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + return localTensor.get<float>(0); +} - std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; +static float getTensorStd(std::shared_ptr<Tensor> tensor) +{ + auto valueTensor = (*tensor); + + auto skewedTensor = valueTensor - valueTensor.mean(); + auto squaredTensor = skewedTensor * skewedTensor; + auto varianceTensor = squaredTensor.mean(); - // INPUT QUANTIZERS INSERTION + std::shared_ptr<Tensor> fallback; + auto localTensor = varianceTensor.refCastFrom(fallback, DataType::Float32, "cpu"); + + float variance = localTensor.get<float>(0); + return std::sqrt(variance); +} - // TODO : double check this, and use createUniqueName() - auto inputQuantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); - auto inputQuantizerNode = LSQ(signedRange, inputQuantizerName); - // Set the step size +// INIT THE STEP SIZE OF A QUANTIZER NODE - auto inputStepSizeOp = inputQuantizerNode->getParent(1)->getOperator(); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); +static bool initStepSize(std::shared_ptr<Node> quantizer) +{ + const auto quantizerOp = std::static_pointer_cast<LSQ_Op>(quantizer->getOperator()); - // Absorb the ReLU when possible ... + // This formula is the one proposed in the paper ... - // XXX is this safe ??? - bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); - // bool nodeHasParent = (linearNode->getParents().size() != 0); + // float inputAbsMean = getTensorAbsMean(quantizerOp->getInput(0)); + // float stepSize = 2.0f * (inputAbsMean / std::sqrt(quantizerOp->range().second)); - if (nodeHasParent) { - auto parentNode = linearNode->getParents()[0]; - if (parentNode->type() == "ReLU") { - auto inputQuantizerOp = std::static_pointer_cast<LSQ_Op> (inputQuantizerNode->getOperator()); - inputQuantizerOp->range() = unsignedRange; - graphView->replace({parentNode}, {}); - } - } + // .. but this formula seems to work better !!! - // We need to handle the case where the linear node is the first one ... + float inputStd = getTensorStd(quantizerOp->getInput(0)); + float stepSize = 8.0f * (inputStd / (quantizerOp->range().second)); - if (nodeHasParent) { - graphView->insertParent(linearNode, inputQuantizerNode, 0, 0, 0); - } else { - inputQuantizerNode->addChild(graphView); - graphView->add(inputQuantizerNode); - } + // TODO : use the scalar constructor + auto stepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - // PARAM QUANTIZERS INSERTION + // XXX Manage backend here ? + stepSizeTensor->setBackend(quantizerOp->getInput(0)->backend()); + stepSizeTensor->setDataType(quantizerOp->getInput(0)->dataType()); - // TODO : double check this, and use createUniqueName() - auto paramQuantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); - auto paramQuantizerNode = LSQ(signedRange, paramQuantizerName); - graphView->insertParent(linearNode, paramQuantizerNode, 1, 0, 0); + auto stepSizeProducer = quantizer->getParent(1); - // Set the step size + stepSizeProducer->getOperator()->setOutput(0, stepSizeTensor); - auto paramStepSizeOp = paramQuantizerNode->getParent(1)->getOperator(); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); - } + Log::notice(" [ INIT STEP SIZE = {} ] ", stepSize); + return false; } -static float getTensorAbsMean(std::shared_ptr<Tensor> tensor) +static void setupInputQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - auto backend = tensor->backend(); - if (backend == "cuda") - tensor->setBackend("cpu"); + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - float acc = 0; - float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr()); - for(std::size_t i = 0; i < tensor->size(); i++) - acc += std::abs(castedTensor[i]); - acc /= static_cast<float> (tensor->size()); + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); - if (backend == "cuda") - tensor->setBackend("cuda"); + // Log::notice(" SET INPUT QUANTIZER : {} ", linearNode->type()); - return acc; -} + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; + std::pair<int, int> unsignedRange = {0, std::pow(2, nbBits) - 1}; -static std::map<std::string, float> collectInputStats(std::shared_ptr<GraphView> graphView, std::shared_ptr<Tensor> calibrationData, bool useCuda) -{ - // Propagate the calibration tensor + // Create the input quantizer node - SequentialScheduler scheduler(graphView); - scheduler.resetScheduling(); - scheduler.forward(true, {calibrationData}); + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_i", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - // Store the input tensor statistics + // Init the step-size using the node call stack - if (useCuda) - graphView->setBackend("cpu"); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - std::map<std::string, float> inputStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float inputAbsMean = getTensorAbsMean(op->getInput(0)); - inputStats.insert(std::make_pair(node->name(), inputAbsMean)); - fmt::println("{} -> {}", node->name(), inputAbsMean); - } - } + // Absorb the ReLU when possible ... - if (useCuda) - graphView->setBackend("cuda"); + bool nodeHasParent = static_cast<bool> (linearNode->getParents()[0]); // XXX is this safe ? - return inputStats; -} + if (nodeHasParent) + { + bool allParentsAreReLU = true; + for (auto parentNode : linearNode->getParents()) + if (parentNode->type() != "ReLU") + allParentsAreReLU = false; + + if (allParentsAreReLU) { + auto quantizerOp = std::static_pointer_cast<LSQ_Op> (quantizerNode->getOperator()); + quantizerOp->range() = unsignedRange; + } -static std::map<std::string, float> collectParamStats(std::shared_ptr<GraphView> graphView, bool useCuda) -{ - if (useCuda) - graphView->setBackend("cpu"); + // TODO : remove the ReLUs when possible + } - std::map<std::string, float> paramStats; - for (auto node : graphView->getNodes()) - { - if (node->type() == "FC" || node->type() == "Conv2D") // TODO: use graph matching !!! - { - const auto op = std::static_pointer_cast<LSQ_Op>(node->getOperator()); - float paramAbsMean = getTensorAbsMean(op->getInput(1)); - paramStats.insert(std::make_pair(node->name(), paramAbsMean)); - fmt::println("{} -> {}", node->name(), paramAbsMean); + // Insert the quantizer in the graphView ... + // (We need to handle the case where the linear node is the first one) + + if (nodeHasParent) { + graphView->insertParent(linearNode, quantizerNode, 0, 0, 0); + } else { + quantizerNode->addChild(graphView); + graphView->add(quantizerNode); } } - - if (useCuda) - graphView->setBackend("cuda"); - - return paramStats; } -static void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView, std::map<std::string, float> inputStats, std::map<std::string, float> paramStats) -{ - const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|FC#)"); +// PARAM QUANTIZERS INSERTION - for (const auto& match : matches) - { - auto linearNode = match.graph->rootNode(); +static void setupParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) +{ + const auto matches = SinglePassGraphMatching(graphView).match("(Conv2D#|PaddedConv2D#|FC#)"); - // INPUT QUANTIZERS STEP-SIZES + std::pair<int, int> signedRange = {-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1}; - auto inputQuantNode = linearNode->getParent(0); - auto inputQuantOp = std::static_pointer_cast<LSQ_Op>(inputQuantNode->getOperator()); + for (const auto& match : matches) + { + auto linearNode = match.graph->rootNode(); - float absMean = inputStats[linearNode->name()]; - float stepSize = 2.0f * (absMean / std::sqrt(inputQuantOp->range().second)); + // Log::notice(" SET PARAM QUANTIZER : {} ", linearNode->type()); - auto inputStepSizeOp = inputQuantNode->getParent(1)->getOperator(); - // XXX inputStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto inputStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - inputStepSizeOp->setOutput(0, inputStepSizeTensor); + // TODO : double check this, and use createUniqueName() + auto quantizerName = makeUniqueName(linearNode->name() + "_lsq_p", graphView); + auto quantizerNode = LSQ(signedRange, quantizerName); - // PARAM QUANTIZERS STEP-SIZES + // Init the step-size using the node call stack - auto paramQuantNode = linearNode->getParent(1); - auto paramQuantOp = std::static_pointer_cast<LSQ_Op>(paramQuantNode->getOperator()); + quantizerNode->addBeforeForward([quantizerNode](){ return initStepSize(quantizerNode); }); - absMean = paramStats[linearNode->name()]; - stepSize = 2.0f * (absMean / std::sqrt(paramQuantOp->range().second)); + // Insert the quantizer in the graphView - auto paramStepSizeOp = paramQuantNode->getParent(1)->getOperator(); - // XXX paramStepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}}))); - auto paramStepSizeTensor = std::make_shared<Tensor>(Array1D<float, 1>({{stepSize}})); - paramStepSizeOp->setOutput(0, paramStepSizeTensor); + graphView->insertParent(linearNode, quantizerNode, 1, 0, 0); } } -void QuantLSQ::insertAndInitQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits, std::shared_ptr<Tensor> calibrationData) +void QuantLSQ::setupQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) { - bool useCuda = (calibrationData->backend() == "cuda"); - - // Collect the tensor statisics - auto inputStats = collectInputStats(graphView, calibrationData, useCuda); - - auto paramStats = collectParamStats(graphView, useCuda); - - // Insert the quantizers - insertQuantizers(graphView, nbBits, 1.0); - - // Adjust the quantizers step-sizes - adjustQuantizersStepSizes(graphView, inputStats, paramStats); + sanitizeNodeNames(graphView); + setupInputQuantizers(graphView, nbBits); + setupParamQuantizers(graphView, nbBits); } } \ No newline at end of file diff --git a/src/backend/cuda/operator/LSQImpl.cpp b/src/backend/cuda/operator/LSQImpl.cpp index c66bd8a5aa78513b4bcceec83f9c9d87ffed2b11..fa45f211e72f6742b72584aadf2a109c3bdca594 100644 --- a/src/backend/cuda/operator/LSQImpl.cpp +++ b/src/backend/cuda/operator/LSQImpl.cpp @@ -52,19 +52,6 @@ void Aidge::LSQImpl_cuda::backward() { std::shared_ptr<Tensor> gra_int1 = op_.getInput(1)->grad(); std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad(); - // XXX -/* - size_t tmp; - - cudaDeviceSetLimit(cudaLimitStackSize, 2048); - cudaDeviceGetLimit(&tmp, cudaLimitStackSize ); - printf(" stack limit = %ld \n", tmp); - - cudaDeviceSetLimit(cudaLimitMallocHeapSize, 100000000); - cudaDeviceGetLimit(&tmp, cudaLimitMallocHeapSize); - printf(" heap limit = %ld \n", tmp); -*/ - if (gra_int0->size() > mWorkspaceSize) { // std::cout << " reallocation " << sizeof(gra_int0) << " " << gra_int0->size() << std::endl; if (mWorkspace != nullptr) { @@ -87,12 +74,7 @@ void Aidge::LSQImpl_cuda::backward() { gra_int0->getImpl()->rawPtr(), gra_int1->getImpl()->rawPtr(), mWorkspace); -/* - gra_int1->setBackend("cpu"); - float *castedTensor = static_cast<float *> (gra_int1->getImpl()->rawPtr()); - std::cout << castedTensor[0] << std::endl; - gra_int1->setBackend("cuda"); -*/ + } Aidge::LSQImpl_cuda::~LSQImpl_cuda() { diff --git a/src/PTQ/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp similarity index 98% rename from src/PTQ/PTQMetaOps.cpp rename to src/operator/PTQMetaOps.cpp index 77018c23aee2f1ef6f430389393fd35e97baa0f6..56245da47076d8930ce29ab75e549d97d0d7493d 100644 --- a/src/PTQ/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -9,13 +9,12 @@ * ********************************************************************************/ -#include "aidge/quantization/PTQ/PTQMetaOps.hpp" +#include "aidge/operator/PTQMetaOps.hpp" #include <memory> #include <string> #include <utility> -//Operator #include "aidge/operator/Clip.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/Round.hpp" diff --git a/src/recipes/QuantRecipes.cpp b/src/recipes/QuantRecipes.cpp index 6e1dcdb1b64c0a1e94c74ce66cb71f1a458bca35..f03eb462088b16645fe600769e2a5e2c990f21b6 100644 --- a/src/recipes/QuantRecipes.cpp +++ b/src/recipes/QuantRecipes.cpp @@ -9,24 +9,13 @@ * ********************************************************************************/ -/* -#include "aidge/data/Tensor.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/scheduler/SequentialScheduler.hpp" -#include "aidge/scheduler/Scheduler.hpp" -#include "aidge/utils/Log.hpp" - -#include "aidge/operator/Producer.hpp" -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/ReLU.hpp" -#include "aidge/operator/Scaling.hpp" -*/ #include "aidge/operator/Conv.hpp" #include "aidge/operator/BatchNorm.hpp" //#include "aidge/quantization/PTQ/PTQ.hpp" #include "aidge/recipes/QuantRecipes.hpp" +#include "aidge/graph/Node.hpp" + namespace Aidge { @@ -55,14 +44,16 @@ void insertBatchNormNodes(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> parentNode : graphView->getNodes()) { - if (parentNode->type() == "Conv2D") + // TODO : use graph matching + + if (parentNode->type() == "Conv2D" || parentNode->type() == "PaddedConv2D") { - std::shared_ptr<Conv_Op<2>> convOperator = std::static_pointer_cast<Conv_Op<2>> (parentNode->getOperator()); - int nb_channels = convOperator->getInput(1)->dims()[0]; - fmt::println(" NB CHANNELS = {}", nb_channels); // TODO : remove this ... + std::shared_ptr<OperatorTensor> convOperator = std::static_pointer_cast<OperatorTensor> (parentNode->getOperator()); + int nbChannels = convOperator->getInput(1)->dims()[0]; + Log::notice(" NB CHANNELS = {} ", nbChannels); std::string batchnormNodeName = makeUniqueName(parentNode->name() + "_BN", graphView); - std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nb_channels, 1e-5, 0.1, false, batchnormNodeName); + std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nbChannels, 1e-5, 0.1, false, batchnormNodeName); batchnormNode->getOperator()->setDataType(DataType::Float32); batchnormNode->getOperator()->setBackend("cpu"); @@ -118,6 +109,7 @@ std::string makeUniqueName(std::string baseName, std::shared_ptr<GraphView> grap return newName; } + void sanitizeNodeNames(std::shared_ptr<GraphView> graphView) { for (std::shared_ptr<Node> node : graphView->getNodes())