Commit 28bd7cbe authored by Paul Bethge's avatar Paul Bethge
Browse files

use tf.dataset and model.fit

parent 5ee3adf5
......@@ -13,20 +13,23 @@ import argparse
from datetime import datetime
from yaml import load
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam, RMSprop, SGD
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy
from tensorflow.keras.models import load_model
import src.models as models
from src.utils.training_utils import *
from src.audio.generators import AugBatchGenerator
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function, visualize_results
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize
from src.audio.augment import AudioAugmenter
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# config = tf.config.experimental.set_memory_growth(physical_devices[0], True)
def train(config_path, log_dir, model_path):
def train(config_path, log_dir):
# Config
config = load(open(config_path, "rb"))
......@@ -37,86 +40,110 @@ def train(config_path, log_dir, model_path):
batch_size = config["batch_size"]
languages = config["languages"]
num_epochs = config["num_epochs"]
fs = config["sample_rate"]
sample_rate = config["sample_rate"]
audio_length_s = config["audio_length_s"]
augment = config["augment"]
show_progress = config["show_progress"]
learning_rate = config["learning_rate"]
model_name = config["model"]
model_path = config["model_path"]
# create or load the model
if model_path:
if model_path != "":
model = load_model(model_path)
else:
model_class = getattr(models, config["model"])
model_class = getattr(models, model_name)
model = model_class.create_model(config)
optimizer = Adam(lr=config["learning_rate"])
model.compile()
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
loss=CategoricalCrossentropy(),
metrics=[Recall(), Precision(), CategoricalAccuracy()])
print(model.summary())
# when using on-the-fly augmentation create the augmentation object
augmenter = None
# load the dataset
train_ds = create_dataset_from_set_of_files(
ds_dir=train_dir, languages=languages)
val_ds = create_dataset_from_set_of_files(
ds_dir=val_dir, languages=languages)
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
# Optional augmentation of the training set
## Note: tf.py_function allows to construct a graph but code is executed in python
if augment:
augmenter = AudioAugmenter(audio_length_s, fs)
# create generator objects
train_gen_obj = AugBatchGenerator(source=train_dir, target_length_s=audio_length_s,
languages=languages, batch_size=batch_size,
augmenter=augmenter)
val_gen_obj = AugBatchGenerator(source=val_dir, target_length_s=audio_length_s,
languages=languages, batch_size=batch_size)
# progress bar information
if show_progress:
train_num_batches = train_gen_obj.count_batches()
val_num_batches = val_gen_obj.count_batches()
# Training loop
best_val_acc = 0.0
for epoch in range(1, num_epochs+1):
print('===== EPOCH ', str(epoch), ' ======')
# train
train_results = run_epoch(model, train_gen_obj, training=True, optimizer=optimizer,
show_progress=show_progress, num_batches=train_num_batches)
train_acc, train_loss, train_recall, train_precision = train_results
# validate
val_results = run_epoch(model, val_gen_obj, training=False,
show_progress=show_progress, num_batches=val_num_batches)
val_acc, val_loss, val_recall, val_precision = val_results
# log data
lr = round(float(get_value(optimizer.learning_rate)), 6)
logs = {'epoch': epoch, 'learning_rate': lr,
'train_acc': train_acc, 'train_loss': train_loss,
'train_rec': train_recall, 'train_pre': train_precision,
'val_acc': val_acc, 'val_loss': val_loss,
'val_rec': val_recall, 'val_pre': val_precision}
write_csv(os.path.join(log_dir, 'log.csv'), epoch, logs)
# save model
if val_acc > best_val_acc:
best_val_acc = val_acc
model_name = os.path.join(log_dir, 'model' + "_" + str(epoch))
model_predict = get_saved_model_function(model)
model.save(model_name, signatures={'serving_default': model_predict})
augmenter = AudioAugmenter(audio_length_s, sample_rate)
def process_aug(audio, label):
audio = augmenter.augment_audio_array(audio)
return audio, label
aug_wav = lambda x,y: tf.py_function(process_aug, [x,y], tf.float32)
train_ds.map(aug_wav, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# normalize audio
def process(audio, label):
audio = tf_normalize(audio)
return audio, label
train_ds = train_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# prefetch data
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)
# Training Callbacks
tensorboard_callback = TensorBoard(log_dir=log_dir, write_images=True)
csv_logger_callback = CustomCSVCallback(os.path.join(log_dir, "log.csv"))
reduce_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3,
verbose=1, min_lr=0.000001, min_delta=0.001)
checkpoint_filename = os.path.join(log_dir, "trained_models", "model.{epoch:02d}")
model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1,
monitor="val_categorical_accuracy",
save_weights_only=False)
early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode="min")
# comment callbacks that you don't care about
callbacks = [
# tensorboard_callback,
csv_logger_callback,
reduce_on_plateau,
model_checkpoint_callback,
# early_stopping_callback,
]
# Training
history = model.fit(x=train_ds, epochs=num_epochs,
callbacks=callbacks, validation_data=val_ds)
# TODO Do evaluation on model with best validation accuracy
visualize_results(history, config, log_dir)
best_epoch = np.argmax(history.history["val_categorical_accuracy"])
print("Log files: ", log_dir)
print("Best epoch: ", best_epoch)
checkpoint_filename.replace("{epoch:02d}", "{:02d}".format(best_epoch))
print("Best model at: ", checkpoint_filename)
return model, best_epoch
if __name__ == "__main__":
tf.config.list_physical_devices('GPU')
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="config_train.yaml")
parser.add_argument('--model_path', default=None, help="Path to a trained model for retraining")
parser.add_argument('--config', default="config_train.yaml",
help="Path to the required config file.")
cli_args = parser.parse_args()
tf.config.list_physical_devices('GPU')
# copy models & config for later
log_dir = os.path.join("logs", datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
print("Logging to {}".format(log_dir))
# copy models & config for later
shutil.copytree("src/models", os.path.join(log_dir, "models"))
shutil.copy(cli_args.config, log_dir)
model_file_name = train(cli_args.config, log_dir, cli_args.model_path)
print("Best model at: ", model_file_name)
# train and save the best model as SavedModel
model, best_epoch = train(cli_args.config, log_dir)
saved_model_path = os.path.join(log_dir, "model_" + str(best_epoch))
model.save(saved_model_path, signatures=get_saved_model_function(model))
#TODO visualize the training process and save as png
#TODO convert model to saved model
"""
:author:
Paul Bethge (bethge@zkm.de)
2021
:License:
This package is published under Simplified BSD License.
"""
import os
import shutil
import argparse
from datetime import datetime
from yaml import load
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam, RMSprop, SGD
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy
from tensorflow.keras.models import load_model
import src.models as models
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function, visualize_results
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize
from src.audio.augment import AudioAugmenter
def train(config_path, log_dir):
# Config
config = load(open(config_path, "rb"))
if config is None:
print("Please provide a config.")
train_dir = config["train_dir"]
val_dir = config["val_dir"]
batch_size = config["batch_size"]
languages = config["languages"]
num_epochs = config["num_epochs"]
sample_rate = config["sample_rate"]
audio_length_s = config["audio_length_s"]
augment = config["augment"]
learning_rate = config["learning_rate"]
model_name = config["model"]
model_path = config["model_path"]
# create or load the model
if model_path != "":
model = load_model(model_path)
else:
model_class = getattr(models, model_name)
model = model_class.create_model(config)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
loss=CategoricalCrossentropy(),
metrics=[Recall(), Precision(), CategoricalAccuracy()])
print(model.summary())
# load the dataset
train_ds = create_dataset_from_set_of_files(
ds_dir=train_dir, languages=languages)
val_ds = create_dataset_from_set_of_files(
ds_dir=val_dir, languages=languages)
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
# Optional augmentation of the training set
## Note: tf.py_function allows to construct a graph but code is executed in python
if augment:
augmenter = AudioAugmenter(audio_length_s, sample_rate)
def process_aug(audio, label):
audio = augmenter.augment_audio_array(audio)
return audio, label
aug_wav = lambda x,y: tf.py_function(process_aug, [x,y], tf.float32)
train_ds.map(aug_wav, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# normalize audio
def process(audio, label):
audio = tf_normalize(audio)
return audio, label
train_ds = train_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# prefetch data
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)
# Training Callbacks
tensorboard_callback = TensorBoard(log_dir=log_dir, write_images=True)
csv_logger_callback = CustomCSVCallback(os.path.join(log_dir, "log.csv"))
reduce_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3,
verbose=1, min_lr=0.000001, min_delta=0.001)
checkpoint_filename = os.path.join(log_dir, "trained_models", "model.{epoch:02d}")
model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1,
monitor="val_categorical_accuracy",
save_weights_only=False)
early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode="min")
# comment callbacks that you don't care about
callbacks = [
# tensorboard_callback,
csv_logger_callback,
reduce_on_plateau,
model_checkpoint_callback,
# early_stopping_callback,
]
# Training
history = model.fit(x=train_ds, epochs=num_epochs,
callbacks=callbacks, validation_data=val_ds)
# TODO Do evaluation on model with best validation accuracy
visualize_results(history, config, log_dir)
best_epoch = np.argmax(history.history["val_categorical_accuracy"])
print("Log files: ", log_dir)
print("Best epoch: ", best_epoch)
checkpoint_filename.replace("{epoch:02d}", "{:02d}".format(best_epoch))
print("Best model at: ", checkpoint_filename)
return model, best_epoch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="config_train.yaml",
help="Path to the required config file.")
cli_args = parser.parse_args()
tf.config.list_physical_devices('GPU')
# copy models & config for later
log_dir = os.path.join("logs", datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
print("Logging to {}".format(log_dir))
shutil.copytree("src/models", os.path.join(log_dir, "models"))
shutil.copy(cli_args.config, log_dir)
# train and save the best model as SavedModel
model, best_epoch = train(cli_args.config, log_dir)
saved_model_path = os.path.join(log_dir, "model_" + str(best_epoch))
model.save(saved_model_path, signatures=get_saved_model_function(model))
#TODO visualize the training process and save as png
#TODO convert model to saved model
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment