Skip to content
Snippets Groups Projects
Commit c5f707cb authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add WeightInterleaving Operator

parent a924297c
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!262Low bit support for ARM Cortex-M export
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_WEIGHTINTERLEAVING_H_
#define AIDGE_CORE_OPERATOR_WEIGHTINTERLEAVING_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class WeightInterleaving_Op :
public OperatorTensor,
public Registrable<WeightInterleaving_Op, // <Op, backend, implementation creation function>
std::string,
std::function<std::shared_ptr<OperatorImpl>(const WeightInterleaving_Op&)>>
{
public:
static const std::string Type;
WeightInterleaving_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {}
/**
* @brief Copy-constructor.
* @param op WeightInterleaving_Op to copy.
* @details Copies the operator attributes and its output tensor(s), but not
* its input tensors. The new operator has no associated input.
*/
WeightInterleaving_Op(const WeightInterleaving_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::WeightInterleaving_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
/**
* @brief Calculates the required size for the 8-bits`compactData` vector.
*
* This function determines the minimum number of bytes needed in `compactData`
* to store `dataSize` elements compacted to `nb_bits` bits each.
*
* @param dataSize The total number of elements in the input data array.
* @param nb_bits The number of bits to use for each compacted element (from 1 to 7).
* @return std::size_t The required size in bytes for `compactData`.
*/
std::size_t compactDataSize(std::size_t dataSize, std::uint8_t nb_bits);
};
std::shared_ptr<Node> WeightInterleaving(const std::string& name = "");
}
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/WeightInterleaving.hpp"
#include <memory>
#include <string>
#include <vector>
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::WeightInterleaving_Op::Type = "WeightInterleaving";
/**
* @brief Copy-constructor.
* @param op WeightInterleaving_Op to copy.
* @details Copies the operator attributes and its output tensor(s), but not
* its input tensors. The new operator has no associated input.
*/
Aidge::WeightInterleaving_Op::WeightInterleaving_Op(const WeightInterleaving_Op& op)
: OperatorTensor(op)
{
if (op.mImpl) {
SET_IMPL_MACRO(WeightInterleaving_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
}
std::shared_ptr<Aidge::Operator> Aidge::WeightInterleaving_Op::clone() const {
return std::make_shared<WeightInterleaving_Op>(*this);
}
bool Aidge::WeightInterleaving_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
// check input data format is NHWC
AIDGE_ASSERT((getInput(0)->dataFormat() == DataFormat::NHWC),
"Wrong Input tensor Data Format : {} for WeightInterleaving operator (should be DataFormat::NHWC for STM32).", getInput(0)->dataFormat());
// Take the last dimension of the tensor : It is the Channel dimension in format NHWC
// The weights will be compacted along side the channel dimension only
const DimSize_t& lastDim = getInput(0)->dims().back();
// Compute the last dimension size of the tensor after the weight interleaving compression
// TO DO : implement a mechanism to get the number of bits of the DataType
const DataType& dt = getInput(0)->dataType();
std::uint8_t nbBits = 0;
switch (dt) {
case DataType::Int4:
nbBits=4;
break;
case DataType::Int3:
nbBits=3;
break;
case DataType::Int2:
nbBits=2;
break;
default:
AIDGE_ASSERT(true, "Unsupport type for WeightInterleaving {}", dt);
}
const auto lastDimCompression = compactDataSize(lastDim, nbBits);
std::vector<DimSize_t> outputDims = getInput(0)->dims();
outputDims.back() = lastDimCompression;
// <batch, OutChannels>
mOutputs[0]->resize(outputDims);
return true;
}
return false;
}
void Aidge::WeightInterleaving_Op::setBackend(const std::string& name, DeviceIdx_t device) {
SET_IMPL_MACRO(WeightInterleaving_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::WeightInterleaving_Op::getAvailableBackends() const {
return Registrar<WeightInterleaving_Op>::getKeys();
}
std::shared_ptr<Aidge::Node> Aidge::WeightInterleaving(const std::string& name) {
return std::make_shared<Node>(std::make_shared<WeightInterleaving_Op>(), name);
}
std::size_t Aidge::WeightInterleaving_Op::compactDataSize(std::size_t dataSize, std::uint8_t nbBits) {
AIDGE_ASSERT(nbBits > 0 && nbBits < 8, "nbBits must be between 1 and 4"); // Ensure valid bit width
// Calculate the number of `nbBits` segments that can fit in an 8-bit byte.
const unsigned int nbSlot = 8 / nbBits;
// Calculate the number of compacted bytes needed to store all data elements.
// The formula (dataSize + nbSlot - 1) / nbSlot effectively rounds up the division, ensuring that any remaining elements that don't fully fill a byte are accounted for.
std::size_t requiredSize = (dataSize + nbSlot - 1) / nbSlot;
return requiredSize;
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment