TL;DR: I designed a machine learning model that is well-suited to learning the physics of the game Pong. I trained that model by showing it data from hundreds of thousands of sequential frames captured during normal gameplay. As a result, the model learned the deceptively complex rules and physics of the game. By feeding control inputs (for the paddles) into the trained model, you can play a game of Pong.
This work is (obviously) not connected to Atari or the original Pong game in any way. I am using the term 'Pong' to describe a Pong-like table tennis video game.
How It WorksI wrote a simple Pong-style game using pygame. As each frame is displayed, it writes a text file containing information on the positions of the paddles, the position of the ball, and the state of the user inputs. This data, along with supplemental synthetic data that simulates a paddle miss (which is otherwise a rare event) that is generated by this script, is used to train an artificial neural network. The architecture is represented in the diagram below:
Whoah! All that to learn Pong? It may seem like overkill, but the physics are deceptively difficult to learn. I started out thinking I'd have this running in a few hours with a simple feedforward network, but it ended up taking months of spare time to get it working. The velocity inversion of the ball at bounces, the paddle hits and misses, and the issue of paddle movement and ball movement inappropriately influencing each other, for instance, was very, very hard for any model I tried to learn. Aside from basic feedforward architectures, I also tried LSTM/GRU layers, convolutional layers, and (as it felt, anyway) just about everything else.
What ended up finally working was a Transformer-based architecture with multiple isolated branches and output heads. In the next section I'll go into a deep dive of the model.
Ultimately, I would like to do another version of this project using images of the game screen as training data, and have the model predict the next image frame. That is actually the direction I started in for this project, but I soon realized that this goal was out of reach because I do not have a GPU. Some people are GPU poor, but I'm GPU broke. This model was trained on a machine with a pair of ten-year-old Xeon CPUs.
Explaining the Model ArchitectureThe model is defined in the training script.
The model takes in a set of 4 sequential time points containing ball and paddle coordinates and user inputs. I also included a number of engineered features (e.g. ball velocity, distance from each edge, etc.) to aid the model in learning.
train_x.append([paddle1_pos_1, paddle2_pos_1, ball_x_1, ball_y_1, paddle1_vel_1, paddle2_vel_1,
paddle1_pos_2, paddle2_pos_2, ball_x_2, ball_y_2, paddle1_vel_2, paddle2_vel_2,
paddle1_pos_3, paddle2_pos_3, ball_x_3, ball_y_3, paddle1_vel_3, paddle2_vel_3,
paddle1_pos_4, paddle2_pos_4, ball_x_4, ball_y_4, paddle1_vel_4, paddle2_vel_4,
delta_x_1, delta_y_1, dist_left_1, dist_right_1, dist_top_1, dist_bottom_1, coverage_p1_1, coverage_p2_1,
delta_x_2, delta_y_2, dist_left_2, dist_right_2, dist_top_2, dist_bottom_2, coverage_p1_2, coverage_p2_2,
delta_x_3, delta_y_3, dist_left_3, dist_right_3, dist_top_3, dist_bottom_3, coverage_p1_3, coverage_p2_3,
delta_x_4, delta_y_4, dist_left_4, dist_right_4, dist_top_4, dist_bottom_4, coverage_p1_4, coverage_p2_4])
The goal is to learn the physics of ball movement, bounces at the edges of the screen, paddle misses (point scored) or bounces, how to handle user input to adjust paddle positions, and to keep everything within bounds of the screen — basically everything that makes up a game of Pong. This knowledge contained in the model is used to predict the next frame in the game, which then slides into the list of past frames as new predictions are made. So, initially a game is started with a seed of 4 time points of data, then the model does all the work. But that is for the inference stage, so I am getting ahead of myself.
At a high level, the architecture uses branching to separate paddle and ball processing (divide-and-conquer for independent dynamics) to avoid learning inappropriate interactions, temporal modeling via attention (to capture sequence dependencies across frames), and a shared branch for integrating interactions (e.g., paddle-ball collisions for bounces).
A normalization layer scales features to mean=0, variance=1 based on training data statistics, because training would be much slower using absolute coordinates (much larger values). Next, a Gaussian noise layer adds a small amount of random noise to normalized inputs during training only to make the model more robust to unseen data. This helps the model perform better when it is making new predictions based on a time sequence of past predictions (instead of ideal data captured from real games) to prevent small errors from being magnified over time.
normalization_layer = keras.layers.Normalization(axis=-1, name='input_normalization')
normalized_input = normalization_layer(main_input)
# Add some noise to help with model generalization.
noisy_input = keras.layers.GaussianNoise(stddev=0.01)(normalized_input, training=True)
Each player paddle has its own branch that is a simple feedforward layer that only looks at the paddle position and associated user input. This allows it to learn without being confused by irrelevant features.
# Branch for the paddle1 features.
paddle1_features = keras.ops.take(noisy_input, indices=[0, 4, 6, 10, 12, 16, 18, 22], axis=1)
paddle1_branch = keras.layers.Dense(64, activation='relu', name='paddle1_1')(paddle1_features)
# Branch for the paddle2 features.
paddle2_features = keras.ops.take(noisy_input, indices=[1, 5, 7, 11, 13, 17, 19, 23], axis=1)
paddle2_branch = keras.layers.Dense(64, activation='relu', name='paddle2_1')(paddle2_features)
All ball-related features are also isolated and fed into a branch of the network. This uses the self-attention of a Transformer to pick out the features that influence ball movement, which proved to be especially useful on edge cases (e.g. bounces) that other types of models just averaged out of existence.
# Branch for ball features, including engineered features.
ball_indices = [2, 3, 24, 25, 26, 27, 28, 29, 30, 31, # Frame 1: x, y, dx, dy, dl, dr, dt, db, c1, c2
8, 9, 32, 33, 34, 35, 36, 37, 38, 39, # Frame 2
14, 15, 40, 41, 42, 43, 44, 45, 46, 47, # Frame 3
20, 21, 48, 49, 50, 51, 52, 53, 54, 55] # Frame 4
ball_features = keras.ops.take(noisy_input, indices=ball_indices, axis=1)
ball_branch = keras.layers.Reshape((4, 10))(ball_features)
ball_branch = keras.layers.MultiHeadAttention(num_heads=4, key_dim=32)(ball_branch, ball_branch)
ball_branch = keras.layers.LayerNormalization()(ball_branch)
ball_branch = keras.layers.Flatten()(ball_branch)
ball_branch = keras.layers.Dense(64, activation=keras.layers.LeakyReLU(negative_slope=0.1), name='ball_1')(ball_branch)
ball_branch = keras.layers.Dense(64, activation=keras.layers.LeakyReLU(negative_slope=0.1), name='ball_2')(ball_branch)
After that, another independent branch takes in the output of the ball and paddle branches and merges them for processing by another Transformer. This picks up more complex interactions between the ball and paddle.
# Combine ball with paddle features (for hit/miss detection).
combined_features = keras.layers.Concatenate(name='concatenate_branches')([ball_branch, paddle1_branch, paddle2_branch])
combined_features = keras.layers.Dense(32, activation='relu')(combined_features) # Project to common dimension.
combined_features = keras.layers.Reshape((4, 32))(keras.layers.RepeatVector(4)(combined_features)) # Repeat to create 4 time steps.
pos_encoding = positional_encoding(4, 32)
combined_features += pos_encoding
shared_branch = keras.layers.MultiHeadAttention(num_heads=4, key_dim=32)(combined_features, combined_features)
shared_branch = keras.layers.LayerNormalization()(shared_branch)
shared_branch = keras.layers.Dense(128, activation=keras.layers.LeakyReLU(negative_slope=0.1))(shared_branch)
shared_branch = keras.layers.Dropout(0.3)(shared_branch)
shared_branch = keras.layers.Flatten()(shared_branch)
shared_branch = keras.layers.Dense(64, activation=keras.layers.LeakyReLU(negative_slope=0.1), name='shared_dense')(shared_branch)
Finally, the paddle positions, ball position deltas, and game state (normal/point scored) are predicted for the next frame.
# Output heads.
paddle1_pos_output = keras.layers.Dense(1, activation='linear', name='paddle1_output_1')(paddle1_branch)
paddle2_pos_output = keras.layers.Dense(1, activation='linear', name='paddle2_output_1')(paddle2_branch)
ball_state_output = keras.layers.Dense(2, activation='linear', name='ball_output_2')(shared_branch)
game_state_output = keras.layers.Dense(1, activation='sigmoid', name='game_state_output_2')(shared_branch)
Playing the ModelThe inference script is used to play the game. It is given a set of 4 hard coded time points of input data initially to make the first prediction with. The results of the prediction are used to position elements on the screen. From that point forward, predictions move backward into the historical data that the model predicts the next frame from, along with keyboard data that is fed in to represent the user inputs. This process continues ad infinitum, with all logic being contained in the trained model. The inference script just handles drawing elements to the screen where the network says they should go — it contains no game logic whatsoever.
About the Author
Comments