Commit 7cf47261 authored by Paul Bethge's avatar Paul Bethge
Browse files

use custom train loop

parent d36d2591
......@@ -27,6 +27,15 @@ from kapre.composed import get_log_frequency_spectrogram_layer
def create_dataset_from_set_of_files(ds_dir, languages):
"""Create an audio dataset for wav files in a specified directory and for specified languages
Args:
ds_dir (str): the path to the folder containing the folders for each language
languages (list): the folders to be included
Returns:
(tf.data.dataset, num_files): the dataset and the amount of files in there
"""
# assure languages are sorted alphanumerically
languages = sorted(languages)
......@@ -34,6 +43,7 @@ def create_dataset_from_set_of_files(ds_dir, languages):
# create a file path dataset from all directories specified
glob_list = [os.path.join(ds_dir, lang,'*.wav') for lang in languages]
list_ds = tf.data.Dataset.list_files(glob_list)
num_files = len(list(list_ds))
# create a dataset yielding audio and categorical label
def process_path(file_path):
......@@ -48,7 +58,7 @@ def create_dataset_from_set_of_files(ds_dir, languages):
return x, y
labeled_ds = list_ds.map(process_path)
return labeled_ds
return labeled_ds, num_files
def tf_normalize(signal):
......@@ -62,6 +72,15 @@ def tf_normalize(signal):
def get_feature_layer(feature_type, feature_nu, sample_rate):
"""returns a keras.layer for given feature type and number
Args:
feature_type (str): one of stft, mel and fbank
feature_nu (int): the number of features
sample_rate (int): sampling rate of input signal
Returns:
[keras.layer]: computes the specified features
"""
if feature_type == 'stft':
m = get_stft_magnitude_layer(n_fft=feature_nu*2, name='stft_deb')
elif feature_type == 'mel':
......@@ -77,15 +96,15 @@ def get_feature_layer(feature_type, feature_nu, sample_rate):
def write_csv(logging_dir, epoch, logs={}):
with open(logging_dir, mode='a') as log_file:
log_file_writer = csv.writer(log_file, delimiter=',')
if epoch == 0:
row = list(logs.keys())
row.insert(0, 'epoch')
log_file_writer.writerow(row)
row_vals = [round(x, 6) for x in list(logs.values())]
row_vals.insert(0, epoch)
log_file_writer.writerow(row_vals)
with open(logging_dir, mode='a') as log_file:
log_file_writer = csv.writer(log_file, delimiter=',')
if epoch == 0:
row = list(logs.keys())
row.insert(0, 'epoch')
log_file_writer.writerow(row)
row_vals = [round(x, 6) for x in list(logs.values())]
row_vals.insert(0, epoch)
log_file_writer.writerow(row_vals)
class CustomCSVCallback(Callback):
......@@ -105,10 +124,10 @@ def get_saved_model_function(model, dims=(None, None, 1)):
def visualize_results(history, config, log_dir):
epochs = config["num_epochs"]
acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history['categorical_accuracy']
val_acc = history['val_categorical_accuracy']
loss = history['loss']
val_loss = history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
......@@ -127,7 +146,180 @@ def visualize_results(history, config, log_dir):
plt.savefig(os.path.join(log_dir, "training.png"))
def run_epoch(model, dataset, training=False, augmenter=None, optimizer=None, show_progress=False, num_batches=32):
def fit(model, train_set, optimizer, num_epochs=20, augmenter=None, val_set=None,
show_progress=False, num_batches_train=0, num_batches_val=0):
"""Train or validate a model with a given generator.
Args:
model (Keras.Model): keras model
batch_generator (generator): generator object that yields batches of samples
training (bool, optional): Whether to train the model or not. Defaults to False.
optimizer (Keras.Optimizer, optional): optimizer that is applied when training is True. Defaults to None.
augmenter (AudioAugmenter, optional): when given will augment the audio data. Defaults to None.
show_progress (bool, optional): whether to show the progress bar. Defaults to False.
num_batches (int): maximum number of batches (only used for progress bar)
Returns:
[list]: a list of metrics
"""
# metrics
loss_fn = CrossLoss()
metric_accuracy = CategoricalAccuracy()
metric_recall = Recall()
metric_precision = Precision()
metric_loss = CategoricalCrossentropy()
def update_states(y_batch, logits):
metric_accuracy.update_state(y_batch, logits)
metric_precision.update_state(y_batch, logits)
metric_recall.update_state(y_batch, logits)
metric_loss.update_state(y_batch, logits)
accuracy = metric_accuracy.result().numpy()
recall = metric_recall.result().numpy()
precision = metric_precision.result().numpy()
loss = metric_loss.result().numpy()
return loss, accuracy, recall, precision
def reset_states():
metric_accuracy.reset_states()
metric_recall.reset_states()
metric_precision.reset_states()
metric_loss.reset_states()
metrics_names = ['loss', 'acc', 'rec', 'pre']
history = {
'categorical_accuracy': [0.0] * num_epochs,
'val_categorical_accuracy': [0.0] * num_epochs,
'loss': [0.0] * num_epochs,
'val_loss': [0.0] * num_epochs
}
for epoch in range(num_epochs):
print('______ TRAINING ______')
pb = Progbar(num_batches_train, stateful_metrics=metrics_names)
# iterate over the batches of a dataset
for x_batch, y_batch in train_set:
if augmenter != None:
x_batch = augmenter.augment_audio_array(x_batch.numpy().tolist())
x_batch = tf.expand_dims(x_batch, axis=-1)
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss_value = loss_fn(y_batch, logits)
# optimize
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
loss, accuracy, recall, precision = update_states(y_batch, logits)
if show_progress:
values = [ ('loss', loss), ('acc', accuracy),
('rec', recall), ('pre', precision) ]
pb.add(1, values=values)
history['categorical_accuracy'][epoch] = accuracy
history['loss'][epoch] = loss
reset_states()
print('_____ VALIDATION ______')
pb = Progbar(num_batches_val, stateful_metrics=metrics_names)
# iterate over the batches of a dataset
for x_batch, y_batch in val_set:
logits = model(x_batch, training=False)
loss_value = loss_fn(y_batch, logits)
loss, accuracy, recall, precision = update_states(y_batch, logits)
if show_progress:
values = [ ('loss', loss), ('acc', accuracy),
('rec', recall), ('pre', precision) ]
pb.add(1, values=values)
history['val_categorical_accuracy'][epoch] = accuracy
history['val_loss'][epoch] = loss
reset_states()
return history
def run_epoch(model, dataset, training=False, optimizer=None, augmenter=None, show_progress=False, num_batches=32):
"""Train or validate a model with a given generator.
Args:
model (Keras.Model): keras model
batch_generator (generator): generator object that yields batches of samples
training (bool, optional): Whether to train the model or not. Defaults to False.
optimizer (Keras.Optimizer, optional): optimizer that is applied when training is True. Defaults to None.
augmenter (AudioAugmenter, optional): when given will augment the audio data. Defaults to None.
show_progress (bool, optional): whether to show the progress bar. Defaults to False.
num_batches (int): maximum number of batches (only used for progress bar)
Returns:
[list]: a list of metrics
"""
if training:
print('______ TRAINING ______')
else:
print('_____ VALIDATION ______')
# metrics
loss_fn = CrossLoss()
metric_accuracy = CategoricalAccuracy()
metric_recall = Recall()
metric_precision = Precision()
metric_loss = CategoricalCrossentropy()
metrics_names = ['loss', 'acc', 'rec', 'pre']
pb = Progbar(num_batches, stateful_metrics=metrics_names)
# iterate over the batches of a dataset
for x_batch, y_batch in dataset:
if training and augmenter!=None:
x_batch = augmenter.augment_audio_array(x_batch.numpy().tolist())
x_batch = tf.expand_dims(x_batch, axis=-1)
with tf.GradientTape() as tape:
logits = model(x_batch, training=training)
loss_value = loss_fn(y_batch, logits)
# update metrics
metric_accuracy.update_state(y_batch, logits)
metric_precision.update_state(y_batch, logits)
metric_recall.update_state(y_batch, logits)
metric_loss.update_state(y_batch, logits)
# optimize
if training:
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# report loss & metrics
accuracy = metric_accuracy.result().numpy()
recall = metric_recall.result().numpy()
precision = metric_precision.result().numpy()
loss = metric_loss.result().numpy()
if show_progress:
values = [ ('loss', loss), ('acc', accuracy),
('rec', recall), ('pre', precision) ]
pb.add(1, values=values)
return accuracy, loss, recall, precision
def run_epoch_metrics(model, dataset, metrics, training=False, optimizer=None, augmenter=None, show_progress=False, num_batches=32):
"""Train or validate a model with a given generator.
Args:
......@@ -135,6 +327,7 @@ def run_epoch(model, dataset, training=False, augmenter=None, optimizer=None, sh
batch_generator (generator): generator object that yields batches of samples
training (bool, optional): Whether to train the model or not. Defaults to False.
optimizer (Keras.Optimizer, optional): optimizer that is applied when training is True. Defaults to None.
augmenter (AudioAugmenter, optional): when given will augment the audio data. Defaults to None.
show_progress (bool, optional): whether to show the progress bar. Defaults to False.
num_batches (int): maximum number of batches (only used for progress bar)
......
......@@ -24,7 +24,7 @@ from tensorflow.keras.models import load_model
import src.models as models
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function, visualize_results
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize, run_epoch
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize, fit
from src.audio.augment import AudioAugmenter
......@@ -57,9 +57,9 @@ def train(config_path, log_dir):
# print(model.summary())
# load the dataset
train_ds = create_dataset_from_set_of_files(
train_ds, num_train_files = create_dataset_from_set_of_files(
ds_dir=train_dir, languages=languages)
val_ds = create_dataset_from_set_of_files(
val_ds, num_val_files = create_dataset_from_set_of_files(
ds_dir=val_dir, languages=languages)
# Optional augmentation of the training set
......@@ -81,13 +81,15 @@ def train(config_path, log_dir):
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)
for epoch in range(num_epochs):
run_epoch(model, train_ds, training=True, augmenter=augmenter, optimizer=optimizer, show_progress=True, num_batches=32)
run_epoch(model, val_ds, training=False, augmenter=None, optimizer=None, show_progress=True, num_batches=32)
# Training
history = model.fit(x=train_ds, epochs=num_epochs,
callbacks=callbacks, validation_data=val_ds)
# # Training
# history = model.fit(x=train_ds, epochs=num_epochs,
# callbacks=callbacks, validation_data=val_ds)
# # Training
history = fit(model, train_set=train_ds, optimizer=optimizer, num_epochs=num_epochs, augmenter=augmenter, show_progress=True,
num_batches_train=num_train_files/batch_size, num_batches_val=num_val_files/batch_size,
val_set=val_ds)
# TODO Do evaluation on model with best validation accuracy
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment