Run parameter optimization (wofost)
Uncomment the following line to install the latest version of cropengine if needed.
In [ ]:
Copied!
# !pip install -U cropengine
# !pip install -U cropengine
Import libraries¶
In [ ]:
Copied!
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from cropengine import WOFOSTCropSimulationBatchRunner
from cropengine.agromanagement import WOFOSTAgroEventBuilder
from cropengine.optimizer import WOFOSTOptimizer
from sklearn.metrics import mean_squared_error
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from cropengine import WOFOSTCropSimulationBatchRunner
from cropengine.agromanagement import WOFOSTAgroEventBuilder
from cropengine.optimizer import WOFOSTOptimizer
from sklearn.metrics import mean_squared_error
Instantiate batch crop simulation engine for WOFOST¶
In [ ]:
Copied!
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"
# Define the csv path with 'id', 'latitude', and 'longitude'
locations_csv_path = "test_data/optimizer/location.csv"
# Initialize Engine
batch_runner = WOFOSTCropSimulationBatchRunner(
model_name=MODEL_NAME,
locations_csv_path=locations_csv_path,
workspace_dir="test_output/optimizer_workspace",
)
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"
# Define the csv path with 'id', 'latitude', and 'longitude'
locations_csv_path = "test_data/optimizer/location.csv"
# Initialize Engine
batch_runner = WOFOSTCropSimulationBatchRunner(
model_name=MODEL_NAME,
locations_csv_path=locations_csv_path,
workspace_dir="test_output/optimizer_workspace",
)
User inputs¶
In [ ]:
Copied!
# Crop Configuration
models = batch_runner.get_model_options()
crops = batch_runner.get_crop_options(MODEL_NAME)
CROP_NAME = "wheat"
varieties = batch_runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Winter_wheat_103"
# Timing
crop_start_end = batch_runner.get_crop_start_end_options()
CAMPAIGN_START = "2019-09-01"
CROP_START = "2019-09-25"
CROP_START_TYPE = "sowing"
CROP_END_TYPE = "maturity"
CROP_END = None
CAMPAIGN_END = "2020-09-30"
MAX_DURATION = 365
# Crop Configuration
models = batch_runner.get_model_options()
crops = batch_runner.get_crop_options(MODEL_NAME)
CROP_NAME = "wheat"
varieties = batch_runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Winter_wheat_103"
# Timing
crop_start_end = batch_runner.get_crop_start_end_options()
CAMPAIGN_START = "2019-09-01"
CROP_START = "2019-09-25"
CROP_START_TYPE = "sowing"
CROP_END_TYPE = "maturity"
CROP_END = None
CAMPAIGN_END = "2020-09-30"
MAX_DURATION = 365
Create agromanagements with user inputs¶
In [ ]:
Copied!
agro_event_builder = WOFOSTAgroEventBuilder()
# Note: Use agro_event_builder.get_..._events_info() to see valid values if unsure
timed_events_info = agro_event_builder.get_timed_events_info()
state_events_info = agro_event_builder.get_state_events_info()
# Build timed events (irrigation)
irrigation_schedule = [
{"event_date": "2020-03-20", "amount": 3.0, "efficiency": 0.7}, # stem elongation
{"event_date": "2020-04-25", "amount": 2.5, "efficiency": 0.7}, # booting/heading
{"event_date": "2020-05-20", "amount": 2.0, "efficiency": 0.7}, # flowering
]
irrigation_events = agro_event_builder.create_timed_events(
signal_type="irrigate", events_list=irrigation_schedule
)
# Build state Events (fertilization based on DVS)
nitrogen_schedule = [
{"threshold": 0.3, "N_amount": 40, "N_recovery": 0.7}, # early vegetative
{"threshold": 0.6, "N_amount": 60, "N_recovery": 0.7}, # stem elongation
{"threshold": 1.0, "N_amount": 40, "N_recovery": 0.7}, # heading
]
nitrogen_events = agro_event_builder.create_state_events(
signal_type="apply_n",
state_var="DVS",
zero_condition="rising",
events_list=nitrogen_schedule,
)
agro_event_builder = WOFOSTAgroEventBuilder()
# Note: Use agro_event_builder.get_..._events_info() to see valid values if unsure
timed_events_info = agro_event_builder.get_timed_events_info()
state_events_info = agro_event_builder.get_state_events_info()
# Build timed events (irrigation)
irrigation_schedule = [
{"event_date": "2020-03-20", "amount": 3.0, "efficiency": 0.7}, # stem elongation
{"event_date": "2020-04-25", "amount": 2.5, "efficiency": 0.7}, # booting/heading
{"event_date": "2020-05-20", "amount": 2.0, "efficiency": 0.7}, # flowering
]
irrigation_events = agro_event_builder.create_timed_events(
signal_type="irrigate", events_list=irrigation_schedule
)
# Build state Events (fertilization based on DVS)
nitrogen_schedule = [
{"threshold": 0.3, "N_amount": 40, "N_recovery": 0.7}, # early vegetative
{"threshold": 0.6, "N_amount": 60, "N_recovery": 0.7}, # stem elongation
{"threshold": 1.0, "N_amount": 40, "N_recovery": 0.7}, # heading
]
nitrogen_events = agro_event_builder.create_state_events(
signal_type="apply_n",
state_var="DVS",
zero_condition="rising",
events_list=nitrogen_schedule,
)
Prepare batch system¶
In [ ]:
Copied!
batch_runner.prepare_batch_system(
campaign_start=CAMPAIGN_START,
campaign_end=CAMPAIGN_END,
crop_start=CROP_START,
crop_end=CROP_END,
crop_name=CROP_NAME,
variety_name=CROP_VARIETY,
max_workers=5,
crop_start_type=CROP_START_TYPE,
crop_end_type=CROP_END_TYPE,
max_duration=MAX_DURATION,
timed_events=[irrigation_events],
state_events=[nitrogen_events],
force_update=False,
force_param_update=True,
crop_overrides=None,
soil_overrides=None,
site_overrides={"WAV": 10}, # Extra site params can be passed as overrides
)
batch_runner.prepare_batch_system(
campaign_start=CAMPAIGN_START,
campaign_end=CAMPAIGN_END,
crop_start=CROP_START,
crop_end=CROP_END,
crop_name=CROP_NAME,
variety_name=CROP_VARIETY,
max_workers=5,
crop_start_type=CROP_START_TYPE,
crop_end_type=CROP_END_TYPE,
max_duration=MAX_DURATION,
timed_events=[irrigation_events],
state_events=[nitrogen_events],
force_update=False,
force_param_update=True,
crop_overrides=None,
soil_overrides=None,
site_overrides={"WAV": 10}, # Extra site params can be passed as overrides
)
Run the simulation first¶
In [ ]:
Copied!
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
Plot the results (before optimization)¶
In [ ]:
Copied!
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
Optimize phenology¶
In [ ]:
Copied!
# Load the observed phenology
obs_phenology_df = pd.read_csv("test_data/optimizer/phenology_observed.csv")
phenology_optimizer = WOFOSTOptimizer(
runner=batch_runner, observed_data=obs_phenology_df
)
# Create the loss function for phenology
def loss_fn_phenology(sim_df, obs_df):
# Process observed data
phenology_obs = obs_df[["id", "flowering_doy", "maturity_doy"]]
# Process simulated data
sim_df["day"] = pd.to_datetime(sim_df["day"])
flowering_dates = sim_df[sim_df["DVS"] == 1][["point_id", "day"]]
flowering_dates["flowering_doy_sim"] = flowering_dates["day"].dt.day_of_year
maturity_dates = sim_df[sim_df["DVS"] == 2][["point_id", "day"]]
maturity_dates["maturity_doy_sim"] = maturity_dates["day"].dt.day_of_year
phenology_sim = pd.merge(
left=flowering_dates[["point_id", "flowering_doy_sim"]],
right=maturity_dates[["point_id", "maturity_doy_sim"]],
on="point_id",
how="inner",
)
merged_df = pd.merge(
left=phenology_obs, right=phenology_sim, left_on="id", right_on="point_id"
)
flowering_loss = np.sqrt(
mean_squared_error(merged_df["flowering_doy"], merged_df["flowering_doy_sim"])
)
maturity_loss = np.sqrt(
mean_squared_error(merged_df["maturity_doy"], merged_df["maturity_doy_sim"])
)
total_loss = np.round((flowering_loss + maturity_loss) / 2, 2)
return total_loss
# Define the search space
def search_space(trial):
return {
"crop_params": {
"TSUM1": trial.suggest_int("TSUM1", 100, 1200),
"TSUM2": trial.suggest_int("TSUM2", 100, 1200),
}
}
study = phenology_optimizer.optimize(
search_space,
loss_fn_phenology,
n_trials=100,
n_workers=5,
directions=["minimize"],
sampler="TPE",
)
# Load the observed phenology
obs_phenology_df = pd.read_csv("test_data/optimizer/phenology_observed.csv")
phenology_optimizer = WOFOSTOptimizer(
runner=batch_runner, observed_data=obs_phenology_df
)
# Create the loss function for phenology
def loss_fn_phenology(sim_df, obs_df):
# Process observed data
phenology_obs = obs_df[["id", "flowering_doy", "maturity_doy"]]
# Process simulated data
sim_df["day"] = pd.to_datetime(sim_df["day"])
flowering_dates = sim_df[sim_df["DVS"] == 1][["point_id", "day"]]
flowering_dates["flowering_doy_sim"] = flowering_dates["day"].dt.day_of_year
maturity_dates = sim_df[sim_df["DVS"] == 2][["point_id", "day"]]
maturity_dates["maturity_doy_sim"] = maturity_dates["day"].dt.day_of_year
phenology_sim = pd.merge(
left=flowering_dates[["point_id", "flowering_doy_sim"]],
right=maturity_dates[["point_id", "maturity_doy_sim"]],
on="point_id",
how="inner",
)
merged_df = pd.merge(
left=phenology_obs, right=phenology_sim, left_on="id", right_on="point_id"
)
flowering_loss = np.sqrt(
mean_squared_error(merged_df["flowering_doy"], merged_df["flowering_doy_sim"])
)
maturity_loss = np.sqrt(
mean_squared_error(merged_df["maturity_doy"], merged_df["maturity_doy_sim"])
)
total_loss = np.round((flowering_loss + maturity_loss) / 2, 2)
return total_loss
# Define the search space
def search_space(trial):
return {
"crop_params": {
"TSUM1": trial.suggest_int("TSUM1", 100, 1200),
"TSUM2": trial.suggest_int("TSUM2", 100, 1200),
}
}
study = phenology_optimizer.optimize(
search_space,
loss_fn_phenology,
n_trials=100,
n_workers=5,
directions=["minimize"],
sampler="TPE",
)
Run the simulation with optimized parameters¶
In [ ]:
Copied!
# Run the simulation with optimized parameters
best_params = phenology_optimizer.get_best_params(study, search_space)
# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])
# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
# Run the simulation with optimized parameters
best_params = phenology_optimizer.get_best_params(study, search_space)
# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])
# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
Plot the simulation¶
In [ ]:
Copied!
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
Optimize yield (TWSO)¶
In [ ]:
Copied!
# Load the observed yield
obs_yield_df = pd.read_csv("test_data/optimizer/yield_observed.csv")
yield_optimizer = WOFOSTOptimizer(runner=batch_runner, observed_data=obs_yield_df)
# Create the loss function for yield
def loss_fn_yield(sim_df, obs_df):
# Process observed data
obs_df = obs_df[["id", "yield"]].copy()
obs_df["yield"] = obs_df["yield"] * 1000 # t/ha -> kg/ha
# Process simulated data
sim_df["day"] = pd.to_datetime(sim_df["day"])
sim_df = sim_df[sim_df["DVS"] >= 2]
sim_df = sim_df.groupby(by="point_id").first()
sim_df = sim_df[["TWSO"]].reset_index()
merged_df = pd.merge(left=obs_df, right=sim_df, left_on="id", right_on="point_id")
yield_loss = np.sqrt(mean_squared_error(merged_df["yield"], merged_df["TWSO"]))
return np.round(yield_loss, 2)
# Define the search space
def search_space(trial):
# 1. Define a scaling factor for Photosynthesis (AMAX)
# This allows Optuna to shift the entire curve up or down by +/- 20%
amax_factor = trial.suggest_float("amax_factor", 0.8, 1.2)
# 2. Define a scaling factor for Leaf Thickness (SLA)
sla_factor = trial.suggest_float("sla_factor", 0.8, 1.2)
return {
"crop_params": {
# LEAF DYNAMICS (Source capacity)
# SPAN: Leaf lifespan. Higher = longer green canopy duration.
"SPAN": trial.suggest_float("SPAN", 25.0, 40.0),
# SLATB (Specific Leaf Area): Controls how much leaf area is built per kg biomass.
# We scale the entire table by a factor (0.8x to 1.2x).
"SLATB": [
0.0,
0.00212 * sla_factor,
0.5,
0.00212 * sla_factor,
2.0,
0.00212 * sla_factor,
],
# ASSIMILATION & CONVERSION (Biomass production)
# AMAXTB: Max CO2 assimilation rate. Highly sensitive.
"AMAXTB": [
0.0,
35.83 * amax_factor, # Vegetative
1.0,
35.83 * amax_factor, # Flowering
1.3,
35.83 * amax_factor, # Early Grain filling
2.0,
4.48 * amax_factor, # Maturity (Senescence)
],
# CVT: Efficiency of conversion to storage organs (Harvest Index driver).
"CVO": trial.suggest_float("CVO", 0.65, 0.75),
# ROOTING (Water access)
# RDMCR: Max rooting depth. Critical for drought resistance.
"RDMCR": trial.suggest_int("RDMCR", 80, 150),
}
}
study = yield_optimizer.optimize(
search_space,
loss_fn_yield,
n_trials=1000,
n_workers=5,
directions=["minimize"],
sampler="TPE",
)
# Load the observed yield
obs_yield_df = pd.read_csv("test_data/optimizer/yield_observed.csv")
yield_optimizer = WOFOSTOptimizer(runner=batch_runner, observed_data=obs_yield_df)
# Create the loss function for yield
def loss_fn_yield(sim_df, obs_df):
# Process observed data
obs_df = obs_df[["id", "yield"]].copy()
obs_df["yield"] = obs_df["yield"] * 1000 # t/ha -> kg/ha
# Process simulated data
sim_df["day"] = pd.to_datetime(sim_df["day"])
sim_df = sim_df[sim_df["DVS"] >= 2]
sim_df = sim_df.groupby(by="point_id").first()
sim_df = sim_df[["TWSO"]].reset_index()
merged_df = pd.merge(left=obs_df, right=sim_df, left_on="id", right_on="point_id")
yield_loss = np.sqrt(mean_squared_error(merged_df["yield"], merged_df["TWSO"]))
return np.round(yield_loss, 2)
# Define the search space
def search_space(trial):
# 1. Define a scaling factor for Photosynthesis (AMAX)
# This allows Optuna to shift the entire curve up or down by +/- 20%
amax_factor = trial.suggest_float("amax_factor", 0.8, 1.2)
# 2. Define a scaling factor for Leaf Thickness (SLA)
sla_factor = trial.suggest_float("sla_factor", 0.8, 1.2)
return {
"crop_params": {
# LEAF DYNAMICS (Source capacity)
# SPAN: Leaf lifespan. Higher = longer green canopy duration.
"SPAN": trial.suggest_float("SPAN", 25.0, 40.0),
# SLATB (Specific Leaf Area): Controls how much leaf area is built per kg biomass.
# We scale the entire table by a factor (0.8x to 1.2x).
"SLATB": [
0.0,
0.00212 * sla_factor,
0.5,
0.00212 * sla_factor,
2.0,
0.00212 * sla_factor,
],
# ASSIMILATION & CONVERSION (Biomass production)
# AMAXTB: Max CO2 assimilation rate. Highly sensitive.
"AMAXTB": [
0.0,
35.83 * amax_factor, # Vegetative
1.0,
35.83 * amax_factor, # Flowering
1.3,
35.83 * amax_factor, # Early Grain filling
2.0,
4.48 * amax_factor, # Maturity (Senescence)
],
# CVT: Efficiency of conversion to storage organs (Harvest Index driver).
"CVO": trial.suggest_float("CVO", 0.65, 0.75),
# ROOTING (Water access)
# RDMCR: Max rooting depth. Critical for drought resistance.
"RDMCR": trial.suggest_int("RDMCR", 80, 150),
}
}
study = yield_optimizer.optimize(
search_space,
loss_fn_yield,
n_trials=1000,
n_workers=5,
directions=["minimize"],
sampler="TPE",
)
Run the simulation with optimized parameters¶
In [ ]:
Copied!
# Run the simulation with optimized parameters
best_params = yield_optimizer.get_best_params(study, search_space)
# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])
# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
# Run the simulation with optimized parameters
best_params = yield_optimizer.get_best_params(study, search_space)
# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])
# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
Plot the simulation¶
In [ ]:
Copied!
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()