Skip to content
Snippets Groups Projects
Commit 7273399a authored by Nathan Thoumine's avatar Nathan Thoumine
Browse files

Add option to choose calibration folder and cache file

parent e0af522d
No related branches found
No related tags found
No related merge requests found
......@@ -32,9 +32,12 @@ typedef struct {
class Graph {
public:
Graph(std::string const& filePath = "",
unsigned int device_id = 0,
int nbbits = -32);
Graph( std::string const& filePath,
std::string const& calibPath,
std::string const& cachePath,
unsigned int device_id,
int nbbits );
~Graph();
void device(unsigned int id);
......@@ -87,6 +90,10 @@ class Graph {
std::vector<void*> _iobuffer;
// Stream
cudaStream_t _stream{nullptr};
// Path to the folder containing samples that your calibration will be built on
std::string _calibPath;
// Path to the calibration cache to load / save your calibration data
std::string _cachePath;
};
......
......@@ -18,10 +18,13 @@ void init_Graph(py::module& m)
;
py::class_<Graph>(m, "Graph")
.def(py::init<std::string, unsigned int, int>(),
.def(py::init<std::string, std::string, std::string, unsigned int, int>(),
py::arg("filepath") = "",
py::arg("device_id") = 0,
py::arg("calibPath") = "./calibration_folder/",
py::arg("cachePath") = "./calibration_cache",
py::arg("device_id") = 0,
py::arg("nb_bits") = -32)
.def("device", &Graph::device, py::arg("id"))
.def("load", &Graph::load, py::arg("filepath"))
.def("save", &Graph::save, py::arg("filepath"))
......
......@@ -5,9 +5,11 @@
#include "IInt8EntropyCalibrator.hpp"
#include <dirent.h>
Graph::Graph(std::string const& filePath,
unsigned int device_id,
int nbbits)
Graph::Graph( std::string const& filePath,
std::string const& calibPath,
std::string const& cachePath,
unsigned int device_id,
int nbbits )
{
// ctor
......@@ -18,6 +20,9 @@ Graph::Graph(std::string const& filePath,
// this->_builderconfig->setMaxWorkspaceSize(MAX_WORKSPACE_SIZE);
this->_builderconfig->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, MAX_WORKSPACE_SIZE);
this->_calibPath = calibPath;
this->_cachePath = cachePath;
CHECK_CUDA_STATUS(cudaStreamCreate(&(this->_stream)));
device(device_id);
......@@ -94,13 +99,13 @@ void Graph::datamode(nvinfer1::DataType datatype)
case nvinfer1::DataType::kINT8: {
if (!this->_builder->platformHasFastInt8()) {
std::cout << "Cannot use INT8 for this platform \nLet default datatype activated." << std::endl;
std::cout << "Cannot use INT8 for this platform \nLet default datatype activated." << std::endl;
return;
}
std::string calibration_folder = "calib/";
std::string calibDir = calibration_folder.empty() ?
"./batches_calib/"
: calibration_folder;
// Mark calibrator as nullptr not to provide an INT8 calibrator
this->_builderconfig->setFlag(nvinfer1::BuilderFlag::kINT8);
std::string calibDir = this->_calibPath;
std::vector<std::string> filesCalib;
struct dirent* pFile;
DIR* pDir = opendir(calibDir.c_str());
......@@ -108,19 +113,19 @@ void Graph::datamode(nvinfer1::DataType datatype)
std::cout << "No directory for batches calibration" << std::endl;
}
else {
while ((pFile = readdir(pDir)) != NULL)
{
while ((pFile = readdir(pDir)) != NULL)
{
if (pFile->d_name[0] != '.') filesCalib.push_back(std::string(calibDir + pFile->d_name));
}
closedir(pDir);
}
closedir(pDir);
}
unsigned int nbCalibFiles = 1;
if(nbCalibFiles == 0) std::cout << "Cannot find calibration files in dir " << calibDir << std::endl;
unsigned int nbInputs = this->_iodescriptors.inputs.size();
BatchStream calibrationStream(1, 3, 32, 32, nbCalibFiles, calibration_folder);
this->_calibrator = new Int8EntropyCalibrator(calibrationStream, 0, "./calibration_cache");
unsigned int batchSize = 1;
BatchStream calibrationStream(batchSize, 3, 32, 32, nbCalibFiles, this->_calibPath);
this->_calibrator = new Int8EntropyCalibrator(calibrationStream, 0, this->_cachePath);
this->_builderconfig->setInt8Calibrator(this->_calibrator);
this->_builderconfig->setFlag(nvinfer1::BuilderFlag::kINT8);
}
break;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment