Commit 8be6b2fc authored by Paul Bethge's avatar Paul Bethge
Browse files

fix data augmentation

parent e42d138c
......@@ -64,27 +64,30 @@ def train(config_path, log_dir):
ds_dir=train_dir, languages=languages)
val_ds = create_dataset_from_set_of_files(
ds_dir=val_dir, languages=languages)
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
# Optional augmentation of the training set
## Note: tf.py_function allows to construct a graph but code is executed in python
## Note: tf.py_function allows to construct a graph but code is executed in python (may be slow)
if augment:
augmenter = AudioAugmenter(audio_length_s, sample_rate)
# process a single audio array (note: dataset needs to be batched later on)
def process_aug(audio, label):
audio = augmenter.augment_audio_array(audio)
return audio, label
aug_wav = lambda x,y: tf.py_function(process_aug, [x,y], tf.float32),
# normalize audio
augmented_audio = augmenter.augment_audio(audio.numpy())
tensor_audio = tf.convert_to_tensor(augmented_audio, dtype=tf.float32)
return tensor_audio, label
aug_wav = lambda x,y: tf.py_function(process_aug, [x, y], [tf.float32, tf.float32])
train_ds =,
# normalize audio and expand by one dimension (as required by feature extraction)
def process(audio, label):
audio = tf_normalize(audio)
audio = tf.expand_dims(audio, axis=-1)
return audio, label
train_ds =,
val_ds =,
# prefetch data
# batch and prefetch data
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
train_ds = train_ds.prefetch(
val_ds = val_ds.prefetch(
......@@ -104,7 +107,7 @@ def train(config_path, log_dir):
# tensorboard_callback,
# model_checkpoint_callback,
# early_stopping_callback,
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment