Skip to content

Commit

Permalink
MAE
Browse files Browse the repository at this point in the history
***
Add MAE pre-training

***
temp to fix
  • Loading branch information
Miguel Martin committed Mar 15, 2024
1 parent ae885f1 commit d95523a
Show file tree
Hide file tree
Showing 5 changed files with 617 additions and 0 deletions.
19 changes: 19 additions & 0 deletions ego4d/research/mae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Setup

```
pip install -r ego4d/research/requirements.txt
```

Clone omnivore
```
git clone omnivore
cd omnivore
pip install .
```

## Run

```bash
python3 ego4d/research/mae/train.py --config-name omnimae run_locally
```

60 changes: 60 additions & 0 deletions ego4d/research/mae/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass
from typing import List, Optional

from ego4d.research.common import SlurmConfig

@dataclass
class Ego4DInput:
video_root_dir: str

@dataclass
class EgoExoInput:
root_dir: str
metadata_path: str

@dataclass
class InputConfig:
ego4d_input: Optional[Ego4DInput]
egoexo_input: Optional[EgoExoInput]


@dataclass
class ModelConfig:
model_name: str


@dataclass
class PreprocessConfig:
# TODO
pass


@dataclass
class TrainConfig:
input_config: InputConfig
model_config: ModelConfig
pre_config: PreprocessConfig

slurm_config: SlurmConfig

checkpoint_dir: str
checkpoint_metric: str
batch_size: int
num_workers: int
prefetch_factor: int

num_epochs: int
accelerator: str
devices: int

tb_log_dir: str
tb_log_name: str

lr: float
beta1: float
beta2: float
wd: float
eps: float

eval_per_iter: int
eval_init: bool
37 changes: 37 additions & 0 deletions ego4d/research/mae/configs/maws.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
input_config:
# TODO: batch stuff here?
reader_class: "PyAvReader"
# reader_class: "TorchAudioStreamReader"
egoexo_input:
root_dir: "/checkpoint/miguelmartin/egoexo_data/dev/"
ego4d_input:
video_root_dir: "/datasets01/ego4d_track2/v1/full_scale"
metadata_path: "/checkpoint/miguelmartin/ego4d_data/ego4d.json"
model_config:
# model_name: "vit_huge_mae_pretraining"
model_name: "TODO"
batch_size: 32
num_workers: 20
prefetch_factor: 4
num_epochs: 1600
accelerator: "gpu"
devices: 1
tb_log_dir: "/private/home/miguelmartin/ego4d/ego4d_public/runs"
tb_log_name: "mae"
lr: 0.001
beta1: 0.9
beta2: 0.98
wd: 0.1
eps: 1.0e-6
eval_per_iter: 500
eval_init: true
checkpoint_dir: "/checkpoint/miguelmartin/maws_mae/checkpoints"
slurm_config:
slurm_log_folder: "slurm_log"
timeout_min: 240
constraint: "volta"
slurm_partition: "pixar"
slurm_array_parallelism: null
gpus_per_node: 1
cpus_per_task: 10
run_locally: false
Loading

0 comments on commit d95523a

Please sign in to comment.