This project presents a solution to the classic Snake Game using Reinforcement Learning (RL) with the Q-learning algorithm and a deep neural network. The goal is to train an agent to play the game effectively, maximizing its score by collecting food while avoiding collisions with itself or the grid boundaries.
The core of the project involves implementing a Q-learning agent, which utilizes a deep neural network called Linear_QNet
, to approximate the Q-values of different state-action pairs. The agent interacts with the game environment, stores its experiences in a replay memory, and updates its Q-network using experience replay and the Bellman equation.
This README provides an overview of the implementation, including details about the AGENT
class, state representation, experience replay, training process, and action selection. Additionally, it covers the setup requirements and dependencies for running the Snake Game with the RL-based agent.
This project implements a Q-learning algorithm with a deep neural network using PyTorch. The agent learns to play the Snake game by interacting with the environment and updating its Q-values based on the rewards received. The Q-values are estimated using a deep neural network called Linear_QNet
.
The AGENT
class represents the Q-learning agent. It has the following components:
number_of_games
: Number of games played by the agent.epsilon
: Exploration rate for choosing random actions.GAMMA
: Discount factor for future rewards.memory
: Replay memory to store experience tuples.model
: The Q-network model (Linear_QNet
) used by the agent.trainer
: The Q-training object (QTraining
) to train the model.last_move
: The last move made by the agent.
The get_state
method takes the game state as input and converts it into a numerical representation. It creates an 11-dimensional state vector that represents various features of the game state, such as danger zones, current direction, and food location.
The remember_brain
method stores the experience tuple (state, action, reward, next_state, is_done)
into the agent's replay memory.
The train_long_memory
method performs a batch training step using experience replay. It samples a mini-batch of experiences from the replay memory and updates the Q-network's weights using the Q-training object's train_step
method.
The train_short_memory
method performs a single training step using the current experience tuple. It updates the Q-network's weights without using experience replay.
The action_state
method selects an action to take based on the current state. It implements an epsilon-greedy strategy, where there is a chance of choosing a random action (exploration) and a chance of choosing the action with the highest Q-value (exploitation) predicted by the Q-network.
The train
function is the main training loop. It initializes the game environment, creates an instance of the agent, and iteratively plays the game, collects experiences, and performs training steps. It also keeps track of the game scores, updates the best score, and plots the training progress using Matplotlib.
This project implements a Q-learning algorithm with a deep neural network using PyTorch. The main goal is to train an agent to make decisions in an environment with discrete actions and states. The deep neural network is used as a function approximator to learn the Q-values of different state-action pairs.
Before running the code, make sure you have the following libraries installed:
- Python 3.x
- PyTorch (version >= 1.0)
- NumPy
The Linear_QNet
class defines a simple feedforward neural network with three linear layers. The constructor takes the input size, two hidden layer sizes, and output size as arguments. The network architecture is as follows:
- Input Layer:
input_size
neurons - Hidden Layer 1:
hidden_1_size
neurons, activated by ReLU activation function - Hidden Layer 2:
hidden_2_size
neurons, no activation function - Output Layer:
output_size
neurons, representing Q-values for each action
The forward
method defines the forward pass of the network, where the input x
is passed through the linear layers with ReLU activation in between Hidden Layer 1 and Hidden Layer 2.
The save
method saves the model's state dictionary to a file named "model.pth" in the "./model" directory. If the directory doesn't exist, it is created first.
The save_checkpoint
method saves the model's state dictionary using a checkpoint format to a file named "checkpoint.pth" in the "./model" directory.
The load_checkpoint
method loads a saved checkpoint from the file "checkpoint.pth" in the "./model" directory. It retrieves the state dictionary and loads it into the model.
The QTraining
class is responsible for training the Q-network using Q-learning. It takes the model, learning rate (lr
), and discount factor (gamma
) as arguments during initialization.
The train_step
method performs a single training step for the Q-network. It takes a tuple of (state, action, reward, next_state, done)
as input, where:
state
: The current state of the environmentaction
: The action taken by the agentreward
: The reward received after taking the actionnext_state
: The next state of the environment after taking the actiondone
: A flag indicating whether the episode is finished
The method calculates the Q-value target using the Bellman equation, updates the Q-values of the chosen actions, and performs backpropagation to minimize the Mean Squared Error (MSE) loss between the predicted Q-values and the target Q-values.
To train the Q-network, follow these steps:
- Create an instance of the
Linear_QNet
class with the appropriate input, hidden, and output sizes. - Create an instance of the
QTraining
class, passing the model, learning rate, and discount factor. - Run multiple episodes of the environment and train the model at each step using the
train_step
method.
This is a simple implementation of the classic Snake Game using Python and the Pygame library.
The Snake Game offers the following functionality:
- The snake moves on a grid and the player controls its direction.
- The objective is to eat the food that appears randomly on the grid.
- As the snake eats food, it grows longer.
- The game ends if the snake collides with itself or with the boundaries of the grid.
- The score increases each time the snake eats food.
- The game speed gradually increases as the player scores more points.
To run the Snake Game, ensure that you have Python and Pygame installed on your system. Then, execute the Python script snakeAI.py
for AI version or snakehuman.py
for playable version.
⚠️ updates on the way to make it more playable ⚠️
Controls:
- Use the arrow keys (up, down, left, right) to change the direction of the snake.
- Press the spacebar to pause or resume the game.
The Snake Game requires the following dependencies:
- Python 3.x
- Pygame
Sometimes the AI just starts to keep going in circles, one way to fix it is giving it a penalty for going back to the same position it was before, but it will take more time to train.