Jenanaputra
Published © MIT

Autonomous Salinity Control on Coral Nursery Systems

The control systems based on Reinforcement Learning (RL)

IntermediateWork in progress44
Autonomous Salinity Control on Coral Nursery Systems

Story

Read more

Code

all dependencies

Python
import os
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
import numpy as np
from stable_baselines3.common.env_util import make_vec_env
import matplotlib.pyplot as plt


# class for the environment
class PondSalinityEnv(gym.Env):
    def __init__(self):
        super().__init__()

        # Observation space: salinity in ppt
        self.observation_space = spaces.Box(
            low=0, high=45, shape=(1,), dtype=np.float32
        )

        # Action space: 3 discrete inflow options
        self.action_space = spaces.Discrete(3)

        self.current_step = 0
        self.max_step = 200
        self.salinity = None

        # System parameters
        self.dt = 1.0      # [h]
        self.V = 1000.0    # [m^3]
        self.C_in = 35.0   # [ppt]
        self.S_env = 34.5  # [ppt]
        self.k = 0.05     # [1/h]

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.salinity = np.random.uniform(34.25, 34.75)
        self.current_step = 0
        return np.array([self.salinity], dtype=np.float32), {}

    def mapping_value(self, action):
        return {0: 0.0, 1: 0.1, 2: 0.5}[action]

    def step(self, action):
        Q_in = self.mapping_value(action)
        dS = (Q_in * self.C_in / self.V) - self.k * (self.salinity - self.S_env)
        self.salinity += self.dt * dS
        self.salinity = np.clip(self.salinity, 0, 45)

        obs = np.array([self.salinity], dtype=np.float32)
        reward = -abs(self.salinity - self.S_env)

        self.current_step += 1
        terminated = self.current_step >= self.max_step
        truncated = False
        return obs, reward, terminated, truncated, {}

    def render(self):
        print(f"Step={self.current_step} | Salinity={self.salinity:.2f} ppt")

    def close(self):
        pass
      
#########################################################      

# for model training      
# Import your environment
env = PondSalinityEnv()

# Wrap with vec_env (needed for PPO)
env = make_vec_env(lambda: PondSalinityEnv(), n_envs=1)

# Define PPO model
model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=3e-4,
    gamma=0.99,
    n_steps=2048,
    batch_size=64,
    ent_coef=0.01
)

# Train
model.learn(total_timesteps=500_000)

# Save model
model.save("ppo_salinity_1D")

env.close()
#######################################


# for testing the model
# Load model
model = PPO.load("ppo_salinity_1D")

# Create new env instance (not vec_env here)
env = PondSalinityEnv()

# Reset env
obs, _ = env.reset()
print("Initial observation:", obs)

# Logs
salinity_log = []
setpoint_log = []
time_log = []

setpoint = 34.5
n = 200  # shorter horizon for plotting

for t in range(n):
    # action, _ = model.predict(obs, deterministic=True)
    # obs, reward, terminated, truncated, info = env.step(action)

    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(int(action))

    salinity_log.append(obs[0])   # now obs is 1D
    setpoint_log.append(setpoint)
    time_log.append(t)

    if terminated or truncated:
        break

# Plot
plt.figure(figsize=(10,5))
plt.plot(time_log, salinity_log, label="Actual Salinity")
plt.plot(time_log, setpoint_log, 'r--', label="Desired Salinity (Setpoint)")
plt.xlabel("Time step")
plt.ylabel("Salinity (ppt)")
plt.title("Pond Salinity Control with RL Agent (1D state)")
plt.legend()
plt.grid(True)
plt.show()

########################################

Credits

Jenanaputra
1 project • 0 followers
An underwater robotic developer who passionate on marine things.

Comments