Friday, September 30, 2022

Why Back propagation in RNN isn't effective

If you observe, to compute the gradient wrt the previous hidden state, which is the downstream gradient, the upstream gradient flows through the tanh non-linearity and gets multiplied by the weight matrix. Now, since this downstream gradient flows back across time steps, it means the computation happens over and over again at every time step. There are a couple of problems with this:

Since we’re multiplying over and over again by the weight matrix, the gradient will be scaled up or down depending on the largest singular value of the matrix: if the singular value is greater than 1, we’ll face an exploding gradient problem, and if it’s less than 1, we’ll face a vanishing gradient problem.

Now, the gradient passes through the tanh non-linearity which has saturating regions at the extremes. It means the gradient will essentially become zero if it has a high or low value once it passes through the non-linearity — so the gradient cannot propagate effectively across long sequences and it leads to ineffective optimization.

There is a way to avoid the exploding gradient problem by essentially “clipping” the gradient if it crosses a certain threshold. However, RNN still cannot be used effectively for long sequences.

References:

https://towardsdatascience.com/backpropagation-in-rnn-explained-bdf853b4e1c2#:~:text=You%20see%2C%20a%20RNN%20essentially,where%20they%20are%20summed%20up.

No comments:

Post a Comment