Commit 60ab6b7a authored by Paul Bethge's avatar Paul Bethge
Browse files

fix training

parent b667b402
......@@ -139,9 +139,7 @@ python train.py --config config_train.yaml
## TODO
- evaluate the fairness of the model
- use a voice (instead of audio) activity detector
- rework data loading process (e.g. use TFDataset)
- more automation in the data set creation steps
- use a voice (instead of audio) activity detector
## Further Reading
......
......@@ -43,7 +43,6 @@ def traverse_csv(language, input_dir, output_dir, max_chops,
input_sub_dir = os.path.join(input_dir, lang_abb)
input_sub_dir_clips = os.path.join(input_sub_dir, "clips")
splits = ["train", "dev", "test"]
fast_forward = 0
......
......@@ -17,6 +17,7 @@ from tensorflow.keras.losses import CategoricalCrossentropy as CrossLoss
from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy, CategoricalCrossentropy
from tensorflow.keras.utils import Progbar
import scipy.io.wavfile as wav
from kapre.composed import get_stft_magnitude_layer
from kapre.composed import get_melspectrogram_layer
......@@ -35,7 +36,11 @@ def create_dataset_from_set_of_files(ds_dir, languages):
# create a dataset yielding audio and categorical label
def process_path(file_path):
# read the wav file
x = tf.io.read_file(file_path)
x, _ = tf.audio.decode_wav(x)
# x = tf.squeeze(x, axis=-1)
# get label and convert to categorical
label = tf.strings.split(file_path, os.sep)[-2]
y = tf.cast(tf.equal(label, languages), tf.float32)
......@@ -45,6 +50,16 @@ def create_dataset_from_set_of_files(ds_dir, languages):
return labeled_ds
def tf_normalize(signal):
"""
normalize a float signal to have a maximum absolute value of 1.0
"""
highest = tf.math.abs(tf.math.reduce_max(signal))
lowest = tf.math.abs(tf.math.reduce_min(signal))
abs_max = tf.math.maximum(highest, lowest)
return tf.math.divide(signal, abs_max)
def get_feature_layer(feature_type, feature_nu, sample_rate):
if feature_type == 'stft':
m = get_stft_magnitude_layer(n_fft=feature_nu*2, name='stft_deb')
......@@ -59,6 +74,7 @@ def get_feature_layer(feature_type, feature_nu, sample_rate):
return None
return m
def write_csv(logging_dir, epoch, logs={}):
with open(logging_dir, mode='a') as log_file:
log_file_writer = csv.writer(log_file, delimiter=',')
......@@ -74,7 +90,7 @@ class CustomCSVCallback(Callback):
self._logging_dir = logging_dir
self._counter = 0
def on_epoch_end(self, epoch, logs={}):
write_csv(self.logging_dir, epoch, logs)
write_csv(self._logging_dir, epoch, logs)
def get_saved_model_function(model, dims=(None, None, 1)):
......
......@@ -23,11 +23,12 @@ from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy
from tensorflow.keras.models import load_model
import src.models as models
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function, create_dataset_from_set_of_files
from src.audio.features import normalize
from src.utils.training_utils import CustomCSVCallback, get_saved_model_function,
from src.utils.training_utils import create_dataset_from_set_of_files, tf_normalize
from src.audio.augment import AudioAugmenter
def train(config_path, log_dir, model_path):
# Config
......@@ -62,8 +63,8 @@ def train(config_path, log_dir, model_path):
ds_dir=train_dir, languages=languages)
val_ds = create_dataset_from_set_of_files(
ds_dir=val_dir, languages=languages)
train_ds.batch(batch_size)
val_ds.batch(batch_size)
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
# Optional augmentation of the training set
if augment:
......@@ -71,11 +72,11 @@ def train(config_path, log_dir, model_path):
def process_aug(audio, label):
audio = augmenter.augment_audio_array(audio)
return audio, label
train_ds.map(process_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# train_ds.map(process_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# normalize audio
def process(audio, label):
audio = normalize(audio)
audio = tf_normalize(audio)
return audio, label
train_ds = train_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.map(process, num_parallel_calls=tf.data.experimental.AUTOTUNE)
......@@ -84,8 +85,6 @@ def train(config_path, log_dir, model_path):
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)
print("Output classes:", train_ds.class_names)
# Training Callbacks
tensorboard_callback = TensorBoard(log_dir=log_dir, write_images=True)
csv_logger_callback = CustomCSVCallback(os.path.join(log_dir, "log.csv"))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment