Skip to content

Commit

Permalink
Resolved bugs, add more parameter and other minor changes and improve…
Browse files Browse the repository at this point in the history
…ments
  • Loading branch information
voldien committed Feb 27, 2024
1 parent a0c98c0 commit b9a3ed3
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 99 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ experiments/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/


commandline
*/cpkt*
*.png
*.json
*.txt
Expand Down
66 changes: 30 additions & 36 deletions superresolution/SuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import models.SuperResolutionVDSR
import models.SuperResolutionCNN


from core.common import ParseDefaultArgument, DefaultArgumentParser, setup_tensorflow_strategy
from util.dataProcessing import load_dataset_from_directory, \
configure_dataset_performance, dataset_super_resolution, augment_dataset
Expand Down Expand Up @@ -146,6 +145,7 @@ def ssim_loss(y_true, y_pred):
y_true_color = None
y_pred_color = None

#
if args.color_space == 'rgb':
# Remap [-1,1] to [0,1]
y_true_color = ((y_true + 1.0) * 0.5)
Expand Down Expand Up @@ -198,25 +198,26 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
non_augmented_dataset_train = dataset_super_resolution(dataset=training_dataset,
input_size=image_input_size,
output_size=image_output_size)
non_augmented_dataset_validation = None
if validation_dataset:
non_augmented_dataset_validation = dataset_super_resolution(dataset=validation_dataset,
input_size=image_input_size,
output_size=image_output_size)
non_augmented_dataset_validation = non_augmented_dataset_validation.batch(batch_size)

non_augmented_dataset_validation = dataset_super_resolution(dataset=validation_dataset,
input_size=image_input_size,
output_size=image_output_size)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
non_augmented_dataset_validation = non_augmented_dataset_validation.with_options(options)

non_augmented_dataset_train = configure_dataset_performance(ds=non_augmented_dataset_train, use_cache=False,
cache_path=None, shuffle_size=0)

non_augmented_dataset_train = non_augmented_dataset_train.batch(batch_size)
non_augmented_dataset_validation = non_augmented_dataset_validation.batch(batch_size)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
non_augmented_dataset_train = non_augmented_dataset_train.with_options(options)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
non_augmented_dataset_validation = non_augmented_dataset_validation.with_options(options)

# Configure cache, shuffle and performance of the dataset.
training_dataset = configure_dataset_performance(ds=training_dataset, use_cache=args.cache_ram,
cache_path=args.cache_path,
Expand Down Expand Up @@ -293,35 +294,24 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
metrics.append(PSNRMetric())

training_model.compile(optimizer=model_optimizer, loss=loss_fn, metrics=metrics)
# Save copy.
training_model.save(args.model_filepath)

# Create a callback that saves the model weights
checkpoint_path: str = args.checkpoint_dir
# checkpoint root_path
checkpoint_root_path: str = args.checkpoint_dir

# Create a callback that saves the model weights
if validation_data_ds:
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=(checkpoint_path + "-{epoch:03d}-{val_loss:.4f}.keras"),
monitor='val_loss',
mode='min',
save_best_only=True,
save_weights_only=True,
save_freq='epoch',
verbose=0)
else:
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=(checkpoint_path + "-{epoch:03d}-{loss:.4f}.keras"),
monitor='loss',
mode='min',
save_best_only=True,
save_weights_only=True,
save_freq='epoch',
verbose=0)

# If exists, load weights.
if os.path.exists(checkpoint_path):
training_model.load_weights(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_root_path, "cpkt-{epoch:02d}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor='val_loss' if validation_data_ds else 'loss',
mode='min',
save_freq='epoch',
verbose=1)

#
checkpoint = tf.train.Checkpoint(optimizer=model_optimizer, model=training_model)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_root_path)
if latest_checkpoint:
status = checkpoint.restore(save_path=latest_checkpoint).assert_consumed()

training_callbacks: list = [tf.keras.callbacks.TerminateOnNaN(), checkpoint_callback]

Expand All @@ -348,10 +338,14 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
graph_output_filepath: str = os.path.join(args.output_dir, "history_graph.png")
training_callbacks.append(GraphHistory(filepath=graph_output_filepath))

# Save copy.
training_model.save(args.model_filepath)

history_result = training_model.fit(x=training_dataset, validation_data=validation_data_ds, verbose='auto',
epochs=args.epochs,
callbacks=training_callbacks)
# Save final model.
#
# training_model.load_weights(checkpoint_path)
training_model.save(args.model_filepath)

# Test model.
Expand Down
33 changes: 18 additions & 15 deletions superresolution/UpScaleUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

sr_logger.addHandler(console_handler)


