Skip to content

Commit

Permalink
Merge pull request #445 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Replace dst with src
  • Loading branch information
yoshitomo-matsubara authored Mar 18, 2024
2 parents 3799847 + 64aa752 commit 97c4207
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
9 changes: 5 additions & 4 deletions examples/torchvision/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def evaluate(model, data_loader, device, device_ids, distributed, log_freq=1000,
return metric_logger.acc1.global_avg


def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
def train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args):
logger.info('Start training')
train_config = config['train']
Expand All @@ -129,8 +129,8 @@ def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
device, device_ids, distributed, lr_factor)
best_val_top1_accuracy = 0.0
optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
if file_util.check_if_exists(dst_ckpt_file_path):
best_val_top1_accuracy, _ = load_ckpt(dst_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
if file_util.check_if_exists(src_ckpt_file_path):
best_val_top1_accuracy, _ = load_ckpt(src_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)

log_freq = train_config['log_freq']
student_model_without_ddp = student_model.module if module_util.check_if_wrapped(student_model) else student_model
Expand Down Expand Up @@ -178,13 +178,14 @@ def main(args):
load_model(teacher_model_config, device, distributed) if teacher_model_config is not None else None
student_model_config =\
models_config['student_model'] if 'student_model' in models_config else models_config['model']
src_ckpt_file_path = student_model_config.get('src_ckpt', None)
dst_ckpt_file_path = student_model_config['dst_ckpt']
student_model = load_model(student_model_config, device, distributed)
if args.log_config:
logger.info(config)

if not args.test_only:
train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args)

student_model_without_ddp =\
Expand Down
9 changes: 5 additions & 4 deletions examples/torchvision/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def evaluate(model, data_loader, iou_types, device, device_ids, distributed, log
return coco_evaluator


def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
def train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args):
logger.info('Start training')
train_config = config['train']
Expand All @@ -164,8 +164,8 @@ def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
device, device_ids, distributed, lr_factor)
best_val_map = 0.0
optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
if file_util.check_if_exists(dst_ckpt_file_path):
best_val_map, _ = load_ckpt(dst_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
if file_util.check_if_exists(src_ckpt_file_path):
best_val_map, _ = load_ckpt(src_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)

log_freq = train_config['log_freq']
iou_types = args.iou_types
Expand Down Expand Up @@ -218,13 +218,14 @@ def main(args):
teacher_model = load_model(teacher_model_config, device) if teacher_model_config is not None else None
student_model_config =\
models_config['student_model'] if 'student_model' in models_config else models_config['model']
src_ckpt_file_path = student_model_config.get('src_ckpt', None)
dst_ckpt_file_path = student_model_config['dst_ckpt']
student_model = load_model(student_model_config, device)
if args.log_config:
logger.info(config)

if not args.test_only:
train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args)

student_model_without_ddp =\
Expand Down
9 changes: 5 additions & 4 deletions examples/torchvision/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def evaluate(model, data_loader, device, device_ids, distributed, num_classes,
return seg_evaluator


def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
def train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args):
logger.info('Start training')
train_config = config['train']
Expand All @@ -118,8 +118,8 @@ def train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
device, device_ids, distributed, lr_factor)
best_val_miou = 0.0
optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
if file_util.check_if_exists(dst_ckpt_file_path):
best_val_miou, _ = load_ckpt(dst_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
if file_util.check_if_exists(src_ckpt_file_path):
best_val_miou, _ = load_ckpt(src_ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)

log_freq = train_config['log_freq']
student_model_without_ddp = student_model.module if module_util.check_if_wrapped(student_model) else student_model
Expand Down Expand Up @@ -171,13 +171,14 @@ def main(args):
teacher_model = load_model(teacher_model_config, device) if teacher_model_config is not None else None
student_model_config =\
models_config['student_model'] if 'student_model' in models_config else models_config['model']
src_ckpt_file_path = student_model_config.get('src_ckpt', None)
dst_ckpt_file_path = student_model_config['dst_ckpt']
student_model = load_model(student_model_config, device)
if args.log_config:
logger.info(config)

if not args.test_only:
train(teacher_model, student_model, dataset_dict, dst_ckpt_file_path,
train(teacher_model, student_model, dataset_dict, src_ckpt_file_path, dst_ckpt_file_path,
device, device_ids, distributed, world_size, config, args)

student_model_without_ddp =\
Expand Down

0 comments on commit 97c4207

Please sign in to comment.