Skip to content

Commit

Permalink
Completed restructuring: Merge develop into main
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-janssen committed Sep 24, 2024
2 parents 506fa0d + 4760a00 commit 92a2570
Show file tree
Hide file tree
Showing 38 changed files with 171 additions and 56 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: CI Pipeline

on:
push:
branches:
- main
- develop
pull_request:
branches:
- main
- develop

jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.10']

steps:
- name: Check out the repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install project with dev dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev,tests]
- name: Run Black (auto-reformat)
run: black .

- name: Run isort (auto-reformat)
run: isort .

- name: Run pytest
run: pytest
8 changes: 6 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ optuna_runs/models
optuna_runs/studies
optuna_runs/plots
scripts
.*
build
build/
dist/
*.egg-info/
data.hdf5
.venv
.pytest_cachebuild
*.egg-info
4 changes: 4 additions & 0 deletions codes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .benchmark import *
from .surrogates import *
from .train import *
from .utils import *
File renamed without changes.
2 changes: 1 addition & 1 deletion benchmark/bench_fcts.py → codes/benchmark/bench_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tabulate import tabulate
from torch.utils.data import DataLoader

from data import check_and_load_data
from codes.dataset import check_and_load_data

from .bench_plots import (
inference_time_bar_plot,
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions benchmark/bench_utils.py → codes/benchmark/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch
import yaml

from surrogates.surrogate_classes import surrogate_classes
from surrogates.surrogates import SurrogateModel
from codes.surrogates import surrogate_classes
from codes.surrogates import SurrogateModel


def check_surrogate(surrogate: str, conf: dict) -> None:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from torch.utils.data import DataLoader, TensorDataset

# Use the below import to adjust the config class to the specific model
from surrogates.DeepONet.deeponet_config import MultiONetBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from .deeponet_config import MultiONetBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn

from .utils import mass_conservation_loss
from .don_utils import mass_conservation_loss


class BranchNet(nn.Module):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions surrogates/FCNN/fcnn.py → codes/surrogates/FCNN/fcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from surrogates.FCNN.fcnn_config import FCNNBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from codes.surrogates.FCNN.fcnn_config import FCNNBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn


class FullyConnectedNet(nn.Module):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from surrogates.LatentNeuralODE.latent_neural_ode_config import (
from codes.surrogates.LatentNeuralODE.latent_neural_ode_config import (
LatentNeuralODEBaseConfig,
)
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from codes.surrogates.LatentNeuralODE.utilities import ChemDataset
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn


class LatentNeuralODE(AbstractSurrogateModel):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from torch.optim import Adam
from torch.utils.data import DataLoader

from surrogates.LatentNeuralODE.latent_neural_ode import Decoder, Encoder
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.LatentPolynomial.latent_poly_config import LatentPolynomialBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution
from codes.surrogates.LatentNeuralODE.latent_neural_ode import Decoder, Encoder
from codes.surrogates.LatentNeuralODE.utilities import ChemDataset
from codes.surrogates.LatentPolynomial.latent_poly_config import LatentPolynomialBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution


class LatentPoly(AbstractSurrogateModel):
Expand Down
File renamed without changes.
10 changes: 5 additions & 5 deletions surrogates/__init__.py → codes/surrogates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from surrogates.DeepONet.deeponet import MultiONet, TrunkNet, BranchNet
from surrogates.FCNN.fcnn import FullyConnected, FullyConnectedNet
from surrogates.LatentNeuralODE.latent_neural_ode import (
from .DeepONet.deeponet import MultiONet, TrunkNet, BranchNet
from .FCNN.fcnn import FullyConnected, FullyConnectedNet
from .LatentNeuralODE.latent_neural_ode import (
LatentNeuralODE,
ModelWrapper,
ODE,
Encoder,
Decoder,
)
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.LatentPolynomial.latent_poly import LatentPoly, Polynomial
from .LatentNeuralODE.utilities import ChemDataset
from .LatentPolynomial.latent_poly import LatentPoly, Polynomial

from .surrogate_classes import surrogate_classes
from .surrogates import AbstractSurrogateModel, SurrogateModel
Expand Down
12 changes: 12 additions & 0 deletions codes/surrogates/surrogate_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from codes.surrogates.DeepONet.deeponet import MultiONet
from codes.surrogates.FCNN.fcnn import FullyConnected
from codes.surrogates.LatentNeuralODE.latent_neural_ode import LatentNeuralODE
from codes.surrogates.LatentPolynomial.latent_poly import LatentPoly

surrogate_classes = [
MultiONet,
FullyConnected,
LatentNeuralODE,
LatentPoly,
# Add any additional surrogate classes here
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import create_model_dir
from codes.utils import create_model_dir


class AbstractSurrogateModel(ABC, nn.Module):
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions train/train_fcts.py → codes/train/train_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from tqdm import tqdm

from benchmark.bench_utils import get_model_config, get_surrogate
from data import check_and_load_data, get_data_subset
from utils import (
from codes.benchmark.bench_utils import get_model_config, get_surrogate
from codes.dataset import check_and_load_data, get_data_subset
from codes.utils import (
get_progress_bar,
load_and_save_config,
make_description,
Expand Down
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Global settings for the benchmark
training_id: "all_models_3"
surrogates: ["LatentPoly", "LatentNeuralODE"]
batch_size: [256, 256]
epochs: [15000, 10000]
training_id: "delete_me2"
surrogates: ["LatentPoly", "LatentNeuralODE", "FullyConnected", "MultiONet"]
batch_size: [256, 256, 256, 256]
epochs: [2, 2, 2, 2]
dataset:
name: "osu2008"
log10_transform: True
normalise: "minmax" # "standardise", "minmax", "disable"
use_optimal_params: True
devices: ["cuda:5", "cuda:6", "cuda:7", "cuda:8"]
devices: ["cuda:1"]
seed: 42
verbose: False

Expand Down
2 changes: 1 addition & 1 deletion data_gen/generate_simple_ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from scipy.integrate import solve_ivp

from data.data_utils import create_dataset
from codes.dataset.data_utils import create_dataset


def lotka_volterra(t, n):
Expand Down
65 changes: 65 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
[build-system]
requires = [
"setuptools >=61",
"setuptools_scm >=7",
]
build-backend = "setuptools.build_meta"


[project]
name = "CODES"
description = "Benchmarking tool for neural (chemical) ODE surrogate models"
readme = "README.md"
maintainers = [
{ name = "Robin Janssen, Immanuel Sulzer", email = "[email protected]" },
]
dynamic = ["version"]
requires-python = ">=3.8"
license = { text = "GPL-3.0" }
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
]
dependencies = [
"h5py==3.9.0",
"matplotlib==3.8.4",
"numpy==1.26.4",
"PyYAML==6.0.1",
"scipy==1.14.1",
"tabulate==0.9.0",
"torch==2.3.0",
"torchode==0.2.0",
"tqdm==4.66.4",
]

[project.optional-dependencies]
tests = [
"pytest",
#"pytest-cov",
]
dev = [
"pre-commit",
"black",
"isort",
]
# docs = [
# "sphinx",
# "sphinx_mdinclude",
# "sphinx_rtd_theme",
# ]


[tool.setuptools]
packages = [
"codes",
]


[tool.pytest.ini_options]
testpaths = [
"test",
]
pythonpath = "."
python_functions = "test_*"
python_classes = "Test*"
4 changes: 2 additions & 2 deletions run_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from argparse import ArgumentParser

from benchmark import (
from codes.benchmark import (
check_benchmark,
check_surrogate,
compare_models,
get_surrogate,
run_benchmark,
)
from utils import nice_print, read_yaml_config
from codes.utils import nice_print, read_yaml_config


def main(args):
Expand Down
6 changes: 3 additions & 3 deletions run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from tqdm import tqdm

from data.data_utils import download_data
from train import create_task_list_for_surrogate, parallel_training, sequential_training
from utils import (
from codes.dataset.data_utils import download_data
from codes.train import create_task_list_for_surrogate, parallel_training, sequential_training
from codes.utils import (
check_training_status,
load_and_save_config,
load_task_list,
Expand Down
12 changes: 0 additions & 12 deletions surrogates/surrogate_classes.py

This file was deleted.

4 changes: 2 additions & 2 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from data.data_utils import check_and_load_data, download_data
from codes.dataset.data_utils import check_and_load_data, download_data

paths = glob.glob("data/*/data.hdf5")
dataset_names = [path.split("/")[1] for path in paths]
Expand All @@ -16,6 +16,7 @@ def dataset(request: pytest.FixtureRequest):

def test_check_and_load_data(dataset):
try:
download_data(dataset)
_ = check_and_load_data(dataset)
except Exception as e:
pytest.fail(f"Failed to load data for {dataset}: {e}")
Expand All @@ -32,4 +33,3 @@ def test_download_data(dataset, tmp_path: Path):
def test_download_data_invalid_dataset():
with pytest.raises(ValueError):
download_data("invalid_dataset")

2 changes: 1 addition & 1 deletion test/test_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from surrogates.surrogate_classes import surrogate_classes
from codes import surrogate_classes

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
N_CHEMICALS = 10
Expand Down

0 comments on commit 92a2570

Please sign in to comment.