Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-janssen committed Sep 26, 2024
2 parents 3cb94e9 + a31f6d0 commit d34e5fd
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 83 deletions.
2 changes: 1 addition & 1 deletion codes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .benchmark import *
from .surrogates import *
from .train import *
from .utils import *
from .utils import *
2 changes: 1 addition & 1 deletion codes/benchmark/bench_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tabulate import tabulate
from torch.utils.data import DataLoader

from codes.dataset import check_and_load_data
from codes.utils import check_and_load_data

from .bench_plots import (
inference_time_bar_plot,
Expand Down
2 changes: 1 addition & 1 deletion codes/benchmark/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def get_model_config(surr_name: str, config: dict) -> dict:
return {}

dataset_name = config["dataset"]["name"].lower()
dataset_folder = f"data/{dataset_name}"
dataset_folder = f"datasets/{dataset_name}"
config_file = f"{dataset_folder}/surrogates_config.py"

if os.path.exists(config_file):
Expand Down
17 changes: 0 additions & 17 deletions codes/dataset/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion codes/surrogates/DeepONet/don_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_project_path(relative_path):
"""
Construct the absolute path to a project resource (data or model) based on a relative path.
:param relative_path: A relative path to the resource, e.g., "data/dataset100" or "models/02-28/model.pth".
:param relative_path: A relative path to the resource, e.g., "datasets/dataset100" or "models/02-28/model.pth".
:return: The absolute path to the resource.
"""
import os
Expand Down
3 changes: 2 additions & 1 deletion codes/train/train_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from tqdm import tqdm

from codes.benchmark.bench_utils import get_model_config, get_surrogate
from codes.dataset import check_and_load_data, get_data_subset
from codes.utils import (
check_and_load_data,
get_data_subset,
get_progress_bar,
load_and_save_config,
make_description,
Expand Down
30 changes: 22 additions & 8 deletions codes/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
from .data_utils import (
check_and_load_data,
create_dataset,
create_hdf5_dataset,
download_data,
get_data_subset,
normalize_data,
)
from .utils import (
read_yaml_config,
time_execution,
check_training_status,
create_model_dir,
get_progress_bar,
load_and_save_config,
set_random_seeds,
nice_print,
load_task_list,
make_description,
get_progress_bar,
worker_init_fn,
nice_print,
read_yaml_config,
save_task_list,
load_task_list,
check_training_status,
set_random_seeds,
time_execution,
worker_init_fn,
)

__all__ = [
"check_and_load_data",
"create_dataset",
"create_hdf5_dataset",
"download_data",
"get_data_subset",
"normalize_data",
"read_yaml_config",
"time_execution",
"create_model_dir",
Expand Down
10 changes: 5 additions & 5 deletions codes/dataset/data_utils.py → codes/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def check_and_load_data(
Raises:
DatasetError: If the dataset or required data is missing or if the data shape is incorrect.
"""
data_dir = "data"
data_dir = "datasets"
dataset_name_lower = dataset_name.lower()

# Check if dataset exists
Expand Down Expand Up @@ -231,7 +231,7 @@ def create_hdf5_dataset(
test_data: np.ndarray,
val_data: np.ndarray,
dataset_name: str,
data_dir: str = "data",
data_dir: str = "datasets",
timesteps: np.ndarray | None = None,
labels: list[str] | None = None,
):
Expand Down Expand Up @@ -337,7 +337,7 @@ def create_dataset(
TypeError: If the train_data is not a numpy array or torch tensor.
ValueError: If the train_data, test_data, and val_data do not have the correct shape.
"""
base_dir = "data"
base_dir = "datasets"
dataset_dir = os.path.join(base_dir, name)

if os.path.exists(dataset_dir):
Expand Down Expand Up @@ -440,14 +440,14 @@ def download_data(dataset_name: str, path: str | None = None):
path (str, optional): The path to save the dataset. If None, the default data directory is used.
"""
data_path = (
os.path.abspath(f"data/{dataset_name.lower()}/data.hdf5")
os.path.abspath(f"datasets/{dataset_name.lower()}/data.hdf5")
if path is None
else os.path.abspath(path)
)
if os.path.isfile(data_path):
return

with open("data/data_sources.yaml", "r", encoding="utf-8") as file:
with open("datasets/data_sources.yaml", "r", encoding="utf-8") as file:
data_sources = yaml.safe_load(file)

try:
Expand Down
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Global settings for the benchmark
training_id: "delete_me2"
training_id: "delete_me3"
surrogates: ["LatentPoly", "LatentNeuralODE", "FullyConnected", "MultiONet"]
batch_size: [256, 256, 256, 256]
epochs: [2, 2, 2, 2]
dataset:
name: "osu2008"
log10_transform: True
log10_transform: False
normalise: "minmax" # "standardise", "minmax", "disable"
use_optimal_params: True
devices: ["cuda:1"]
Expand Down
8 changes: 8 additions & 0 deletions datasets/data_analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
__all__ = [
"create_hdf5_dataset",
"check_and_load_data",
"get_data_subset",
"create_dataset",
"normalize_data",
"download_data"
]
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import sys
from argparse import ArgumentParser

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(1, "../..")

from data import check_and_load_data
from data.data_plots import plot_example_trajectories, plot_example_trajectories_paper
from codes import check_and_load_data

from .data_plots import plot_example_trajectories, plot_example_trajectories_paper


def main(args):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import sys

import matplotlib.pyplot as plt
import numpy as np

from benchmark import save_plot
sys.path.insert(1, "../..")

from codes import save_plot


def plot_example_trajectories(
Expand Down Expand Up @@ -72,7 +76,7 @@ def plot_example_trajectories(
"example_trajectories.png",
conf,
dpi=300,
base_dir="data",
base_dir="datasets",
increase_count=False,
)

Expand Down Expand Up @@ -196,7 +200,7 @@ def plot_example_trajectories_paper(
"example_trajectories_paper.png",
conf,
dpi=300,
base_dir="data",
base_dir="datasets",
increase_count=False,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# TODO: move this to an appropriate location

import os

# Add codes package to the path (two keys up)
import sys
from argparse import ArgumentParser
from typing import Callable

import numpy as np
from scipy.integrate import solve_ivp

from codes.dataset.data_utils import create_dataset
sys.path.insert(1, "../..")

from codes.utils.data_utils import create_dataset


def lotka_volterra(t, n):
Expand Down Expand Up @@ -55,7 +60,7 @@ def reaction(t, n):
array
Array of the derivatives of the abundances of species s1, s2, s3, s4, s5, and s6.
"""
s1, s2, s3, s4, s5, s6 = n[0], n[1], n[2], n[3], n[4], n[5]
s1, s2, s3, s4, s5, _ = n[0], n[1], n[2], n[3], n[4], n[5]
return np.array(
[
-0.1 * s1 + 0.1 * s2,
Expand All @@ -68,33 +73,6 @@ def reaction(t, n):
)


# def func(t, n):
# """
# Differential equations for a simple ODE system.

# Parameters
# ----------
# t : float
# Time
# n : array
# Array of concentrations of species A, B, C, D, and E.

# Returns
# -------
# array
# Array of the derivatives of the concentrations of species A, B, C, D, and E.
# """
# k = np.array([0.8, 0.5, 0.2])
# return np.array(
# [
# -k[0] * n[0] - k[2] * n[0] * n[2],
# k[0] * n[0] - k[1] * n[1] + 2 * k[2] * n[0] * n[2],
# k[1] * n[1] - k[2] * n[0] * n[2],
# k[2] * n[0] + k[1] / k[0] * n[1],
# k[0] * n[0] / k[1] * n[2] - k[1] * n[0] * n[2],
# ]
# )

FUNCS = {
"lotka_volterra": {
"func": lotka_volterra,
Expand Down Expand Up @@ -130,13 +108,15 @@ def create_data(num: int, func: Callable, timesteps: np.ndarray, dim: int):

def main(args):

if os.path.exists(f"data/{args.name}"):
# Switch cwd to the root directory
os.chdir("../..")
if os.path.exists(f"datasets/{args.name}"):
res = input(
f"The data directory 'data/{args.name}' already exists. Press Enter to overwrite it."
f"The data directory 'datasets/{args.name}' already exists. Press Enter to overwrite it."
)
if res != "":
return
os.system(f"rm -r data/data/{args.name}")
os.system(f"rm -r datasets/{args.name}")

if not FUNCS.get(args.func):
print(f"Function {args.func} not found")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import sys

import numpy as np

from data import create_dataset
sys.path.insert(1, "../..")

from codes import create_dataset

if __name__ == "__main__":
# Create a new dataset
train_data = np.load("data/osu2008_old/train_data.npy")
test_data = np.load("data/osu2008_old/test_data.npy")
train_data = np.load("datasets/osu2008_old/train_data.npy")
test_data = np.load("datasets/osu2008_old/test_data.npy")
full_dataset = np.concatenate((train_data, test_data), axis=0)
np.random.shuffle(full_dataset)
labels = None
create_dataset(
"osu2008",
"osu2008_test",
full_dataset,
timesteps=np.linspace(0, 1, 100),
labels=labels,
Expand Down
File renamed without changes.
File renamed without changes
File renamed without changes.
2 changes: 1 addition & 1 deletion run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tqdm import tqdm

from codes.dataset.data_utils import download_data
from codes.utils.data_utils import download_data
from codes.train import create_task_list_for_surrogate, parallel_training, sequential_training
from codes.utils import (
check_training_status,
Expand Down
4 changes: 2 additions & 2 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytest

from codes.dataset.data_utils import check_and_load_data, download_data
from codes.utils.data_utils import check_and_load_data, download_data

paths = glob.glob("data/*/data.hdf5")
paths = glob.glob("datasets/*/data.hdf5")
dataset_names = [path.split("/")[1] for path in paths]


Expand Down

0 comments on commit d34e5fd

Please sign in to comment.