Skip to content

Commit becfe29

Browse files
committed
Merge remote-tracking branch 'origin/main' into improve-html
2 parents 2a7dd3b + 29aca16 commit becfe29

3 files changed

Lines changed: 82 additions & 31 deletions

File tree

src/compare.jl

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,17 @@ function compare_with_reference(
294294
isempty(times) && return 0, 0, 0, ""
295295

296296
# Determine which signals to compare: prefer comparisonSignals.txt
297-
sig_file = joinpath(dirname(ref_csv_path), "comparisonSignals.txt")
298-
signals = if isfile(sig_file)
299-
filter(s -> lowercase(s) != "time" && !isempty(s), strip.(readlines(sig_file)))
297+
sig_file = joinpath(dirname(ref_csv_path), "comparisonSignals.txt")
298+
using_sig_file = isfile(sig_file)
299+
signals = if using_sig_file
300+
sigs = filter(s -> lowercase(s) != "time" && !isempty(s), strip.(readlines(sig_file)))
301+
sigs_missing = filter(s -> !haskey(ref_data, s), sigs)
302+
isempty(sigs_missing) || error("Signal(s) listed in comparisonSignals.txt not present in reference CSV: $(join(sigs_missing, ", "))")
303+
sigs
300304
else
301305
filter(k -> lowercase(k) != "time", collect(keys(ref_data)))
302306
end
307+
n_total = length(signals)
303308

304309
# ── Build variable accessor map ──────────────────────────────────────────────
305310
# var_access: normalized name → Int (state index) or MTK symbolic (observed).
@@ -321,18 +326,27 @@ function compare_with_reference(
321326
@warn "Could not enumerate observed variables: $(sprint(showerror, e))"
322327
end
323328

324-
# Clip reference time to the simulation interval
329+
# Verify the simulation covers the expected reference time interval.
330+
# A large gap means the solver stopped early or started late.
331+
isempty(sol.t) && return n_total, 0, 0, ""
325332
t_start = sol.t[1]
326333
t_end = sol.t[end]
334+
ref_t_start = times[1]
335+
ref_t_end = times[end]
336+
if t_start > ref_t_start || t_end < ref_t_end
337+
@error "Simulation interval [$(t_start), $(t_end)] does not cover " *
338+
"reference interval [$(ref_t_start), $(ref_t_end)]"
339+
return n_total, 0, 0, ""
340+
end
341+
342+
# Clip reference time to the simulation interval
327343
valid_mask = (times .>= t_start) .& (times .<= t_end)
328344
t_ref = times[valid_mask]
329-
isempty(t_ref) && return 0, 0, 0, ""
345+
isempty(t_ref) && return n_total, 0, 0, ""
330346

331-
n_total = 0
332347
n_pass = 0
333348
pass_sigs = String[]
334349
fail_sigs = String[]
335-
skip_sigs = String[]
336350
pass_max_abs_error = Dict{String, Float64}()
337351
pass_max_rel_error = Dict{String, Float64}()
338352
fail_ref_vals = Dict{String, Vector{Float64}}()
@@ -341,25 +355,32 @@ function compare_with_reference(
341355
fail_scaled_rel_error = Dict{String, Vector{Float64}}()
342356

343357
for sig in signals
344-
haskey(ref_data, sig) || continue # signal absent from ref CSV entirely
358+
signal_name = _normalize_var(sig)
359+
ref_vals = ref_data[sig][valid_mask]
360+
361+
nan_vec = fill(NaN, length(t_ref))
345362

346-
norm = _normalize_var(sig)
347-
if !haskey(var_access, norm)
348-
push!(skip_sigs, sig)
363+
if !haskey(var_access, signal_name)
364+
push!(fail_sigs, sig)
365+
fail_ref_vals[sig] = ref_vals
366+
fail_sim_vals[sig] = nan_vec
367+
fail_abs_error[sig] = nan_vec
368+
fail_scaled_rel_error[sig] = nan_vec
349369
continue
350370
end
351371

352-
accessor = var_access[norm]
353-
ref_vals = ref_data[sig][valid_mask]
354-
n_total += 1
372+
accessor = var_access[signal_name]
355373

356374
# Interpolate simulation at reference time points.
357375
sim_vals = [_eval_sim(sol, accessor, t) for t in t_ref]
358376

359-
# If evaluation returned NaN (observed-var access failed), treat as skip.
377+
# If evaluation returned NaN (observed-var access failed), treat as fail.
360378
if any(isnan, sim_vals)
361-
n_total -= 1
362-
push!(skip_sigs, sig)
379+
push!(fail_sigs, sig)
380+
fail_ref_vals[sig] = ref_vals
381+
fail_sim_vals[sig] = sim_vals
382+
fail_abs_error[sig] = nan_vec
383+
fail_scaled_rel_error[sig] = nan_vec
363384
continue
364385
end
365386

@@ -401,15 +422,14 @@ function compare_with_reference(
401422
end
402423

403424
# ── Write detail HTML whenever there is anything worth showing ───────────────
404-
if !isempty(fail_sigs) || !isempty(skip_sigs)
425+
if !isempty(fail_sigs)
405426
write_diff_html(model_dir, model;
406427
diff_csv_path = diff_csv,
407428
pass_sigs = pass_sigs,
408-
skip_sigs = skip_sigs,
409429
pass_max_abs_error = pass_max_abs_error,
410430
pass_max_rel_error = pass_max_rel_error,
411431
settings = settings)
412432
end
413433

414-
return n_total, n_pass, length(skip_sigs), diff_csv
434+
return n_total, n_pass, 0, diff_csv
415435
end

src/pipeline.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,23 @@ function main(;
194194
result = test_model(omc, model, results_root, ref_root; csv_max_size_mb, settings)
195195
push!(results, result)
196196

197-
phase = result.sim_success ? "SIM OK" :
198-
result.parse_success ? "SIM FAIL" :
199-
result.export_success ? "PARSE FAIL" : "EXPORT FAIL"
200-
cmp_info = result.cmp_total > 0 ?
201-
" cmp=$(result.cmp_pass)/$(result.cmp_total)" : ""
197+
phase = if result.sim_success && result.cmp_total > 0
198+
result.cmp_pass == result.cmp_total ? "CMP OK" : "CMP FAIL"
199+
elseif result.sim_success
200+
"SIM OK"
201+
elseif result.parse_success
202+
"SIM FAIL"
203+
elseif result.export_success
204+
"PARSE FAIL"
205+
else
206+
"EXPORT FAIL"
207+
end
208+
cmp_info = if result.cmp_total > 0
209+
skip_note = result.cmp_skip > 0 ? " skip=$(result.cmp_skip)" : ""
210+
" cmp=$(result.cmp_pass)/$(result.cmp_total)$skip_note"
211+
else
212+
""
213+
end
202214
@info "$phase export=$(round(result.export_time;digits=2))s" *
203215
" parse=$(round(result.parse_time;digits=2))s" *
204216
" sim=$(round(result.sim_time;digits=2))s$cmp_info"

src/simulate.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,21 @@ function run_simulate(ode_prob, model_dir::String, model::String;
3232
# Redirect all library log output (including Symbolics/MTK warnings)
3333
# to the log file so they don't clutter stdout.
3434
sol = Logging.with_logger(logger) do
35-
solve(ode_prob, Rodas5P())
35+
# Overwrite saveat, always use dense output.
36+
solve(ode_prob, Rodas5P(); saveat = Float64[], dense = true)
3637
end
3738
sim_time = time() - t0
3839
if sol.retcode == ReturnCode.Success
39-
sim_success = true
40+
sys = sol.prob.f.sys
41+
n_vars = length(ModelingToolkit.unknowns(sys))
42+
n_obs = length(ModelingToolkit.observed(sys))
43+
if isempty(sol.t)
44+
sim_error = "Simulation produced no time points"
45+
elseif n_vars == 0 && n_obs == 0
46+
sim_error = "Simulation produced no output variables (no states or observed)"
47+
else
48+
sim_success = true
49+
end
4050
else
4151
sim_error = "Solver returned: $(sol.retcode)"
4252
end
@@ -49,21 +59,30 @@ function run_simulate(ode_prob, model_dir::String, model::String;
4959
isempty(sim_error) || println(log_file, "\n--- Error ---\n$sim_error")
5060
close(log_file)
5161

52-
# Write simulation results CSV (time + all state variables)
62+
# Write simulation results CSV (time + state variables + observed variables)
5363
if sim_success && sol !== nothing
5464
short_name = split(model, ".")[end]
5565
sim_csv = joinpath(model_dir, "$(short_name)_sim.csv")
5666
try
57-
sys = sol.prob.f.sys
58-
vars = ModelingToolkit.unknowns(sys)
59-
col_names = [_clean_var_name(string(v)) for v in vars]
67+
sys = sol.prob.f.sys
68+
vars = ModelingToolkit.unknowns(sys)
69+
obs_eqs = ModelingToolkit.observed(sys)
70+
obs_syms = [eq.lhs for eq in obs_eqs]
71+
col_names = vcat(
72+
[_clean_var_name(string(v)) for v in vars],
73+
[_clean_var_name(string(s)) for s in obs_syms],
74+
)
6075
open(sim_csv, "w") do f
6176
println(f, join(["time"; col_names], ","))
6277
for (ti, t) in enumerate(sol.t)
6378
row = [@sprintf("%.10g", t)]
6479
for vi in eachindex(vars)
6580
push!(row, @sprintf("%.10g", sol[vi, ti]))
6681
end
82+
for sym in obs_syms
83+
val = try Float64(sol(t; idxs = sym)) catch; NaN end
84+
push!(row, @sprintf("%.10g", val))
85+
end
6786
println(f, join(row, ","))
6887
end
6988
end

0 commit comments

Comments
 (0)