Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-janssen committed Sep 27, 2024
2 parents 252487a + f7b3aef commit ee1ea86
Show file tree
Hide file tree
Showing 48 changed files with 1,904 additions and 168 deletions.
4 changes: 0 additions & 4 deletions .flake8

This file was deleted.

2 changes: 0 additions & 2 deletions .gitattributes

This file was deleted.

92 changes: 92 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
name: CI Pipeline

on:
push:
branches:
- main
- develop
pull_request:
branches:
- main
- develop

jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.10']

steps:
- name: Check out the repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"
- name: Install project with dev dependencies
run: |
poetry install --with dev
# Step to detect if the commit is a merge commit (skip CI on merge commits)
- name: Check if commit is a merge commit
id: check-merge
run: |
if git log -1 --pretty=%B | grep -q "Merge pull request"; then
echo "is-merge=true" >> $GITHUB_ENV
else
echo "is-merge=false" >> $GITHUB_ENV
fi
# Step to detect changes in pyproject.toml or poetry.lock
- name: Check for dependency changes
id: check-deps
if: env.is-merge == 'false' # Skip if it's a merge commit
run: |
if git diff --name-only HEAD~1 | grep -qE 'pyproject.toml|poetry.lock'; then
echo "dependencies-changed=true" >> $GITHUB_ENV
else
echo "dependencies-changed=false" >> $GITHUB_ENV
fi
# Step to generate requirements.txt if dependencies have changed
- name: Generate requirements.txt
if: env.is-merge == 'false' && env.dependencies-changed == 'true'
run: |
poetry export -f requirements.txt --output requirements.txt --without-hashes
# Commit and push the updated requirements.txt if dependencies have changed
- name: Commit and push updated requirements.txt
if: env.is-merge == 'false' && env.dependencies-changed == 'true'
run: |
git config --global user.name 'github-actions[bot]'
git config --global user.email 'github-actions[bot]@users.noreply.github.com'
git add requirements.txt
git commit -m "Update requirements.txt [skip ci]"
git push
# Run Black (auto-reformat) using Poetry
- name: Run Black (auto-reformat)
if: env.is-merge == 'false' # Skip if it's a merge commit
run: |
poetry run black .
# Run isort (auto-reformat) using Poetry
- name: Run isort (auto-reformat)
if: env.is-merge == 'false' # Skip if it's a merge commit
run: |
poetry run isort .
# Run pytest using Poetry
- name: Run pytest
if: env.is-merge == 'false' # Skip if it's a merge commit
run: |
poetry run pytest
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ optuna_runs/models
optuna_runs/studies
optuna_runs/plots
scripts
.*
build/
dist/
*.egg-info/
data.hdf5
.venv
.pytest_cachebuild
*.egg-info
.flake8
4 changes: 4 additions & 0 deletions codes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .benchmark import *
from .surrogates import *
from .train import *
from .utils import *
File renamed without changes.
2 changes: 1 addition & 1 deletion benchmark/bench_fcts.py → codes/benchmark/bench_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tabulate import tabulate
from torch.utils.data import DataLoader

from data import check_and_load_data
from codes.utils import check_and_load_data

from .bench_plots import (
inference_time_bar_plot,
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions benchmark/bench_utils.py → codes/benchmark/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch
import yaml

from surrogates.surrogate_classes import surrogate_classes
from surrogates.surrogates import SurrogateModel
from codes.surrogates import surrogate_classes
from codes.surrogates import SurrogateModel


def check_surrogate(surrogate: str, conf: dict) -> None:
Expand Down Expand Up @@ -627,7 +627,7 @@ def get_model_config(surr_name: str, config: dict) -> dict:
return {}

dataset_name = config["dataset"]["name"].lower()
dataset_folder = f"data/{dataset_name}"
dataset_folder = f"datasets/{dataset_name}"
config_file = f"{dataset_folder}/surrogates_config.py"

if os.path.exists(config_file):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from torch.utils.data import DataLoader, TensorDataset

# Use the below import to adjust the config class to the specific model
from surrogates.DeepONet.deeponet_config import MultiONetBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from .deeponet_config import MultiONetBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn

from .utils import mass_conservation_loss
from .don_utils import mass_conservation_loss


class BranchNet(nn.Module):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_project_path(relative_path):
"""
Construct the absolute path to a project resource (data or model) based on a relative path.
:param relative_path: A relative path to the resource, e.g., "data/dataset100" or "models/02-28/model.pth".
:param relative_path: A relative path to the resource, e.g., "datasets/dataset100" or "models/02-28/model.pth".
:return: The absolute path to the resource.
"""
import os
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions surrogates/FCNN/fcnn.py → codes/surrogates/FCNN/fcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from surrogates.FCNN.fcnn_config import FCNNBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from codes.surrogates.FCNN.fcnn_config import FCNNBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn


class FullyConnectedNet(nn.Module):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from surrogates.LatentNeuralODE.latent_neural_ode_config import (
from codes.surrogates.LatentNeuralODE.latent_neural_ode_config import (
LatentNeuralODEBaseConfig,
)
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution, worker_init_fn
from codes.surrogates.LatentNeuralODE.utilities import ChemDataset
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution, worker_init_fn


class LatentNeuralODE(AbstractSurrogateModel):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from torch.optim import Adam
from torch.utils.data import DataLoader

from surrogates.LatentNeuralODE.latent_neural_ode import Decoder, Encoder
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.LatentPolynomial.latent_poly_config import LatentPolynomialBaseConfig
from surrogates.surrogates import AbstractSurrogateModel
from utils import time_execution
from codes.surrogates.LatentNeuralODE.latent_neural_ode import Decoder, Encoder
from codes.surrogates.LatentNeuralODE.utilities import ChemDataset
from codes.surrogates.LatentPolynomial.latent_poly_config import LatentPolynomialBaseConfig
from codes.surrogates.surrogates import AbstractSurrogateModel
from codes.utils import time_execution


class LatentPoly(AbstractSurrogateModel):
Expand Down
File renamed without changes.
10 changes: 5 additions & 5 deletions surrogates/__init__.py → codes/surrogates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from surrogates.DeepONet.deeponet import MultiONet, TrunkNet, BranchNet
from surrogates.FCNN.fcnn import FullyConnected, FullyConnectedNet
from surrogates.LatentNeuralODE.latent_neural_ode import (
from .DeepONet.deeponet import MultiONet, TrunkNet, BranchNet
from .FCNN.fcnn import FullyConnected, FullyConnectedNet
from .LatentNeuralODE.latent_neural_ode import (
LatentNeuralODE,
ModelWrapper,
ODE,
Encoder,
Decoder,
)
from surrogates.LatentNeuralODE.utilities import ChemDataset
from surrogates.LatentPolynomial.latent_poly import LatentPoly, Polynomial
from .LatentNeuralODE.utilities import ChemDataset
from .LatentPolynomial.latent_poly import LatentPoly, Polynomial

from .surrogate_classes import surrogate_classes
from .surrogates import AbstractSurrogateModel, SurrogateModel
Expand Down
12 changes: 12 additions & 0 deletions codes/surrogates/surrogate_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from codes.surrogates.DeepONet.deeponet import MultiONet
from codes.surrogates.FCNN.fcnn import FullyConnected
from codes.surrogates.LatentNeuralODE.latent_neural_ode import LatentNeuralODE
from codes.surrogates.LatentPolynomial.latent_poly import LatentPoly

surrogate_classes = [
MultiONet,
FullyConnected,
LatentNeuralODE,
LatentPoly,
# Add any additional surrogate classes here
]
14 changes: 10 additions & 4 deletions surrogates/surrogates.py → codes/surrogates/surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import create_model_dir
from codes.utils import create_model_dir


class AbstractSurrogateModel(ABC, nn.Module):
Expand Down Expand Up @@ -273,7 +273,7 @@ def save(
if isinstance(value, torch.Tensor):
value = value.cpu().detach().numpy()
if isinstance(value, np.ndarray):
value = value.astype(np.float16)
value = value.astype(np.float32)
setattr(self, attribute, value)

# Save the hyperparameters as a yaml file
Expand Down Expand Up @@ -311,13 +311,19 @@ def load(
"""
if model_dir is None:
model_dict_path = os.path.join(
os.getcwd(), "trained", training_id, surr_name, f"{model_identifier}.pth"
os.getcwd(),
"trained",
training_id,
surr_name,
f"{model_identifier}.pth",
)
else:
model_dict_path = os.path.join(
model_dir, training_id, surr_name, f"{model_identifier}.pth"
)
model_dict = torch.load(model_dict_path, map_location=self.device)
model_dict = torch.load(
model_dict_path, map_location=self.device, weights_only=False
)
self.load_state_dict(model_dict["state_dict"])
for key, value in model_dict["attributes"].items():
# remove self.device from the attributes
Expand Down
File renamed without changes.
7 changes: 4 additions & 3 deletions train/train_fcts.py → codes/train/train_fcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from tqdm import tqdm

from benchmark.bench_utils import get_model_config, get_surrogate
from data import check_and_load_data, get_data_subset
from utils import (
from codes.benchmark.bench_utils import get_model_config, get_surrogate
from codes.utils import (
check_and_load_data,
get_data_subset,
get_progress_bar,
load_and_save_config,
make_description,
Expand Down
30 changes: 22 additions & 8 deletions utils/__init__.py → codes/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
from .data_utils import (
check_and_load_data,
create_dataset,
create_hdf5_dataset,
download_data,
get_data_subset,
normalize_data,
)
from .utils import (
read_yaml_config,
time_execution,
check_training_status,
create_model_dir,
get_progress_bar,
load_and_save_config,
set_random_seeds,
nice_print,
load_task_list,
make_description,
get_progress_bar,
worker_init_fn,
nice_print,
read_yaml_config,
save_task_list,
load_task_list,
check_training_status,
set_random_seeds,
time_execution,
worker_init_fn,
)

__all__ = [
"check_and_load_data",
"create_dataset",
"create_hdf5_dataset",
"download_data",
"get_data_subset",
"normalize_data",
"read_yaml_config",
"time_execution",
"create_model_dir",
Expand Down
Loading

0 comments on commit ee1ea86

Please sign in to comment.