Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-janssen committed Sep 27, 2024
2 parents d3af4e6 + ce2578d commit f7b3aef
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 31 deletions.
12 changes: 9 additions & 3 deletions codes/surrogates/surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def save(
if isinstance(value, torch.Tensor):
value = value.cpu().detach().numpy()
if isinstance(value, np.ndarray):
value = value.astype(np.float16)
value = value.astype(np.float32)
setattr(self, attribute, value)

# Save the hyperparameters as a yaml file
Expand Down Expand Up @@ -311,13 +311,19 @@ def load(
"""
if model_dir is None:
model_dict_path = os.path.join(
os.getcwd(), "trained", training_id, surr_name, f"{model_identifier}.pth"
os.getcwd(),
"trained",
training_id,
surr_name,
f"{model_identifier}.pth",
)
else:
model_dict_path = os.path.join(
model_dir, training_id, surr_name, f"{model_identifier}.pth"
)
model_dict = torch.load(model_dict_path, map_location=self.device)
model_dict = torch.load(
model_dict_path, map_location=self.device, weights_only=False
)
self.load_state_dict(model_dict["state_dict"])
for key, value in model_dict["attributes"].items():
# remove self.device from the attributes
Expand Down
24 changes: 19 additions & 5 deletions codes/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def check_and_load_data(
verbose: bool = True,
log: bool = True,
normalisation_mode: str = "standardise",
tolerance: float | None = 1e-20,
):
"""
Check the specified dataset and load the data based on the mode (train or test).
Expand All @@ -29,6 +30,8 @@ def check_and_load_data(
verbose (bool): Whether to print information about the loaded data.
log (bool): Whether to log-transform the data (log10).
normalisation_mode (str): The normalization mode, either "disable", "minmax", or "standardise".
tolerance (float, optional): The tolerance value for log-transformation.
Values below this will be set to the tolerance value. Pass None to disable.
Returns:
tuple: Loaded data and timesteps.
Expand Down Expand Up @@ -66,9 +69,14 @@ def check_and_load_data(
)

# Load data
train_data = np.asarray(f["train"])
test_data = np.asarray(f["test"])
val_data = np.asarray(f["val"])
train_data = np.asarray(f["train"], dtype=np.float32)
test_data = np.asarray(f["test"], dtype=np.float32)
val_data = np.asarray(f["val"], dtype=np.float32)

if tolerance is not None:
train_data = np.where(train_data < tolerance, tolerance, train_data)
test_data = np.where(test_data < tolerance, tolerance, test_data)
val_data = np.where(val_data < tolerance, tolerance, val_data)

# Log transformation
if log:
Expand Down Expand Up @@ -416,8 +424,14 @@ def create_dataset(
train_data = full_data[:n_train]
test_data = full_data[n_train : n_train + n_test]
val_data = full_data[n_train + n_test :]
if any(dim == 0 for shape in (train_data.shape, test_data.shape, val_data.shape) for dim in shape):
raise ValueError("Split data contains zero samples. One of the splits is too small.")
if any(
dim == 0
for shape in (train_data.shape, test_data.shape, val_data.shape)
for dim in shape
):
raise ValueError(
"Split data contains zero samples. One of the splits is too small."
)

if labels is not None:
if not isinstance(labels, list):
Expand Down
3 changes: 2 additions & 1 deletion codes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def check_training_status(config: dict) -> tuple[str, bool]:
confirmation = input("Overwrite? [y/n]: ")
if confirmation.lower() == "y":
print("Overwriting the saved configuration.")
os.remove(task_list_filepath)
if os.path.exists(task_list_filepath):
os.remove(task_list_filepath)
copy_config = True
else:
print("Continuing training with the previous configuration.")
Expand Down
13 changes: 7 additions & 6 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Global settings for the benchmark
training_id: "delete_me3"
surrogates: ["LatentPoly", "LatentNeuralODE", "FullyConnected", "MultiONet"]
batch_size: [256, 256, 256, 256]
epochs: [2, 2, 2, 2]
training_id: "delete_me4"
surrogates: ["LatentPoly", "FullyConnected", "MultiONet"]
batch_size: [256, 256, 256]
epochs: [2, 2, 2]
dataset:
name: "osu2008"
log10_transform: False
name: "simple_primordial"
log10_transform: True
normalise: "minmax" # "standardise", "minmax", "disable"
use_optimal_params: True
tolerance: 1e-20
devices: ["cuda:1"]
seed: 42
verbose: False
Expand Down
3 changes: 2 additions & 1 deletion datasets/data_sources.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ osu2008: "https://zenodo.org/records/13359976/files/data.hdf5?download=1"
simple_reaction: "https://zenodo.org/records/13624781/files/data.hdf5?download=1"
simple_ode: "https://zenodo.org/records/13624783/files/data.hdf5?download=1"
lotka_volterra: "https://zenodo.org/records/13624788/files/data.hdf5?download=1"
branca24: "https://zenodo.org/records/13624794/files/data.hdf5?download=1"
branca24: "https://zenodo.org/records/13624794/files/data.hdf5?download=1"
simple_primordial: "https://zenodo.org/records/13754361/files/data.hdf5?download=1"
60 changes: 45 additions & 15 deletions test/test_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,73 @@ def dataloaders(instance):
)
return dataloader_train, dataloader_test, dataloader_val


def test_init(instance):
assert instance.device == DEVICE, f"device is wrong: {instance.device} != {DEVICE}"
assert instance.n_chemicals == N_CHEMICALS, f"n_chemicals is wrong: {instance.n_chemicals} != {N_CHEMICALS}"
assert instance.n_timesteps == N_TIMESTEPS, f"n_timesteps is wrong: {instance.n_timesteps} != {N_TIMESTEPS}"
assert (
instance.n_chemicals == N_CHEMICALS
), f"n_chemicals is wrong: {instance.n_chemicals} != {N_CHEMICALS}"
assert (
instance.n_timesteps == N_TIMESTEPS
), f"n_timesteps is wrong: {instance.n_timesteps} != {N_TIMESTEPS}"
assert instance.config is not None, "config is None"


def test_dataloader(instance, dataloaders):
dataloader_train, dataloader_test, dataloader_val = dataloaders

assert isinstance(dataloader_train, torch.utils.data.DataLoader), "dataloader_train is not a DataLoader"
assert isinstance(dataloader_test, torch.utils.data.DataLoader), "dataloader_test is not a DataLoader"
assert isinstance(dataloader_val, torch.utils.data.DataLoader), "dataloader_val is not a DataLoader"

assert dataloader_train.batch_size == BATCH_SIZE, f"dataloader_train has wrong batch size: {dataloader_train.batch_size} != {BATCH_SIZE}"
assert dataloader_test.batch_size == BATCH_SIZE, f"dataloader_test has wrong batch size: {dataloader_test.batch_size} != {BATCH_SIZE}"
assert dataloader_val.batch_size == BATCH_SIZE, f"dataloader_val has wrong batch size: {dataloader_val.batch_size} != {BATCH_SIZE}"
assert isinstance(
dataloader_train, torch.utils.data.DataLoader
), "dataloader_train is not a DataLoader"
assert isinstance(
dataloader_test, torch.utils.data.DataLoader
), "dataloader_test is not a DataLoader"
assert isinstance(
dataloader_val, torch.utils.data.DataLoader
), "dataloader_val is not a DataLoader"

assert (
dataloader_train.batch_size == BATCH_SIZE
), f"dataloader_train has wrong batch size: {dataloader_train.batch_size} != {BATCH_SIZE}"
assert (
dataloader_test.batch_size == BATCH_SIZE
), f"dataloader_test has wrong batch size: {dataloader_test.batch_size} != {BATCH_SIZE}"
assert (
dataloader_val.batch_size == BATCH_SIZE
), f"dataloader_val has wrong batch size: {dataloader_val.batch_size} != {BATCH_SIZE}"


def test_predict(instance, dataloaders):
dataloader_train, _, _ = dataloaders
predictions, targets = instance.predict(dataloader_train)

assert predictions.shape == torch.Size([3, N_TIMESTEPS, N_CHEMICALS]), f"predictions has wrong shape: {predictions.shape} != [3, {N_TIMESTEPS}, {N_CHEMICALS}]"
assert targets.shape == torch.Size([3, N_TIMESTEPS, N_CHEMICALS]), f"targets has wrong shape: {targets.shape} != [3, {N_TIMESTEPS}, {N_CHEMICALS}]"
assert predictions.shape == torch.Size(
[3, N_TIMESTEPS, N_CHEMICALS]
), f"predictions has wrong shape: {predictions.shape} != [3, {N_TIMESTEPS}, {N_CHEMICALS}]"
assert targets.shape == torch.Size(
[3, N_TIMESTEPS, N_CHEMICALS]
), f"targets has wrong shape: {targets.shape} != [3, {N_TIMESTEPS}, {N_CHEMICALS}]"


def test_fit(instance, dataloaders):
dataloader_train, dataloader_test, _ = dataloaders
instance.fit(dataloader_train, dataloader_test, epochs=2)

assert instance.train_loss.shape == torch.Size([2]), f"train_loss has wrong shape: {instance.train_loss.shape} != [2]"
assert instance.test_loss.shape == torch.Size([2]), f"test_loss has wrong shape: {instance.test_loss.shape} != [2]"
assert instance.MAE.shape == torch.Size([2]), f"MAE has wrong shape: {instance.MAE.shape} != [2]"
assert instance.train_loss.shape == torch.Size(
[2]
), f"train_loss has wrong shape: {instance.train_loss.shape} != [2]"
assert instance.test_loss.shape == torch.Size(
[2]
), f"test_loss has wrong shape: {instance.test_loss.shape} != [2]"
assert instance.MAE.shape == torch.Size(
[2]
), f"MAE has wrong shape: {instance.MAE.shape} != [2]"


def test_save_load(instance, tmp_path):
model_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
model_name = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(10)
)
instance.save(
model_name=model_name, base_dir=tmp_path, training_id="TestID", data_params={}
)
Expand Down

0 comments on commit f7b3aef

Please sign in to comment.