Skip to content
Snippets Groups Projects
Commit cc3a7008 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'bindLoss' into 'dev'

Update how loss function work

See merge request !118
parents d0790356 ab4f0f5d
No related branches found
No related tags found
2 merge requests!1190.2.1,!118Update how loss function work
Pipeline #45162 passed
......@@ -554,16 +554,11 @@ public:
inline void print() const { fmt::print("{}\n", toString()); }
std::shared_ptr<Tensor> grad() {
// if (!mGrad && mImpl) {
// mGrad = std::make_shared<Tensor>(mDims);
// mGrad->setDataType(mDataType);
// mGrad->setBackend(mImpl->backend());
// // if (mImpl) mGrad->setBackend(mImpl->backend());
// }
return mGrad;
}
void setGrad(std::shared_ptr<Tensor> newGrad) {
mGrad = newGrad;
}
/**
* @brief Associate the gradient with a Tensor instance and set its implementation
......@@ -574,7 +569,7 @@ public:
* @note If Tensor instance and implementation already existed for the gradient
* nothing is done.
*/
void initGradient() {
void initGrad() {
if (!mGrad) {
mGrad = std::make_shared<Tensor>(mDims);
}
......
......@@ -105,7 +105,7 @@ public:
void forward() override final;
void backward() override final {
fmt::print("Basic Producer backward() function.\n");
// fmt::print("Basic Producer backward() function.\n");
}
void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override {
if (getAttr<ProdAttr::Constant>()) {
......
......@@ -54,7 +54,7 @@ public:
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true);
void backward(bool instantiateGrad = true);
private:
SchedulingPolicy mSchedulingPolicy;
......
......@@ -77,7 +77,9 @@ void init_Tensor(py::module& m){
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("grad", &Tensor::grad)
.def("set_grad", &Tensor::setGrad)
.def("dtype", &Tensor::dataType)
.def("init_grad", &Tensor::initGrad)
.def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl)
......
......@@ -24,5 +24,6 @@ namespace py = pybind11;
namespace Aidge {
void init_GraphViewHelper(py::module &m) {
m.def("producers", &producers, py::arg("graphview"));
m.def("compile_gradient", &compile_gradient, py::arg("graphview"));
}
} // namespace Aidge
......@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true)
.def("backward", &SequentialScheduler::backward, py::arg("instanciate_grad")=true)
;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
......
......@@ -29,7 +29,9 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor,
: (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0
: fanOut);
AIDGE_ASSERT(n > 0,
"Something went wrong division by zero or square root of "
"negative value.");
const T stdDev(std::sqrt(2.0 / n));
const T mean(varianceNorm == Aidge::VarianceNorm::FanIn ? meanNorm / fanIn
......
......@@ -29,6 +29,9 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor,
: (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0
: fanOut);
AIDGE_ASSERT(n > 0,
"Something went wrong division by zero or square root of "
"negative value.");
const T scale(std::sqrt(3.0 / n));
std::uniform_real_distribution<T> uniformDist(-scale, scale);
......
......@@ -72,9 +72,6 @@ void Aidge::Producer_Op::forward() {
if (!backend().empty()) {
mImpl->forward();
}
else {
fmt::print("Basic Producer forward() function.\n");
}
runHooks();
}
......@@ -51,7 +51,7 @@ void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type());
const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
op->getOutput(o)->initGradient();
op->getOutput(o)->initGrad();
}
}
}
\ No newline at end of file
}
......@@ -73,21 +73,12 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
}
}
void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad) {
void Aidge::SequentialScheduler::backward(bool instanciateGrad) {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
const auto& ordered_outputs = mGraphView->getOrderedOutputs();
AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
right number of data objects to run the backward function. \
{} outputs detected for the current GraphView when {} were \
provided.", ordered_outputs.size(), data.size());
for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
*t_grad = data[i]->clone();
}
// TODO: Check output grad are not empty
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment