Daniele Reda's blog

On the usage of Batch Norm in Reinforcement Learning

There is no other way to put it: it is bad1.

For over a month this year I spent time debugging the code of a project I’ve been working on recently. It consists on an environment based on a dynamics simulator I built and a few RL wrappers I’ve been playing with. I was pretty sure the RL code was correct, since I ported it from a previous project, but for some reason I wasn’t getting good results so, obviously, I drove my attention to the physics behind the simulator in search of a bug. I stripped down the environment to a very basic problem and still, it wasn’t working. Finally, one day, I noticed that the architecture of my network contained BatchNorm layers, stupid me for not noticing that!! Removing them, allowed to reach good results and finally go on with the project.

Thinking about it, I should have payed more attention to where I got my code from: I got the RL code from a previous RL project, but I copied and pasted the networks from a previous supervised learning project where batch norm is usually useful. But anyway, I took the time to actually make sure this was a problem, so I built a small toy example to test how bad batch norm is in RL.

Let us use a very simple environment, which doesn’t do anything at all: the state is always a vector of only zeros [0, 0, 0], the action is continuous between -1 and +1, the reward is equal to the action, and an episode consists of 30 steps. So basically, the optimal policy is one that always outputs +1 and the max reward is 30.

Here is the code for the environment:

import numpy as np
from gym import spaces

class RandomEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.num_steps = 0
        self.max_steps = 30
        self.reward_func = lambda a: a
        
        self.action_space = spaces.Box(low=-1, high=+1, shape=(1,))
        self.observation_space = spaces.Box(low=-1, high=+1, shape=(3,))

    def reset(self):
        self.num_steps = 0
        return self.get_state()

    def step(self, a):
        self.num_steps += 1
        a = a[0]
        return self.get_state(), self.reward_func(a), self.get_terminal(), {}

    def get_state(self):
        return np.array([0.,0.,0.])

    def get_terminal(self):
        return True if self.num_steps >= self.max_steps else False

I then tried to solve this environment with both TD3 algorithm (Fujimoto et al.) and with SAC (Haarnoja et al.) both with two different architectures: one which includes batch norm and one which does not.

Well, after running the same experiment 5 times, the results confirmed perfectly my thought, with both TD3 and SAC.

SAC Results TD3 Results

Mean, min, and max for both SAC and TD3 training on the environment above, with both architectures. Each experiment has been run with 5 different random seeds.

I have no final explanation for this. I initially thought that this was because, compared to supervised learning, the dataset is not static, and with that also its statistics change, but now that the state is constant and there is technically no exploration, that wouldn’t be true anymore as each batch should contain similar transition tuples. I wrote this as a reminder for myself, to not use batch norm in RL problems, but I would also be happy to understand the reason why this is happening.

The whole code is available here.


  1. in my experience ↩︎