def save_result_file(argument):
upscale_image, new_cropped_size, full_output_path = argument

Expand All @@ -46,7 +47,7 @@ def super_resolution_upscale(argv):
default=None, type=str, help='')

#
parser.add_argument('--input-file', type=str,
parser.add_argument('--input-file', type=str, # action='append', TODO:add later
default=None, dest='input_files')

#
Expand All @@ -66,14 +67,13 @@ def super_resolution_upscale(argv):
help='Select Model Weight Path', default=None)

#
parser.add_argument('--device', type=str, dest='',
default=None, help='Select Device')
parser.add_argument('--device', action='append', default=None, required=False,
dest='devices', help='Select the device explicitly that will be used.')

parser.add_argument('--cpu', action='store_true',
default=False,
dest='use_explicit_cpu', help='Explicit use the CPU as the compute device.')


#
parser.add_argument('--verbosity', type=int, dest='accumulate',
default=1,
Expand All @@ -97,24 +97,26 @@ def super_resolution_upscale(argv):
sr_logger.setLevel(logging.DEBUG)

# Allow to use multiple GPU
strategy = setup_tensorflow_strategy(args={})
strategy = setup_tensorflow_strategy(args=args)
sr_logger.info('Number of devices: {0}'.format(strategy.num_replicas_in_sync))
with strategy.scope():

# TODO: fix output.
output_path: str = args.save_path
if os.path.isdir(output_path):
pass
if not os.path.exists(output_path):
os.mkdir(output_path)

# TODO improved extraction of filepaths.
input_filepaths: str = args.input_files
sr_logger.info("File Paths: " + str(input_filepaths))

if os.path.isdir(input_filepaths):
sr_logger.info("Directory Path: " + str(input_filepaths))
all_files = os.listdir(input_filepaths)
base_bath = input_filepaths
input_filepaths: list = [os.path.join(
base_bath, path) for path in all_files]
else: # Convert to list
sr_logger.info("File Path: " + str(input_filepaths))
input_filepaths: list = [input_filepaths]

batch_size: int = args.batch_size * strategy.num_replicas_in_sync
Expand Down Expand Up @@ -149,19 +151,19 @@ def super_resolution_upscale(argv):
# Create a pool of task scheduler.
pool = Pool(processes=16)

for file_path in input_filepaths:
for input_file_path in input_filepaths:

if not os.path.isfile(file_path):
if not os.path.isfile(input_file_path):
continue
sr_logger.info("Starting Image {0}".format(file_path))
sr_logger.info("Starting Image {0}".format(input_file_path))

#
base_filepath: str = os.path.basename(file_path)
base_filepath: str = os.path.basename(input_file_path)
full_output_path: str = os.path.join(
output_path, base_filepath)

# Open File and Convert to RGB Color Space.
input_im: Image = Image.open(file_path)
input_im: Image = Image.open(input_file_path)
input_im: Image = input_im.convert('RGB')

#
Expand Down Expand Up @@ -215,7 +217,7 @@ def super_resolution_upscale(argv):
normalized_subimage_color) * (1.0 / 128.0)
elif color_space == 'rgb':
cropped_sub_input_image = (
normalized_subimage_color * 2) - 1
normalized_subimage_color * 2) - 1

