Skip to content
Snippets Groups Projects
Commit bfabda5b authored by Maxence Naud's avatar Maxence Naud
Browse files

Fixes

- Change Softmax_Op 'axis' attributes from int to size_t
- Change Transpose_OpImpl to TransposeImpl
- Change snake-case to camel-case names
- fix tests according to AvgPooling_Op changes
parent 678c51ce
No related branches found
No related tags found
No related merge requests found
......@@ -27,8 +27,8 @@ class test_attributes(unittest.TestCase):
out_channels = 8
k_dims = [2, 2]
conv_op = aidge_core.Conv2D(in_channels , out_channels, k_dims).get_operator()
self.assertEqual(conv_op.get_attr("InChannels"), in_channels)
self.assertEqual(conv_op.get_attr("OutChannels"), out_channels)
self.assertEqual(conv_op.in_channels(), in_channels)
self.assertEqual(conv_op.out_channels(), out_channels)
self.assertEqual(conv_op.get_attr("KernelDims"), k_dims)
def test_fc(self):
......
......@@ -30,16 +30,16 @@ class Softmax_Op : public OperatorTensor,
public Registrable<Softmax_Op,
std::string,
std::shared_ptr<OperatorImpl>(const Softmax_Op&)>,
public StaticAttributes<SoftmaxAttr, int> {
public StaticAttributes<SoftmaxAttr, std::size_t> {
public:
static const std::string Type;
Softmax_Op() = delete;
using Attributes_ = StaticAttributes<SoftmaxAttr, int>;
using Attributes_ = StaticAttributes<SoftmaxAttr, std::size_t>;
template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>;
Softmax_Op(int axis)
Softmax_Op(std::size_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SoftmaxAttr::AxisIdx>(axis)) {}
......@@ -76,7 +76,7 @@ public:
}
};
inline std::shared_ptr<Node> Softmax(int axis, const std::string& name = "") {
inline std::shared_ptr<Node> Softmax(std::size_t axis, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Softmax_Op>(axis), name);
}
} // namespace Aidge
......
......@@ -26,9 +26,9 @@
#include "aidge/utils/Types.h"
namespace Aidge {
class Transpose_OpImpl : public OperatorImpl {
class TransposeImpl : public OperatorImpl {
public:
Transpose_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
TransposeImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
......@@ -47,11 +47,11 @@ class Transpose_Op : public OperatorTensor,
template <TransposeAttr e>
using attr = typename Attributes_::template attr<e>;
Transpose_Op(const std::vector<DimSize_t> &output_dims_order)
Transpose_Op(const std::vector<DimSize_t> &outputDimsOrder)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<TransposeAttr::OutputDimsOrder>(output_dims_order))
Attributes_(attr<TransposeAttr::OutputDimsOrder>(outputDimsOrder))
{
mImpl = std::make_shared<Transpose_OpImpl>(*this);
mImpl = std::make_shared<TransposeImpl>(*this);
}
/**
......@@ -66,7 +66,7 @@ class Transpose_Op : public OperatorTensor,
SET_IMPL_MACRO(Transpose_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Transpose_OpImpl>(*this);
mImpl = std::make_shared<TransposeImpl>(*this);
}
}
......@@ -90,9 +90,9 @@ class Transpose_Op : public OperatorTensor,
}
};
inline std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &output_dims_order,
inline std::shared_ptr<Node> Transpose(const std::vector<DimSize_t> &outputDimsOrder,
const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Transpose_Op>(output_dims_order), name);
return std::make_shared<Node>(std::make_shared<Transpose_Op>(outputDimsOrder), name);
}
} // namespace Aidge
......
......@@ -119,9 +119,11 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
}
}
if (node_ptr == mRootNode || node_ptr->type() != "Producer" || showProducers) {
fmt::print(fp.get(), "{}_{}({}){}\n", node_ptr->type(), namePtrTable.at(node_ptr),
givenName, nodeCls);
if (node_ptr->type() != "Producer" || showProducers) {
// if (node_ptr == mRootNode) {
fmt::print(fp.get(), "{}_{}({}){}\n", node_ptr->type(), namePtrTable.at(node_ptr),
givenName, nodeCls);
// }
}
}
......@@ -1412,10 +1414,9 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone
}
std::shared_ptr<Aidge::GraphView> Aidge::getConnectedGraphView(std::shared_ptr<Node> node) {
std::vector<NodePtr> foundNodes;
foundNodes.push_back(node);
std::vector<NodePtr> foundNodes{node};
for (size_t curNodeIdx = 0; curNodeIdx < foundNodes.size(); ++curNodeIdx) {
for (std::size_t curNodeIdx = 0; curNodeIdx < foundNodes.size(); ++curNodeIdx) {
NodePtr curNode = foundNodes[curNodeIdx];
for (auto childs : curNode->getOrderedChildren()) {
......
......@@ -23,7 +23,7 @@
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
void Aidge::Transpose_OpImpl::forward() {
void Aidge::TransposeImpl::forward() {
const Transpose_Op& op = dynamic_cast<const Transpose_Op&>(mOp);
const auto inputDims = op.getInput(0)->dims();
const auto outputDims = op.getOutput(0)->dims();
......@@ -83,7 +83,7 @@ void Aidge::Transpose_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t
SET_IMPL_MACRO(Transpose_Op, *this, name);
}
else {
mImpl = std::make_shared<Transpose_OpImpl>(*this);
mImpl = std::make_shared<TransposeImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment