Commit e50d899c authored by Paul Bethge's avatar Paul Bethge
Browse files

add augmentation, model save and load, fix report

parent 7b15f728
......@@ -23,13 +23,13 @@ from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy
from tensorflow.keras.models import load_model
import src.models as models
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function, visualize_results
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize
from import AudioAugmenter
def train(config_path, log_dir, model_path):
def train(config_path, log_dir):
# Config
config = load(open(config_path, "rb"))
......@@ -45,9 +45,10 @@ def train(config_path, log_dir, model_path):
augment = config["augment"]
learning_rate = config["learning_rate"]
model_name = config["model"]
model_path = config["model_path"]
# create or load the model
if model_path:
if model_path != "":
model = load_model(model_path)
model_class = getattr(models, model_name)
......@@ -67,13 +68,15 @@ def train(config_path, log_dir, model_path):
val_ds = val_ds.batch(batch_size)
# Optional augmentation of the training set
## Note: tf.py_function allows to construct a graph but code is executed in python
if augment:
augmenter = AudioAugmenter(audio_length_s, sample_rate)
def process_aug(audio, label):
audio = augmenter.augment_audio_array(audio)
return audio, label
aug_wav = lambda x,y: tf.py_function(process_aug, [x,y], tf.float32),
# normalize audio
def process(audio, label):
audio = tf_normalize(audio)
......@@ -88,47 +91,59 @@ def train(config_path, log_dir, model_path):
# Training Callbacks
tensorboard_callback = TensorBoard(log_dir=log_dir, write_images=True)
csv_logger_callback = CustomCSVCallback(os.path.join(log_dir, "log.csv"))
reduce_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=1,
reduce_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3,
verbose=1, min_lr=0.000001, min_delta=0.001)
checkpoint_filename = os.path.join(log_dir, "trained_models", "weights.{epoch:02d}")
checkpoint_filename = os.path.join(log_dir, "trained_models", "model.{epoch:02d}")
model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1,
#early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode="min")
early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode="min")
callbacks = [tensorboard_callback, csv_logger_callback, reduce_on_plateau, model_checkpoint_callback,
# early_stopping_callback,
# comment callbacks that you don't care about
callbacks = [
# tensorboard_callback,
# early_stopping_callback,
# Training
history =, epochs=num_epochs,
callbacks=callbacks, validation_data=val_ds)
# visualize_results(history, config)
# Do evaluation on model with best validation accuracy
# TODO Do evaluation on model with best validation accuracy
visualize_results(history, config, log_dir)
best_epoch = np.argmax(history.history["val_categorical_accuracy"])
print("Log files: ", log_dir)
print("Best epoch: ", best_epoch)
return checkpoint_filename.replace("{epoch:02d}", "{:02d}".format(best_epoch))
checkpoint_filename.replace("{epoch:02d}", "{:02d}".format(best_epoch))
print("Best model at: ", checkpoint_filename)
return model, best_epoch
if __name__ == "__main__":
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="config_train.yaml")
parser.add_argument('--model_path', default=None, help="Path to a trained model for retraining")
parser.add_argument('--config', default="config_train.yaml",
help="Path to the required config file.")
cli_args = parser.parse_args()
# copy models & config for later
log_dir = os.path.join("logs","%Y-%m-%d-%H-%M-%S"))
print("Logging to {}".format(log_dir))
# copy models & config for later
shutil.copytree("src/models", os.path.join(log_dir, "models"))
shutil.copy(cli_args.config, log_dir)
model_file_name = train(cli_args.config, log_dir, cli_args.model_path)
print("Best model at: ", model_file_name)
# train and save the best model as SavedModel
model, best_epoch = train(cli_args.config, log_dir)
saved_model_path = os.path.join(log_dir, "model_" + str(best_epoch)), signatures=get_saved_model_function(model))
#TODO visualize the training process and save as png
#TODO convert model to saved model
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment