# Stumbling backwards into np.random.seed through jax.

** Published:**

**Spoiler**: We’ve all been using randomness wrong

You can find the associated notebook for this post, but it’s relatively minimal. Feel free to open the link and play with the notebook, but know that running it’s not strictly necessary.

# 1) Intro

Given my current needs, I think that `jax`

is the best computational tool out there. I hope to write more about `jax`

in the coming months, and show you why you should consider trying it out. One important thing to realize is that `jax`

is not a deep learning framework (although it does have autograd built-in). First and foremost, `jax`

is a numerical computation library, like `numpy`

.

Over the weekend, I was working on porting some code from `pytorch`

to `jax`

. In the process, I stumbled onto some code that dealt with randomness, and I decided to read more about randomness in the context of `numpy`

. The material I had read over the weekend ended up being the motivation behind this blog post. To begin, let’s look at how we would deal with randomness in `jax`

:

```
key = jax.random.PRNGKey(SEED)
print(key)
# which outputs the following on my run:
# DeviceArray([1076515368, 3893328283], dtype=uint32)
```

Ironically, I felt like I understood `numpy`

’s randomness better after using `jax`

. This blog post hopes to exposit what I learned in the process.

## i) A little about jax

As mentioned earlier, `jax`

is a computational framework akin to `numpy`

. I’d say the main difference between `jax`

and `numpy`

is that `jax`

was designed to be optimizer agnostic. Being optimizer agnostic means that `jax`

runs fast regardless of if you’re on a CPU, GPU, or TPU. I particularly like it because of:

how fast it is when compared to other frameworks (I got a 10X speed boost compared to raw vectorized numpy in a function with lots of dot products).

how easy it is to peek into its internals (admittedly, this is subjective).

how it allows you to implement the equations you see in papers directly. You can implement the line of code then call

`vmap`

to apply it to all rows in your array. You don’t need to futz around with vectorizing your equations any longer.

## ii) Could it be the future?

I feel like `jax`

and `XLA`

are the future of computation in python. Granted, this isn’t exactly a hot take - lots of people and companies have begun to move to `jax`

:

DeepMind’s alphafold model is built in haiku, which is a deep-learning oriented library built on top of

`jax`

Google Brain has also released a deep-learning called flax. From what I can tell, teams at Google Brain have begun transitioning over to it.

Huggingface has also begun releasing models in

`flax`

**Note** Pytorch behind

In my last blog post PyTorch Gradients, I mentioned publishing a series of posts covering gradients in PyTorch. I fully intend to finish that series, but I’ve more or less abandoned PyTorch.

# 2) Randomness:

Anyways, on to the meat of this post: over the weekend, I was playing with the idea of porting over snnTorch to `jax`

. I first began by scanning through the tutorials where I read some material about creating random spike trains. The contents of the tutorial and what spike trains are aren’t crucial for this post. Still, it did remind me that `jax`

handles randomness differently from other frameworks. So, I thought I should do some deep(er) reading before naively moving code over.

If you look up randomness in `jax`

, one of the first things you’ll stumble on is how to generate a key and continually split the random key. To make a long story short, `jax`

is functional in nature, which means that it is stateless. Being stateless means (among other things) that `jax`

handles randomness explicitly; we have to explicitly seed a value every time we invoke randomness in our code. On the one hand, this makes our code more verbose, but on the other hand, it makes reproducibility far easier.

## i) Statefulness

The following is merely a working example of what “statefulness” means. It is by no means a rigorous definition. Think of being stateful as the following:

```
class StatefulAdd():
def __init__():
self.count = 0
def __call__(self, x):
# The identity + number of times it has been called
self.count += 1
return x + 1
foo = StatefulAdd()
first = foo(1) # first := 1
second = foo(1) # second := 2
```

i.e. we can plug the same value in but obtain different values each time. There’s nothing inherently wrong about coding this way(regardless of what the func-ies will say); it can just be harder to reason about it.

Anyways, going back to `jax`

: by enforcing statelessness, we have to be explicit in terms of our random key every time we make a call. By enforcing statelessness, `jax`

sidesteps the reproducibility issue that plagued Tensorflow1.X (and probably pytorch too). Although `jax`

isn’t perfect in the reproducibility aspect, I believe it is going in the right direction.

## ii) Reproducibility in TF1.X

How to get stable results with TensorFlow, setting random seed although, to be fair, there seems to be an official answer for Tensorflow 2 as of 2020

Why can’t I get reproducible results in Keras even though I set the random seeds? (asked in 2018) which contains my favorite answer I’ve seen so far. The answer states the following and has the following caveat:

In short, to be absolutely sure that you will get reproducible results with your python script on one computer’s/laptop’s CPU then you will have to do the following:

```
# Seed value
# Apparently you may use different seed values at each stage
seed_value= 0
# 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)
# 2. Set the `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)
# 3. Set the `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)
# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)
# for later versions:
# tf.compat.v1.set_random_seed(seed_value)
# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
# for later versions:
# session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
# sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
# tf.compat.v1.keras.backend.set_session(sess)
```

Indeed, a thing of beauty.

# 3) Reproducibility in `numpy`

First and foremost, I’d recommend opening the accompanying notebook, specifically the `numpy`

portion and playing with the code there. NB: the `jax`

portion is trivial and works as you might expect; I included the `jax`

portion primarily for completeness.

As you play with the `numpy`

portion, you’ll notice how you get new random values every time you call a `random`

module. We get new random values every time we call a random module without explicitly giving in a key, which tells us something is happening under the hood.

This “something” looks a lot like we are generating a new random key on every call. Note that this is not what happens under the hood, but it helps tie what we see to `jax`

and how it handles random state.

## i) Example Scenario

You have a program that only crashes once in a while, and you’ve identified the exact function that it crashes on! You’ve even managed to find a specific random seed on which that function works fine, so you’d like to set the state only inside that function and avoid the problem altogether.

Yes, this is a contrived example; sue me.

### Statefullness issue illustrated

Note here how we have reset the random seed within the `new_generate_np_weights`

. If the randomness were only local to the context we are in, we would expect to “continue” the original randomness once we exit the function. Said differently, we would have two “sources” of randomness, the second of which would get garbage collected once `new_generate_np_weights`

returns; however, as we can see on the function labeled with “#3rd” call”, we have received the same random value as our “# 2nd call”.

### The global state

Clearly, something “unexpected” is happening. At its core, `np.random.seed`

creates what is known as a `RandomState`

which, as we’ve discussed, creates a stateful object. In fact, as we saw in our code example, calling `seed`

recreates the object instead of reseeding it.

Obviously, this is the source of our issues.

## ii) How do we address reproducibility in numpy?

In all honesty, I have previously stumbled on the new best practices for generating random numbers in `numpy`

, but I never bothered to read it. I don’t think that the reasoning behind the recommendation ever clicked with me, so I never felt a need to change how I was doing things.

However, now that we are clear on the limitations of the existing `np.random.seed`

, we can discuss the recommended way of doing things: `RandomGenerator`

. To make a long story short, you create an object which contains all your randomness; you “extract” whatever you need from this random object. For example, see random sampling

```
from numpy.random import default_rng
rng = default_rng()
vals = rng.standard_normal(10)
more_vals = rng.standard_normal(10)
```

as opposed to an older method

```
from numpy import random
vals = random.standard_normal(10)
more_vals = random.standard_normal(10)
```

Where we presumably mutate a global object.

# Closing Thoughts:

This was an enlightening topic for me to dive into, and I hope you found reading this useful. I feel like I better understand what `numpy`

does under the hood when we use randomness. I also feel like I better understand the motivation behind `numpy`

’s API change recommendation when viewed through the lens of `jax`

.

tl;dr

1) `jax`

handles randomness very well, even if it may be more verbose. 2) Use the new best practices if you are dealing with random numbers in `numpy`

## P.s

You can generate multiple keys with jax.random.split that you can consume

```
key_array = jax.random.split(key, num=X)
```