# Upscale.
upscale_raw_result = upscale_image_func(upscale_model, cropped_sub_input_image,
Expand All @@ -241,6 +243,7 @@ def super_resolution_upscale(argv):
pool.close()
pool.join()


# If running the script as main executable
if __name__ == '__main__':
try:
Expand Down
5 changes: 3 additions & 2 deletions superresolution/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def DefaultArgumentParser() -> argparse.ArgumentParser:
help='number of training element per each batch, during training.')
#
parser.add_argument('--checkpoint-filepath', type=str, dest='checkpoint_dir',
default="./training_checkpoints",
default="training_checkpoints",
help='Set the path the checkpoint will be saved/loaded.')
#
parser.add_argument('--checkpoint-every-epoch', type=int, dest='checkpoint_every_nth_epoch',
Expand Down Expand Up @@ -115,7 +115,7 @@ def DefaultArgumentParser() -> argparse.ArgumentParser:
def ParseDefaultArgument(args: dict):
#
tf.config.experimental.enable_tensor_float_32_execution(True)

# Set global precision default policy.
if args.use_float16:
mixed_precision.set_global_policy('mixed_float16')
Expand Down Expand Up @@ -159,6 +159,7 @@ def create_virtual_gpu_devices():
print(e)
return []


def setup_tensorflow_strategy(args: dict):
# Configure
if args.use_explicit_cpu:
Expand Down
20 changes: 14 additions & 6 deletions superresolution/models/DCSuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
# Model Construct Parameters.
regularization: float = kwargs.get("regularization", 0.000001) #
upscale_mode: int = kwargs.get("upscale_mode", 2) #
nr_filters: int = kwargs.get("filters", 64)

#
return create_simple_model(input_shape=input_shape,
Expand All @@ -59,48 +60,55 @@ def create_simple_model(input_shape: tuple, output_shape: tuple, regularization:

for i in range(0, int(upscale_mode / 2)):
nrfilters = output_width

#
x = layers.Conv2D(filters=nrfilters, kernel_size=(9, 9), strides=1, padding='same',
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

#
x = layers.Conv2D(filters=nrfilters / 2, kernel_size=(4, 4), strides=1, padding='same',
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

#
x = layers.Conv2D(filters=nrfilters / 4, kernel_size=(3, 3), strides=1, padding='same',
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

# Upscale -
x = layers.Conv2DTranspose(filters=output_width, kernel_size=(5, 5), strides=(
2, 2), use_bias=use_bias, padding='same', kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
2, 2), use_bias=use_bias, padding='same', kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

#
x = layers.Conv2D(filters=nrfilters, kernel_size=(4, 4), strides=1, padding='same',
use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

# Output to 3 channel output.
x = layers.Conv2DTranspose(filters=output_channels, kernel_size=(9, 9), strides=(
1, 1), padding='same', use_bias=use_bias, kernel_initializer=tf.keras.initializers.HeNormal(), bias_initializer=tf.keras.initializers.HeNormal())(x)
1, 1), padding='same', use_bias=use_bias, kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Activation('tanh', dtype='float32')(x)
x = layers.ActivityRegularization(l1=regularization, l2=0)(x)

Expand Down
19 changes: 11 additions & 8 deletions superresolution/models/PostDCSuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def __init__(self):

def load_argument(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False, prog="Post Basic SuperResolution",
description="Basic Deep Convolutional Super Resolution")
description="Basic Deep Convolutional Super Resolution")

group = parser.add_argument_group(self.get_name())

#
group.add_argument('--regularization', dest='regularization', required=False,
type=float,
default=0.0001,
help='Set the L1 Regularization applied.')
type=float,
default=0.0001,
help='Set the L1 Regularization applied.')
#
return parser

Expand Down Expand Up @@ -53,17 +53,20 @@ def create_post_super_resolution(input_shape: tuple, output_shape: tuple):
filter_size = 2 ** (i + 6)
filter_size = min(filter_size, 1024)

x = layers.Conv2D(filter_size, kernel_size=(3, 3), strides=1, padding='same', kernel_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Conv2D(filter_size, kernel_size=(3, 3), strides=1, padding='same',
kernel_initializer=tf.keras.initializers.HeNormal())(x)
if batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = layers.ReLU(dtype='float32')(x)

x = layers.Conv2D(filter_size, kernel_size=(3, 3), padding='same', strides=1, kernel_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Conv2D(filter_size, kernel_size=(3, 3), padding='same', strides=1,
kernel_initializer=tf.keras.initializers.HeNormal())(x)
if batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = layers.ReLU(dtype='float32')(x)

x = layers.Conv2D(filter_size, kernel_size=(3, 3), padding='same', strides=1, kernel_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Conv2D(filter_size, kernel_size=(3, 3), padding='same', strides=1,
kernel_initializer=tf.keras.initializers.HeNormal())(x)
if batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = layers.ReLU(dtype='float32')(x)
Expand All @@ -72,7 +75,7 @@ def create_post_super_resolution(input_shape: tuple, output_shape: tuple):

x = layers.Conv2D(filters=3, kernel_size=(4, 4), strides=(
1, 1), padding='same', kernel_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Activation('tanh')(x)#TODO: determine
x = layers.Activation('tanh')(x) # TODO: determine

x = layers.add([x, upscale])
x = layers.Activation('tanh')(x)
Expand Down
4 changes: 2 additions & 2 deletions superresolution/models/SuperResolutionAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def create_dscr_autoencoder_model(input_shape: tuple, output_shape: tuple, use_r
if use_resnet:
if last_sum_layer is not None:
last_sum_layer = layers.Conv2D(filters=filter_size, kernel_size=(1, 1),
kernel_initializer=tf.keras.initializers.GlorotUniform(),
strides=(2, 2))(last_sum_layer)
kernel_initializer=tf.keras.initializers.GlorotUniform(),
strides=(2, 2))(last_sum_layer)
x = layers.add([attach_layer, last_sum_layer])

last_sum_layer = x
Expand Down
Loading

0 comments on commit b9a3ed3

Please sign in to comment.