Skip to content
Snippets Groups Projects
Commit 9884fd42 authored by Alexander Hartelt's avatar Alexander Hartelt
Browse files

Added setting for LR reduction on plateau

parent 942c35cc
No related tags found
1 merge request!9Tfkeras transition
......@@ -128,7 +128,7 @@ class Network:
def train_dataset(self, train_data: Dataset, test__data: Dataset, output, epochs: int = 100,
early_stopping: bool = True, early_stopping_interval: int = 5, tensorboardlogs: bool = True,
augmentation: bool = False,):
augmentation: bool = False, reduce_lr_on_plateu=False):
callbacks = []
train_gen = self.create_dataset_inputs(train_data, augmentation)
test_gen = self.create_dataset_inputs(test__data, data_augmentation=False)
......@@ -173,6 +173,14 @@ class Network:
write_images=False)
callbacks.append(diagnose_cb)
callbacks.append(tensorboard)
if reduce_lr_on_plateu:
redurce_lr_plateau = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=early_stopping_interval / 2,
min_lr=0.000001,
verbose=1)
callbacks.append(redurce_lr_plateau)
fg = self.model.fit(train_gen,
epochs=epochs,
......
......@@ -25,6 +25,7 @@ class TrainSettings(NamedTuple):
compute_baseline: bool = False
foreground_masks: bool = False
tensorboard: bool = False
reduce_lr_on_plateu = True
class Trainer:
......@@ -56,7 +57,8 @@ class Trainer:
True if self.settings.early_stopping_max_l_rate_drops != 0 else False,
early_stopping_interval=self.settings.early_stopping_max_l_rate_drops,
tensorboardlogs=self.settings.tensorboard,
augmentation=self.settings.data_augmentation)
augmentation=self.settings.data_augmentation, reduce_lr_on_plateu=
self.settings.reduce_lr_on_plateu)
def eval(self) -> None:
if len(self.settings.evaluation_data) > 0:
......@@ -68,15 +70,15 @@ class Trainer:
if __name__ == "__main__":
from pagesegmentation.lib.dataset import DatasetLoader
from pagesegmentation.scripts.generate_image_map import load_image_map_from_file
image_map = load_image_map_from_file('/home/alexander/Bilder/test_datenset/map.json/image_map.json')
image_map = load_image_map_from_file('/home/alexander/Bilder/datenset2/image_map.json')
dataset_loader = DatasetLoader(6, color_map=image_map)
print(dataset_loader.color_map)
train_data = dataset_loader.load_data_from_json(
['/home/alexander/Bilder/test_datenset/t.json'], "train")
['/home/alexander/Bilder/datenset2/t.json'], "train")
test_data = dataset_loader.load_data_from_json(
['/home/alexander/Bilder/test_datenset/t.json'], "test")
['/home/alexander/Bilder/datenset2/t.json'], "test")
eval_data = dataset_loader.load_data_from_json(
['/home/alexander/Bilder/test_datenset/t.json'], "eval")
['/home/alexander/Bilder/datenset2/t.json'], "eval")
settings = TrainSettings(
n_iter=100,
n_classes=len(dataset_loader.color_map),
......@@ -84,14 +86,14 @@ if __name__ == "__main__":
train_data=train_data,
validation_data=test_data,
display=10,
output='/home/alexander/Bilder/test_datenset/',
output='/home/alexander/Bilder/datenset2/',#'/home/alexander/Bilder/test_datenset/', #'/home/alexander/PycharmProjects/PageContent/pagecontent/demo/'
threads=8,
foreground_masks=False,
data_augmentation=True,
tensorboard=True,
n_architecture='mobile_net',
early_stopping_max_l_rate_drops=5,
load='/home/alexander/Bilder/test_datenset/best_model.hdf5'
load=None#'/home/alexander/Bilder/test_datenset/best_model.hdf5'
)
trainer = Trainer(settings)
trainer.train()
......
......@@ -38,12 +38,12 @@ def main():
help="Data used for early stopping"
)
parser.add_argument("--eval", type=str, nargs="*", default=[])
parser.add_argument("--display", type=int, default=100,
help="Display training progress each display iterations.")
parser.add_argument("--foreground_masks", default=False, action="store_true",
help="keep only mask parts that are foreground in binary image")
parser.add_argument("--tensorboard", type=str2bool, default=False,
help="Generate tenlogs for use in tensorboard")
help="Generate tensorboard logs")
parser.add_argument("--reduce_lr_on_plateu", type=str2bool, default=True,
help="Reducing LR when on plateau")
parser.add_argument("--color_map", type=str, required=True,
help="color_map to load")
args = parser.parse_args()
......@@ -83,6 +83,7 @@ def main():
foreground_masks=args.foreground_masks,
data_augmentation=args.data_augmentation,
tensorboard=args.tensorboard,
reduce_lr_on_plateu=args.reduce_lr_on_plateu,
)
trainer = Trainer(settings)
trainer.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment