- AI, But Simple
- Posts
- Mathematics of Recurrent Neural Networks: Backpropagation
Mathematics of Recurrent Neural Networks: Backpropagation
AI, But Simple Issue #23
Hello from the AI, but simple team! If you enjoy our content, consider supporting us so we can keep doing what we do.
Our newsletter is no longer sustainable to run at no cost, so we’re relying on different measures to cover operational expenses. Thanks again for reading!
Mathematics of Recurrent Neural Networks: Backpropagation
AI, But Simple Issue #23
This issue is a continuation of last week’s issue, where we went over the forward pass of a simple Recurrent Neural Network (RNN).
If you haven’t already, please read the last issue, as it goes over some notation, values, and understanding needed for this issue. It can be found here.
In addition to reading that issue, to get a better feeling on the process of forward propagation, backwards propagation, and neural network math, feel free to check out some of these issues listed below:
It would also be helpful to start this issue with some knowledge of calculus and some linear algebra like matrix multiplication, but the process will still be explained to an understandable level.
Through the backpropagation process, we aim to compute the gradients of the total loss with respect to all the parameters by applying the chain rule backward through time.
We give RNN backpropagation a special name: Backpropagation Through Time (BPTT).
As a reminder, here is the architecture of our RNN:
During the backwards pass of a neural network, the main goal is to update the weights and biases by using an optimization method such as gradient descent.
In these optimization methods, we need to find the derivatives of the loss with respect to weights and biases, as they influence the magnitude of the weight updates.
We need a total of 5 derivatives since there are 5 parameter matrices in total, as seen in the above neural network diagram (Why, Whh, Wxh, by, bh).
To find these derivatives, we need to calculate output (δyt) and hidden errors (δht). Let’s get into the process.