Run sensitivity analysis (wofost)
Uncomment the following line to install the latest version of cropengine if needed.
In [ ]:
Copied!
# !pip install -U cropengine
# !pip install -U cropengine
Import libraries¶
In [ ]:
Copied!
import os
import math
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from cropengine import WOFOSTCropSimulationBatchRunner
from cropengine.agromanagement import WOFOSTAgroEventBuilder
from cropengine.sensitivity import WOFOSTSensitivityAnalyzer
import os
import math
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from cropengine import WOFOSTCropSimulationBatchRunner
from cropengine.agromanagement import WOFOSTAgroEventBuilder
from cropengine.sensitivity import WOFOSTSensitivityAnalyzer
Instantiate batch crop simulation engine for WOFOST¶
In [ ]:
Copied!
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"
# Define the csv path with 'id', 'latitude', and 'longitude'
locations_csv_path = "test_data/sensitivity/location.csv"
# Initialize Engine
batch_runner = WOFOSTCropSimulationBatchRunner(
model_name=MODEL_NAME,
locations_csv_path=locations_csv_path,
workspace_dir="test_output/sensitivity_workspace",
)
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"
# Define the csv path with 'id', 'latitude', and 'longitude'
locations_csv_path = "test_data/sensitivity/location.csv"
# Initialize Engine
batch_runner = WOFOSTCropSimulationBatchRunner(
model_name=MODEL_NAME,
locations_csv_path=locations_csv_path,
workspace_dir="test_output/sensitivity_workspace",
)
User inputs¶
In [ ]:
Copied!
# Crop Configuration
models = batch_runner.get_model_options()
crops = batch_runner.get_crop_options(MODEL_NAME)
CROP_NAME = "wheat"
varieties = batch_runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Winter_wheat_103"
# Timing
crop_start_end = batch_runner.get_crop_start_end_options()
CAMPAIGN_START = "2019-09-01"
CROP_START = "2019-09-25"
CROP_START_TYPE = "sowing"
CROP_END_TYPE = "maturity"
CROP_END = None
CAMPAIGN_END = "2020-09-30"
MAX_DURATION = 365
# Crop Configuration
models = batch_runner.get_model_options()
crops = batch_runner.get_crop_options(MODEL_NAME)
CROP_NAME = "wheat"
varieties = batch_runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Winter_wheat_103"
# Timing
crop_start_end = batch_runner.get_crop_start_end_options()
CAMPAIGN_START = "2019-09-01"
CROP_START = "2019-09-25"
CROP_START_TYPE = "sowing"
CROP_END_TYPE = "maturity"
CROP_END = None
CAMPAIGN_END = "2020-09-30"
MAX_DURATION = 365
Create agromanagements with user inputs¶
In [ ]:
Copied!
agro_event_builder = WOFOSTAgroEventBuilder()
# Note: Use agro_event_builder.get_..._events_info() to see valid values if unsure
timed_events_info = agro_event_builder.get_timed_events_info()
state_events_info = agro_event_builder.get_state_events_info()
# Build timed events (irrigation)
irrigation_schedule = [
{"event_date": "2020-03-20", "amount": 3.0, "efficiency": 0.7}, # stem elongation
{"event_date": "2020-04-25", "amount": 2.5, "efficiency": 0.7}, # booting/heading
{"event_date": "2020-05-20", "amount": 2.0, "efficiency": 0.7}, # flowering
]
irrigation_events = agro_event_builder.create_timed_events(
signal_type="irrigate", events_list=irrigation_schedule
)
# Build state Events (fertilization based on DVS)
nitrogen_schedule = [
{"threshold": 0.3, "N_amount": 40, "N_recovery": 0.7}, # early vegetative
{"threshold": 0.6, "N_amount": 60, "N_recovery": 0.7}, # stem elongation
{"threshold": 1.0, "N_amount": 40, "N_recovery": 0.7}, # heading
]
nitrogen_events = agro_event_builder.create_state_events(
signal_type="apply_n",
state_var="DVS",
zero_condition="rising",
events_list=nitrogen_schedule,
)
agro_event_builder = WOFOSTAgroEventBuilder()
# Note: Use agro_event_builder.get_..._events_info() to see valid values if unsure
timed_events_info = agro_event_builder.get_timed_events_info()
state_events_info = agro_event_builder.get_state_events_info()
# Build timed events (irrigation)
irrigation_schedule = [
{"event_date": "2020-03-20", "amount": 3.0, "efficiency": 0.7}, # stem elongation
{"event_date": "2020-04-25", "amount": 2.5, "efficiency": 0.7}, # booting/heading
{"event_date": "2020-05-20", "amount": 2.0, "efficiency": 0.7}, # flowering
]
irrigation_events = agro_event_builder.create_timed_events(
signal_type="irrigate", events_list=irrigation_schedule
)
# Build state Events (fertilization based on DVS)
nitrogen_schedule = [
{"threshold": 0.3, "N_amount": 40, "N_recovery": 0.7}, # early vegetative
{"threshold": 0.6, "N_amount": 60, "N_recovery": 0.7}, # stem elongation
{"threshold": 1.0, "N_amount": 40, "N_recovery": 0.7}, # heading
]
nitrogen_events = agro_event_builder.create_state_events(
signal_type="apply_n",
state_var="DVS",
zero_condition="rising",
events_list=nitrogen_schedule,
)
Prepare batch system¶
In [ ]:
Copied!
batch_runner.prepare_batch_system(
campaign_start=CAMPAIGN_START,
campaign_end=CAMPAIGN_END,
crop_start=CROP_START,
crop_end=CROP_END,
crop_name=CROP_NAME,
variety_name=CROP_VARIETY,
max_workers=5,
crop_start_type=CROP_START_TYPE,
crop_end_type=CROP_END_TYPE,
max_duration=MAX_DURATION,
timed_events=[irrigation_events],
state_events=[nitrogen_events],
force_update=False,
force_param_update=True,
crop_overrides=None,
soil_overrides=None,
site_overrides={"WAV": 10}, # Extra site params can be passed as overrides
)
batch_runner.prepare_batch_system(
campaign_start=CAMPAIGN_START,
campaign_end=CAMPAIGN_END,
crop_start=CROP_START,
crop_end=CROP_END,
crop_name=CROP_NAME,
variety_name=CROP_VARIETY,
max_workers=5,
crop_start_type=CROP_START_TYPE,
crop_end_type=CROP_END_TYPE,
max_duration=MAX_DURATION,
timed_events=[irrigation_events],
state_events=[nitrogen_events],
force_update=False,
force_param_update=True,
crop_overrides=None,
soil_overrides=None,
site_overrides={"WAV": 10}, # Extra site params can be passed as overrides
)
Run the simulation first¶
In [ ]:
Copied!
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()
In [ ]:
Copied!
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])
# Variables to plot (exclude metadata columns)
vars_to_plot = [
col
for col in batch_results.columns
if col not in ["point_id", "latitude", "longitude", "day"]
]
# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)
# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}
fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()
for i, var in enumerate(vars_to_plot):
ax = axes[i]
for pid in unique_points:
df_sub = batch_results[batch_results["point_id"] == pid]
sns.lineplot(
x=df_sub["day"],
y=df_sub[var],
ax=ax,
label=f"Point {pid}",
color=color_map[pid],
)
ax.set_title(var)
ax.legend()
# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
axes[j].axis("off")
plt.tight_layout()
plt.show()
Sensitivity analysis using FAST (Fourier Amplitude Sensitivity Test) method¶
In [ ]:
Copied!
# Instantiate WOFOST Sensitivity Analyzer
sa = WOFOSTSensitivityAnalyzer(batch_runner)
# Define the problem
problem = {
"num_vars": 3,
"names": ["TSUM1", "TSUM2", "SPAN"],
"bounds": [[800, 1200], [900, 1300], [28, 35]], # TSUM1 # TSUM2 # SPAN
}
sa_result = sa.run_analysis(
problem,
method="fast",
n_samples=512,
target_variable="TWSO",
mode="local",
n_workers=70,
)
# Instantiate WOFOST Sensitivity Analyzer
sa = WOFOSTSensitivityAnalyzer(batch_runner)
# Define the problem
problem = {
"num_vars": 3,
"names": ["TSUM1", "TSUM2", "SPAN"],
"bounds": [[800, 1200], [900, 1300], [28, 35]], # TSUM1 # TSUM2 # SPAN
}
sa_result = sa.run_analysis(
problem,
method="fast",
n_samples=512,
target_variable="TWSO",
mode="local",
n_workers=70,
)
Plot the data¶
In [ ]:
Copied!
def plot_global(df):
df_melt = df.melt(
id_vars="Parameter",
value_vars=["S1", "ST"],
var_name="Index",
value_name="Score",
)
plt.figure(figsize=(8, 5))
sns.set_style("whitegrid")
sns.barplot(
data=df_melt,
x="Score",
y="Parameter",
hue="Index",
palette={"ST": "#4c72b0", "S1": "#dd8452"},
)
plt.title("Global Parameter Sensitivity (Aggregated)", fontsize=14)
plt.xlabel("Sensitivity Index", fontsize=12)
plt.ylabel("")
plt.legend(title="Index Type")
plt.tight_layout()
plt.show()
plot_global(sa_result)
def plot_global(df):
df_melt = df.melt(
id_vars="Parameter",
value_vars=["S1", "ST"],
var_name="Index",
value_name="Score",
)
plt.figure(figsize=(8, 5))
sns.set_style("whitegrid")
sns.barplot(
data=df_melt,
x="Score",
y="Parameter",
hue="Index",
palette={"ST": "#4c72b0", "S1": "#dd8452"},
)
plt.title("Global Parameter Sensitivity (Aggregated)", fontsize=14)
plt.xlabel("Sensitivity Index", fontsize=12)
plt.ylabel("")
plt.legend(title="Index Type")
plt.tight_layout()
plt.show()
plot_global(sa_result)
In [ ]:
Copied!
def plot_local_heatmap(df):
pivot_df = df.pivot(index="point_id", columns="Parameter", values="ST")
plt.figure(figsize=(8, 6))
sns.heatmap(pivot_df, annot=True, cmap="YlGnBu", fmt=".2f", linewidths=0.5)
plt.title("Spatial Variability of Sensitivity (Total Effect ST)", fontsize=14)
plt.ylabel("Location ID", fontsize=12)
plt.xlabel("Parameter", fontsize=12)
plt.tight_layout()
plt.show()
plot_local_heatmap(sa_result)
def plot_local_heatmap(df):
pivot_df = df.pivot(index="point_id", columns="Parameter", values="ST")
plt.figure(figsize=(8, 6))
sns.heatmap(pivot_df, annot=True, cmap="YlGnBu", fmt=".2f", linewidths=0.5)
plt.title("Spatial Variability of Sensitivity (Total Effect ST)", fontsize=14)
plt.ylabel("Location ID", fontsize=12)
plt.xlabel("Parameter", fontsize=12)
plt.tight_layout()
plt.show()
plot_local_heatmap(sa_result)
In [ ]:
Copied!
def plot_local_distribution(df):
plt.figure(figsize=(8, 5))
sns.boxplot(data=df, x="ST", y="Parameter", hue="Parameter", palette="vlag")
sns.stripplot(data=df, x="ST", y="Parameter", color="black", alpha=0.5, jitter=True)
plt.title("Stability of Parameters Across Locations", fontsize=14)
plt.xlabel("Total Sensitivity Index (ST)")
plt.tight_layout()
plt.show()
plot_local_distribution(sa_result)
def plot_local_distribution(df):
plt.figure(figsize=(8, 5))
sns.boxplot(data=df, x="ST", y="Parameter", hue="Parameter", palette="vlag")
sns.stripplot(data=df, x="ST", y="Parameter", color="black", alpha=0.5, jitter=True)
plt.title("Stability of Parameters Across Locations", fontsize=14)
plt.xlabel("Total Sensitivity Index (ST)")
plt.tight_layout()
plt.show()
plot_local_distribution(sa_result)