Skip to content
Snippets Groups Projects
Commit 60d17a2f authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Merge branch 'dev' into dataloader

parents fafe588f 16738da4
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!4Dataloader
Pipeline #38550 passed
......@@ -27,25 +27,26 @@
#include "aidge/utils/Types.h"
namespace Aidge {
enum class GatherAttr { Axis };
enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, int> {
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis)
: OperatorTensor(Type, 2, 0, 1),
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis))
{}
......@@ -76,21 +77,21 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "indexes"};
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name);
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"};
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */
......@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> {
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public:
static const std::string Type;
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes)
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
......@@ -94,9 +94,9 @@ public:
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts,
const std::vector<std::int32_t> ends,
const std::vector<std::int32_t> axes,
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
const std::vector<std::int64_t> ends,
const std::vector<std::int64_t> axes,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
......@@ -23,6 +23,6 @@ void init_Gather(py::module& m) {
.def("get_inputs_name", &Gather_Op::getInputsName)
.def("get_outputs_name", &Gather_Op::getOutputsName);
m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = "");
m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
......@@ -9,8 +9,8 @@
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
......@@ -22,18 +22,26 @@ const std::string Aidge::Gather_Op::Type = "Gather";
void Aidge::Gather_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
if (getInput(1)->nbDims()!=2){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor");
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims = getInput(0)->dims();
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if (!gatheredShape.empty())
{
outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx),
gatheredShape.cbegin(),
gatheredShape.cend());
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indexesDims = getInput(1)->dims();
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
mOutputs[0]->resize(outDims);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
......@@ -27,31 +27,32 @@ void Aidge::Reshape_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
mOutputs[0]->resize(outDims);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
......@@ -30,21 +30,23 @@ void Aidge::Slice_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
}
mOutputs[0]->resize(outDims);
......
......@@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter
std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
}
std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]);
inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
}
std::vector<std::int32_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0));
std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment