Stumbling backwards into np.random.seed through jax.

10 minute read

Published:

Spoiler: We’ve all been using randomness wrong

You can find the associated notebook for this post, but it’s relatively minimal. Feel free to open the link and play with the notebook, but know that running it’s not strictly necessary.

1) Intro

Given my current needs, I think that jax is the best computational tool out there. I hope to write more about jax in the coming months, and show you why you should consider trying it out. One important thing to realize is that jax is not a deep learning framework (although it does have autograd built-in). First and foremost, jax is a numerical computation library, like numpy.

Over the weekend, I was working on porting some code from pytorch to jax. In the process, I stumbled onto some code that dealt with randomness, and I decided to read more about randomness in the context of numpy. The material I had read over the weekend ended up being the motivation behind this blog post. To begin, let’s look at how we would deal with randomness in jax:

key = jax.random.PRNGKey(SEED)
print(key)

# which outputs the following on my run:
#   DeviceArray([1076515368, 3893328283], dtype=uint32)

Ironically, I felt like I understood numpy’s randomness better after using jax. This blog post hopes to exposit what I learned in the process.

i) A little about jax

As mentioned earlier, jax is a computational framework akin to numpy. I’d say the main difference between jax and numpy is that jax was designed to be optimizer agnostic. Being optimizer agnostic means that jax runs fast regardless of if you’re on a CPU, GPU, or TPU. I particularly like it because of:

  • how fast it is when compared to other frameworks (I got a 10X speed boost compared to raw vectorized numpy in a function with lots of dot products).

  • how easy it is to peek into its internals (admittedly, this is subjective).

  • how it allows you to implement the equations you see in papers directly. You can implement the line of code then call vmap to apply it to all rows in your array. You don’t need to futz around with vectorizing your equations any longer.

ii) Could it be the future?

I feel like jax and XLA are the future of computation in python. Granted, this isn’t exactly a hot take - lots of people and companies have begun to move to jax:

  • DeepMind’s alphafold model is built in haiku, which is a deep-learning oriented library built on top of jax

  • Google Brain has also released a deep-learning called flax. From what I can tell, teams at Google Brain have begun transitioning over to it.

  • Huggingface has also begun releasing models in flax

Note Pytorch behind

In my last blog post PyTorch Gradients, I mentioned publishing a series of posts covering gradients in PyTorch. I fully intend to finish that series, but I’ve more or less abandoned PyTorch.

2) Randomness:

Anyways, on to the meat of this post: over the weekend, I was playing with the idea of porting over snnTorch to jax. I first began by scanning through the tutorials where I read some material about creating random spike trains. The contents of the tutorial and what spike trains are aren’t crucial for this post. Still, it did remind me that jax handles randomness differently from other frameworks. So, I thought I should do some deep(er) reading before naively moving code over.

If you look up randomness in jax, one of the first things you’ll stumble on is how to generate a key and continually split the random key. To make a long story short, jax is functional in nature, which means that it is stateless. Being stateless means (among other things) that jax handles randomness explicitly; we have to explicitly seed a value every time we invoke randomness in our code. On the one hand, this makes our code more verbose, but on the other hand, it makes reproducibility far easier.


i) Statefulness

The following is merely a working example of what “statefulness” means. It is by no means a rigorous definition. Think of being stateful as the following:

class StatefulAdd():
    def __init__():
        self.count = 0

    def __call__(self, x):
        # The identity + number of times it has been called
        self.count += 1
        return x + 1

foo = StatefulAdd()
first = foo(1)  # first := 1
second = foo(1)  # second := 2

i.e. we can plug the same value in but obtain different values each time. There’s nothing inherently wrong about coding this way(regardless of what the func-ies will say); it can just be harder to reason about it.


Anyways, going back to jax: by enforcing statelessness, we have to be explicit in terms of our random key every time we make a call. By enforcing statelessness, jax sidesteps the reproducibility issue that plagued Tensorflow1.X (and probably pytorch too). Although jax isn’t perfect in the reproducibility aspect, I believe it is going in the right direction.

ii) Reproducibility in TF1.X

In short, to be absolutely sure that you will get reproducible results with your python script on one computer’s/laptop’s CPU then you will have to do the following:

# Seed value
# Apparently you may use different seed values at each stage
seed_value= 0

# 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)

# 2. Set the `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set the `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)
# for later versions: 
# tf.compat.v1.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
# for later versions:
# session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
# sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
# tf.compat.v1.keras.backend.set_session(sess)

Indeed, a thing of beauty.

3) Reproducibility in numpy

First and foremost, I’d recommend opening the accompanying notebook, specifically the numpy portion and playing with the code there. NB: the jax portion is trivial and works as you might expect; I included the jax portion primarily for completeness.

As you play with the numpy portion, you’ll notice how you get new random values every time you call a random module. We get new random values every time we call a random module without explicitly giving in a key, which tells us something is happening under the hood.

This “something” looks a lot like we are generating a new random key on every call. Note that this is not what happens under the hood, but it helps tie what we see to jax and how it handles random state.

i) Example Scenario

You have a program that only crashes once in a while, and you’ve identified the exact function that it crashes on! You’ve even managed to find a specific random seed on which that function works fine, so you’d like to set the state only inside that function and avoid the problem altogether.

Yes, this is a contrived example; sue me.

Statefullness issue illustrated

Note here how we have reset the random seed within the new_generate_np_weights. If the randomness were only local to the context we are in, we would expect to “continue” the original randomness once we exit the function. Said differently, we would have two “sources” of randomness, the second of which would get garbage collected once new_generate_np_weights returns; however, as we can see on the function labeled with “#3rd” call”, we have received the same random value as our “# 2nd call”.

The global state

Clearly, something “unexpected” is happening. At its core, np.random.seed creates what is known as a RandomState which, as we’ve discussed, creates a stateful object. In fact, as we saw in our code example, calling seed recreates the object instead of reseeding it.

Obviously, this is the source of our issues.

ii) How do we address reproducibility in numpy?

In all honesty, I have previously stumbled on the new best practices for generating random numbers in numpy, but I never bothered to read it. I don’t think that the reasoning behind the recommendation ever clicked with me, so I never felt a need to change how I was doing things.

However, now that we are clear on the limitations of the existing np.random.seed, we can discuss the recommended way of doing things: RandomGenerator. To make a long story short, you create an object which contains all your randomness; you “extract” whatever you need from this random object. For example, see random sampling

from numpy.random import default_rng
rng = default_rng()
vals = rng.standard_normal(10)
more_vals = rng.standard_normal(10)

as opposed to an older method

from numpy import random
vals = random.standard_normal(10)
more_vals = random.standard_normal(10)

Where we presumably mutate a global object.

Closing Thoughts:

This was an enlightening topic for me to dive into, and I hope you found reading this useful. I feel like I better understand what numpy does under the hood when we use randomness. I also feel like I better understand the motivation behind numpy’s API change recommendation when viewed through the lens of jax.

tl;dr

1) jax handles randomness very well, even if it may be more verbose. 2) Use the new best practices if you are dealing with random numbers in numpy

P.s

You can generate multiple keys with jax.random.split that you can consume

key_array = jax.random.split(key, num=X)