Skip to content
Snippets Groups Projects
Commit d506f102 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Make FuseBatchNorm work for any type

parent 42056687
No related branches found
No related tags found
No related merge requests found
...@@ -464,17 +464,29 @@ class Tensor : public Data, ...@@ -464,17 +464,29 @@ class Tensor : public Data,
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
} }
template <typename expectedType>
const expectedType& get(std::size_t idx) const {
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
}
template <typename expectedType> template <typename expectedType>
expectedType& get(std::vector<std::size_t> coordIdx){ expectedType& get(std::vector<std::size_t> coordIdx){
return get<expectedType>(getIdx(coordIdx)); return get<expectedType>(getIdx(coordIdx));
} }
template <typename expectedType>
const expectedType& get(std::vector<std::size_t> coordIdx) const {
return get<expectedType>(getIdx(coordIdx));
}
template <typename expectedType> template <typename expectedType>
void set(std::size_t idx, expectedType value){ void set(std::size_t idx, expectedType value){
// TODO : add assert expected Type compatible with datatype // TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size // TODO : add assert idx < Size
void* dataPtr = mImpl->getRaw(idx); expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx));
std::memcpy(dataPtr, &value, sizeof(expectedType)); *dataPtr = value;
} }
template <typename expectedType> template <typename expectedType>
...@@ -695,6 +707,20 @@ class Tensor : public Data, ...@@ -695,6 +707,20 @@ class Tensor : public Data,
return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second); return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
} }
/**
* Return a reference to a Tensor with float32 type on CPU:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @return Reference to either itself or to fallback.
*/
Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) {
return refCast(fallback, DataType::Float32).ref(fallback, "cpu");
}
private: private:
///\bug not protected against overflow ///\bug not protected against overflow
std::size_t computeSize() { std::size_t computeSize() {
......
...@@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
const std::shared_ptr<Tensor> scale = batchOp->getInput(1); std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const std::shared_ptr<Tensor> shift = batchOp->getInput(2); const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf);
const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf);
const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf);
const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf);
const float epsilon = batchOp -> getAttr<float>("Epsilon"); const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
...@@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
assert(scale->size() == convNbOutChannels); assert(scale.size() == convNbOutChannels);
assert(shift->size() == convNbOutChannels); assert(shift.size() == convNbOutChannels);
assert(b_mean->size() == convNbOutChannels); assert(b_mean.size() == convNbOutChannels);
assert(b_var->size() == convNbOutChannels); assert(b_var.size() == convNbOutChannels);
assert(epsilon > 0.0); assert(epsilon > 0.0);
// TODO : no no_bias attribute ? // TODO : no no_bias attribute ?
...@@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
unsigned int count = 0; unsigned int count = 0;
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// TODO: get() assumed dataType is float... if (b_var.get<float>(outChId) > 1.0e-12) {
if (b_var->get<float>(outChId) > 1.0e-12) { meanVariance += b_var.get<float>(outChId);
meanVariance += b_var->get<float>(outChId);
++count; ++count;
} }
else { else {
...@@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr ...@@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
} }
std::shared_ptr<Tensor> weight = convOp -> getInput(1); std::shared_ptr<Tensor> weightBuf, biasBuf;
std::shared_ptr<Tensor> bias = convOp -> getInput(2); Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf);
Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf);
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue: // Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets" // "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf // https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization // to help post-training quantization
const float factor = scale->get<float>(outChId) const float factor = scale.get<float>(outChId)
/ std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) / std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0)
? b_var->get<float>(outChId) : meanVariance)); ? b_var.get<float>(outChId) : meanVariance));
// Weights adjustments // Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) { for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2 // TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
// TODO : suppose weights are float float weightValue = weight.get<float>(currentIdx);
float weightValue = weight->get<float>(currentIdx); weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
} }
} }
} }
// TODO : check if noBias==true is set, then set biasValue to 0 // TODO : check if noBias==true is set, then set biasValue to 0
float biasValue = bias->get<float>(outChId); float biasValue = bias.get<float>(outChId);
biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor; biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor;
bias->set<float>(outChId, biasValue); bias.set<float>(outChId, biasValue);
} }
// Copy values back to the original tensors (actual copy only if needed)
convOp->getInput(1)->copyCastFrom(weight);
convOp->getInput(2)->copyCastFrom(bias);
GraphView::replace(std::set<std::shared_ptr<Node>>({ GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnormNode, batchnormNode,
batchnormNode->input(1).first, batchnormNode->input(1).first,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment