Skip to content
Snippets Groups Projects
Commit d506f102 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Make FuseBatchNorm work for any type

parent 42056687
No related branches found
No related tags found
1 merge request!57Add Convert operator (a.k.a. Transmitter)
Pipeline #35565 failed
......@@ -464,17 +464,29 @@ class Tensor : public Data,
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
}
template <typename expectedType>
const expectedType& get(std::size_t idx) const {
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
}
template <typename expectedType>
expectedType& get(std::vector<std::size_t> coordIdx){
return get<expectedType>(getIdx(coordIdx));
}
template <typename expectedType>
const expectedType& get(std::vector<std::size_t> coordIdx) const {
return get<expectedType>(getIdx(coordIdx));
}
template <typename expectedType>
void set(std::size_t idx, expectedType value){
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
void* dataPtr = mImpl->getRaw(idx);
std::memcpy(dataPtr, &value, sizeof(expectedType));
expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx));
*dataPtr = value;
}
template <typename expectedType>
......@@ -695,6 +707,20 @@ class Tensor : public Data,
return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
}
/**
* Return a reference to a Tensor with float32 type on CPU:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @return Reference to either itself or to fallback.
*/
Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) {
return refCast(fallback, DataType::Float32).ref(fallback, "cpu");
}
private:
///\bug not protected against overflow
std::size_t computeSize() {
......
......@@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
const std::shared_ptr<Tensor> scale = batchOp->getInput(1);
const std::shared_ptr<Tensor> shift = batchOp->getInput(2);
const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3);
const std::shared_ptr<Tensor> b_var = batchOp->getInput(4);
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf);
const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf);
const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf);
const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf);
const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
......@@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
assert(scale->size() == convNbOutChannels);
assert(shift->size() == convNbOutChannels);
assert(b_mean->size() == convNbOutChannels);
assert(b_var->size() == convNbOutChannels);
assert(scale.size() == convNbOutChannels);
assert(shift.size() == convNbOutChannels);
assert(b_mean.size() == convNbOutChannels);
assert(b_var.size() == convNbOutChannels);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
......@@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
unsigned int count = 0;
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// TODO: get() assumed dataType is float...
if (b_var->get<float>(outChId) > 1.0e-12) {
meanVariance += b_var->get<float>(outChId);
if (b_var.get<float>(outChId) > 1.0e-12) {
meanVariance += b_var.get<float>(outChId);
++count;
}
else {
......@@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
}
std::shared_ptr<Tensor> weight = convOp -> getInput(1);
std::shared_ptr<Tensor> bias = convOp -> getInput(2);
std::shared_ptr<Tensor> weightBuf, biasBuf;
Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf);
Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf);
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale->get<float>(outChId)
/ std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0)
? b_var->get<float>(outChId) : meanVariance));
const float factor = scale.get<float>(outChId)
/ std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0)
? b_var.get<float>(outChId) : meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
// TODO : suppose weights are float
float weightValue = weight->get<float>(currentIdx);
weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
float weightValue = weight.get<float>(currentIdx);
weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
}
}
}
// TODO : check if noBias==true is set, then set biasValue to 0
float biasValue = bias->get<float>(outChId);
float biasValue = bias.get<float>(outChId);
biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor;
biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor;
bias->set<float>(outChId, biasValue);
bias.set<float>(outChId, biasValue);
}
// Copy values back to the original tensors (actual copy only if needed)
convOp->getInput(1)->copyCastFrom(weight);
convOp->getInput(2)->copyCastFrom(bias);
GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnormNode,
batchnormNode->input(1).first,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment