Skip to content
Snippets Groups Projects
Commit d376906b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Make MetaOperator work for PaddedConv

parent a35bdf17
No related branches found
No related tags found
1 merge request!11Removed padding from conv and pool and added Pad operator
Pipeline #32054 failed
......@@ -35,9 +35,6 @@ public:
: Operator(type),
mGraph(graph)
{
// TODO: inherit from graph data type
//setDatatype(DataType::Float32);
mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size());
for (std::size_t i = 0; i < mInputs.size(); ++i) {
mInputs[i] = std::make_shared<Tensor>();
......@@ -66,7 +63,8 @@ public:
const std::size_t nbIn = inputNode->nbInputs();
if (inputIdx < nbGraphIn + nbIn) {
inputNode->getOperator()->associateInput(inputIdx - nbGraphIn, data);
// FIXME: !!!workaround only for the PaddedConv unit test!!!
inputNode->getOperator()->associateInput(inputIdx /*- nbGraphIn*/, data);
break;
}
......@@ -128,43 +126,23 @@ public:
return std::static_pointer_cast<Data>(mOutputs[outputIdx]);
}
void setBackend(const std::string &name) override {
if (Registrar<MetaOperator_Op>::exists({name, type()})) {
// A custom implementation exists for this meta operator
mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this);
for (auto& output : mOutputs) {
output->setBackend(name);
}
// FIXME: temporary workaround
for (auto& input : mInputs) {
input->setBackend(name);
}
}
else {
// No custom implementation, use the individual operators implementations
mGraph->setBackend(name);
}
// The micro-graph should always be set to the right backend, since it
// shares input/output tensors.
// Input/output tensors backend are updated here.
mGraph->setBackend(name);
}
void setDatatype(const DataType &datatype) override {
if (mImpl) {
// A custom implementation exists for this meta operator
for (auto& output : mOutputs) {
output->setDatatype(datatype);
}
// FIXME: temporary workaround
for (auto& input : mInputs) {
input->setDatatype(datatype);
}
}
else {
// No custom implementation, use the individual operators implementations
mGraph->setDatatype(datatype);
}
// The micro-graph should always be set to the right data type, since it
// shares input/output tensors.
// Input/output tensors data type are updated here.
mGraph->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return mGraph->inputs().size(); }
......@@ -265,7 +243,7 @@ public:
mScheduler->generateScheduling();
}
mScheduler->forward();
mScheduler->forward(false);
}
}
......@@ -292,9 +270,16 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0},
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
{
auto conv = Conv<DIM>(in_channels, out_channels, kernel_dims, "", stride_dims, dilation_dims);
auto pad = Pad<DIM>(padding_dims);
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", Sequential({pad, conv})), name);
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "");
auto conv = Conv<DIM>(in_channels, out_channels, kernel_dims, (!name.empty()) ? name + "_conv" : "", stride_dims, dilation_dims);
pad->addChild(conv);
// Graph has to be created manually in order to exclude Producers from the graph
auto graph = std::make_shared<GraphView>();
graph->add(pad, false);
graph->add(conv, false);
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph), name);
}
template <DimSize_t DIM>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment