Introduction
Recently I have been working on a small project dealing with creating and using neural networks called simple_nn. It is a neural network framework written in Julia similar to PyTorch and tinygrad, both of which inspired this project. I made and continue to work on this framework solely for my own personal learning, however I believe some aspects of the project could be beneficial to write about for others. In this post, I'll be going over using Zygote.jl, a reverse-mode auto-differentation library written in Julia.
I will be primarily focusing on using Zygote.jl, and I will be assuming that you already know the basics of neural networks, gradient descent, and auto-differentation. If you would like to learn more, I highly recommend checking out the free course Natural Language Processing Demystified by Nitin, which really helped me understand the underlying math behind neural networks, and made simple_nn possible.
Basics of Zygote.jl
Using Zygote.jl is quite simple, pass a function and its input(s) to the gradient
function, and it returns the gradient of the function with respect to its input(s):
julia> x = 2;
julia> f(x) = x^2;
julia> ∂f∂x = gradient(f, x)
(4.0,)
Of course you can pass more complicated functions with more input variables:
julia> x = 2;
julia> y = 2;
julia> f(x, y) = x^2 + y^2;
julia> grad = gradient(f, x, y)
(4.0, 4.0)
The most important pattern used in simple_nn is computing the gradients of a function with respect to inputted parameters:
julia> x = [1 2; 3 4];
julia> y = [1, 2];
julia> f(x, y) = sum(x.^2 * y);
julia> grad = gradient(Params([x, y])) do
f(x, y)
end
Grads(...)
The resulting grad
variable is a type of Zygote.Grads
, and the gradients with respect to x
and y
can be simply calculated:
julia> ∂f∂x = grad[x]
2×2 Matrix{Int64}:
2 8
6 16
julia> ∂f∂y = grad[y]
2-element Vector{Int64}:
10
20
You will see how powerful and simple this pattern is to use in the next section.
Using Zygote.jl to Train Neural Networks
As you should already know, the method by which a feedforward neural network trains is through backpropagation, i.e., computing the gradient of the loss function with respect to the weights and biases of the network.
Using Zygote.jl, this process of calculating and using the gradients is actually incredibly easy, the following is the function used in simple_nn to compute the gradients and pass them to the optimizer function:
function Backward!(net::Network, loss_function, x, y)
grad = gradient(Params([net.weights, net.biases])) do
loss_function(x, y)
end
net.params["optimizer"](net, grad[net.weights], grad[net.biases])
end
Where x
is the input to the network, y
is the target, and net.params["optimizer"]
is the optimizer function that modifies the weights and biases of the network based on the inputted gradients. As you can see, we use the same “parameters passing” pattern demonstrated earlier, to easily calculate the gradients of the loss function with respect to the network’s weights and biases.
Just for reference and further clarification, this is how cross entropy loss is implemented for use with simple_nn:
cross_entropy_loss(x, y) = -sum(y .* log.(Forward(net, x)))
And here is the function for gradient descent:
function GradientDescentOptimizer!(net::Network, weight_grad, bias_grad)
net.weights = net.weights .- net.params["learning_rate"] * weight_grad
net.biases = net.biases .- net.params["learning_rate"] * bias_grad
end
Please see the simple_nn GitHub repo or the example notebook if you would like to learn more.
Conclusion
In conclusion, Zygote.jl provides a simple and powerful method of calculating gradients for use in ML and other applications. Using Zygote.jl in simple_nn allows for a simple method of implementing backpropagation, while providing the great performance that comes with code written in Julia. I hope this article provided some inspiration to use this powerful Julia library in some of your own work. For a more mature neural network framework, check out Flux.jl, which actually uses Zygote.jl.
Please see my website for more of my work plus links to my socials, I’m always happy to connect and talk about cool stuff!