-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_eval.py
55 lines (46 loc) · 1.73 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from argparse import ArgumentParser
from codes.benchmark import (
check_benchmark,
check_surrogate,
compare_models,
get_surrogate,
run_benchmark,
)
from codes.utils import nice_print, read_yaml_config
def main(args):
"""
Main function to run the benchmark. It reads the config file, checks the benchmark
configuration, runs the benchmark for each surrogate model, and compares the models
if specified in the config file.
Args:
args (Namespace): The command line arguments.
"""
config = read_yaml_config(args.config)
check_benchmark(config)
surrogates = config["surrogates"]
# Create dictionary to store metrics for all surrogate models
all_metrics = {surrogate: {} for surrogate in surrogates}
# Run benchmark for each surrogate model
for surrogate_name in surrogates:
surrogate_class = get_surrogate(surrogate_name)
if surrogate_class is not None:
nice_print(f"Running benchmark for {surrogate_name}")
check_surrogate(surrogate_name, config)
metrics = run_benchmark(surrogate_name, surrogate_class, config)
all_metrics[surrogate_name] = metrics
else:
print(f"Surrogate {surrogate_name} not recognized. Skipping.")
# Compare models
if config["compare"]:
if len(surrogates) < 2:
nice_print("At least two surrogate models are required to compare.")
else:
nice_print("Comparing models")
compare_models(all_metrics, config)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--config", default="config.yaml", type=str, help="Path to the config file."
)
args = parser.parse_args()
main(args)