Stumbling backwards into np.random.seed through jax.
Published:
Spoiler: We’ve all been using randomness wrong
You can find the associated notebook for this post, but it’s relatively minimal. Feel free to open the link and play with the notebook, but know that running it’s not strictly necessary.
1) Intro
Given my current needs, I think that jax
is the best computational tool out there. I hope to write more about jax
in the coming months, and show you why you should consider trying it out. One important thing to realize is that jax
is not a deep learning framework (although it does have autograd built-in). First and foremost, jax
is a numerical computation library, like numpy
.
Over the weekend, I was working on porting some code from pytorch
to jax
. In the process, I stumbled onto some code that dealt with randomness, and I decided to read more about randomness in the context of numpy
. The material I had read over the weekend ended up being the motivation behind this blog post. To begin, let’s look at how we would deal with randomness in jax
:
key = jax.random.PRNGKey(SEED)
print(key)
# which outputs the following on my run:
# DeviceArray([1076515368, 3893328283], dtype=uint32)
Ironically, I felt like I understood numpy
’s randomness better after using jax
. This blog post hopes to exposit what I learned in the process.
i) A little about jax
As mentioned earlier, jax
is a computational framework akin to numpy
. I’d say the main difference between jax
and numpy
is that jax
was designed to be optimizer agnostic. Being optimizer agnostic means that jax
runs fast regardless of if you’re on a CPU, GPU, or TPU. I particularly like it because of:
how fast it is when compared to other frameworks (I got a 10X speed boost compared to raw vectorized numpy in a function with lots of dot products).
how easy it is to peek into its internals (admittedly, this is subjective).
how it allows you to implement the equations you see in papers directly. You can implement the line of code then call
vmap
to apply it to all rows in your array. You don’t need to futz around with vectorizing your equations any longer.
ii) Could it be the future?
I feel like jax
and XLA
are the future of computation in python. Granted, this isn’t exactly a hot take - lots of people and companies have begun to move to jax
:
DeepMind’s alphafold model is built in haiku, which is a deep-learning oriented library built on top of
jax
Google Brain has also released a deep-learning called flax. From what I can tell, teams at Google Brain have begun transitioning over to it.
Huggingface has also begun releasing models in
flax
Note Pytorch behind
In my last blog post PyTorch Gradients, I mentioned publishing a series of posts covering gradients in PyTorch. I fully intend to finish that series, but I’ve more or less abandoned PyTorch.
2) Randomness:
Anyways, on to the meat of this post: over the weekend, I was playing with the idea of porting over snnTorch to jax
. I first began by scanning through the tutorials where I read some material about creating random spike trains. The contents of the tutorial and what spike trains are aren’t crucial for this post. Still, it did remind me that jax
handles randomness differently from other frameworks. So, I thought I should do some deep(er) reading before naively moving code over.
If you look up randomness in jax
, one of the first things you’ll stumble on is how to generate a key and continually split the random key. To make a long story short, jax
is functional in nature, which means that it is stateless. Being stateless means (among other things) that jax
handles randomness explicitly; we have to explicitly seed a value every time we invoke randomness in our code. On the one hand, this makes our code more verbose, but on the other hand, it makes reproducibility far easier.
i) Statefulness
The following is merely a working example of what “statefulness” means. It is by no means a rigorous definition. Think of being stateful as the following:
class StatefulAdd():
def __init__():
self.count = 0
def __call__(self, x):
# The identity + number of times it has been called
self.count += 1
return x + 1
foo = StatefulAdd()
first = foo(1) # first := 1
second = foo(1) # second := 2
i.e. we can plug the same value in but obtain different values each time. There’s nothing inherently wrong about coding this way(regardless of what the func-ies will say); it can just be harder to reason about it.
Anyways, going back to jax
: by enforcing statelessness, we have to be explicit in terms of our random key every time we make a call. By enforcing statelessness, jax
sidesteps the reproducibility issue that plagued Tensorflow1.X (and probably pytorch too). Although jax
isn’t perfect in the reproducibility aspect, I believe it is going in the right direction.
ii) Reproducibility in TF1.X
How to get stable results with TensorFlow, setting random seed although, to be fair, there seems to be an official answer for Tensorflow 2 as of 2020
Why can’t I get reproducible results in Keras even though I set the random seeds? (asked in 2018) which contains my favorite answer I’ve seen so far. The answer states the following and has the following caveat:
In short, to be absolutely sure that you will get reproducible results with your python script on one computer’s/laptop’s CPU then you will have to do the following:
# Seed value
# Apparently you may use different seed values at each stage
seed_value= 0
# 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)
# 2. Set the `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)
# 3. Set the `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)
# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)
# for later versions:
# tf.compat.v1.set_random_seed(seed_value)
# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
# for later versions:
# session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
# sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
# tf.compat.v1.keras.backend.set_session(sess)
Indeed, a thing of beauty.
3) Reproducibility in numpy
First and foremost, I’d recommend opening the accompanying notebook, specifically the numpy
portion and playing with the code there. NB: the jax
portion is trivial and works as you might expect; I included the jax
portion primarily for completeness.
As you play with the numpy
portion, you’ll notice how you get new random values every time you call a random
module. We get new random values every time we call a random module without explicitly giving in a key, which tells us something is happening under the hood.
This “something” looks a lot like we are generating a new random key on every call. Note that this is not what happens under the hood, but it helps tie what we see to jax
and how it handles random state.
i) Example Scenario
You have a program that only crashes once in a while, and you’ve identified the exact function that it crashes on! You’ve even managed to find a specific random seed on which that function works fine, so you’d like to set the state only inside that function and avoid the problem altogether.
Yes, this is a contrived example; sue me.
Statefullness issue illustrated
Note here how we have reset the random seed within the new_generate_np_weights
. If the randomness were only local to the context we are in, we would expect to “continue” the original randomness once we exit the function. Said differently, we would have two “sources” of randomness, the second of which would get garbage collected once new_generate_np_weights
returns; however, as we can see on the function labeled with “#3rd” call”, we have received the same random value as our “# 2nd call”.
The global state
Clearly, something “unexpected” is happening. At its core, np.random.seed
creates what is known as a RandomState
which, as we’ve discussed, creates a stateful object. In fact, as we saw in our code example, calling seed
recreates the object instead of reseeding it.
Obviously, this is the source of our issues.
ii) How do we address reproducibility in numpy?
In all honesty, I have previously stumbled on the new best practices for generating random numbers in numpy
, but I never bothered to read it. I don’t think that the reasoning behind the recommendation ever clicked with me, so I never felt a need to change how I was doing things.
However, now that we are clear on the limitations of the existing np.random.seed
, we can discuss the recommended way of doing things: RandomGenerator
. To make a long story short, you create an object which contains all your randomness; you “extract” whatever you need from this random object. For example, see random sampling
from numpy.random import default_rng
rng = default_rng()
vals = rng.standard_normal(10)
more_vals = rng.standard_normal(10)
as opposed to an older method
from numpy import random
vals = random.standard_normal(10)
more_vals = random.standard_normal(10)
Where we presumably mutate a global object.
Closing Thoughts:
This was an enlightening topic for me to dive into, and I hope you found reading this useful. I feel like I better understand what numpy
does under the hood when we use randomness. I also feel like I better understand the motivation behind numpy
’s API change recommendation when viewed through the lens of jax
.
tl;dr
1) jax
handles randomness very well, even if it may be more verbose. 2) Use the new best practices if you are dealing with random numbers in numpy
P.s
You can generate multiple keys with jax.random.split that you can consume
key_array = jax.random.split(key, num=X)