Skip to content
Snippets Groups Projects
Commit e31dca3a authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add axis attr to Softmax

parent 8dfe6ba6
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
......@@ -24,10 +24,10 @@ namespace Aidge {
// compute kernel registry for forward and backward
class SoftmaxImplForward_cpu
: public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(const DimSize_t, const DimSize_t, const DimSize_t, const void*, void*)> {
: public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> {
};
class SoftmaxImplBackward_cpu
: public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
: public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> {
};
class SoftmaxImpl_cpu : public OperatorImpl {
......
......@@ -22,30 +22,33 @@
namespace Aidge {
template <class I, class O>
void SoftmaxImpl_cpu_forward_kernel(const DimSize_t batchSize,
const DimSize_t channelSize,
const DimSize_t featureSize,
const void* input_,
void* output_) {
void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
{
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t batch = 0; batch < batchSize; ++batch) {
for (std::size_t feature = 0; feature < featureSize; ++feature) {
std::size_t ioIndex = batch*channelSize*featureSize + feature;
std::size_t postAxisElems = 1;
for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
postAxisElems *= inputDims[i];
}
std::size_t preAxisElems = 1;
for (std::size_t i = 0; i < axisIdx; ++i) {
preAxisElems *= inputDims[i];
}
I sum(0.0);
for (std::size_t ch = 0; ch < channelSize; ++ch) {
output[ioIndex] = std::exp(input[ioIndex]);
sum += output[ioIndex];
ioIndex+=featureSize;
for (std::size_t i = 0; i < preAxisElems; ++i) {
for (std::size_t j = 0; j < postAxisElems; ++j) {
// Calculate sum of exponentials within the axis
I sumExp = 0;
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
sumExp += std::exp(input[inIdx]);
}
ioIndex = batch*channelSize*featureSize + feature;
for (std::size_t ch = 0; ch < channelSize; ++ch) {
output[ioIndex] /= sum;
ioIndex += featureSize;
// Calculate softmax for the current slice along the axis
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
output[inIdx] = std::exp(input[inIdx]) / sumExp;
}
}
}
......
......@@ -28,20 +28,18 @@ Aidge::NbElts_t Aidge::SoftmaxImpl_cpu::getNbRequiredProtected(const Aidge::IOIn
void Aidge::SoftmaxImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input #0");
assert(mOp.getInput(0)->nbDims()>1);
// assert(mOp.getInput(0)->nbDims()>1);
// Find the correct kernel type
auto kernelFunc = Registrar<SoftmaxImplForward_cpu>::create({
mOp.getInput(0)->dataType(),
mOp.getOutput(0)->dataType()});
DimSize_t batchSize = mOp.getInput(0)->dims()[0];
DimSize_t channelSize = mOp.getInput(0)->dims()[1];
DimSize_t featureSize = mOp.getInput(0)->sizeM1()/channelSize;
Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes();
const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
// Call kernel
kernelFunc(batchSize,
channelSize,
featureSize,
kernelFunc(axisIdx,
mOp.getInput(0)->dims(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr());
}
......@@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
}
});
std::shared_ptr<Node> mySoftmax = Softmax();
std::shared_ptr<Node> mySoftmax = Softmax(1);
mySoftmax->getOperator()->setDatatype(DataType::Float32);
mySoftmax->getOperator()->setBackend("cpu");
mySoftmax->getOperator()->associateInput(0,input);
......@@ -48,7 +48,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 20; ++i) {
for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
......@@ -107,7 +107,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
}
});
std::shared_ptr<Node> mySoftmax = Softmax();
std::shared_ptr<Node> mySoftmax = Softmax(1);
mySoftmax->getOperator()->setDatatype(DataType::Float32);
mySoftmax->getOperator()->setBackend("cpu");
mySoftmax->getOperator()->associateInput(0,input);
......@@ -116,9 +116,8 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 54; ++i) {
for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
// REQUIRE(*mySoftmax->getOperator()->getOutput(0) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment