Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Hertz-Lab
Research
Intelligent Museum
language-identification
Commits
22b9e8ca
Commit
22b9e8ca
authored
May 25, 2021
by
Paul Bethge
Browse files
faster augmentation
parent
e59c72ab
Changes
5
Hide whitespace changes
Inline
Side-by-side
src/audio/augment.py
View file @
22b9e8ca
...
...
@@ -9,7 +9,8 @@ This package is published under GNU GPL Version 3.
import
nlpaug.flow
as
flow
import
nlpaug.augmenter.audio
as
naa
from
audio.utils
import
pad_with_silence
from
audio.features
import
normalize
class
AudioAugmenter
(
object
):
def
__init__
(
self
,
fs
):
...
...
@@ -24,18 +25,21 @@ class AudioAugmenter(object):
self
.
_aug_flow
=
flow
.
Sequential
([
shift
,
crop
,
vltp
,
#
vltp,
# speed,
# pitch,
noise
,
])
def
augment_audio_array
(
self
,
signal
,
fs
):
assert
fs
==
self
.
_fs
data
=
signal
.
astype
(
dtype
=
"float32"
)
augmented_data
=
self
.
_aug_flow
.
augment
(
data
)
return
augmented_data
.
astype
(
dtype
=
"float32"
)
def
augment_audio_array
(
self
,
signals
,
fs
):
# assert fs == self._fs
# data = signal.astype(dtype="float32")
augmented_data
=
self
.
_aug_flow
.
augment
(
signals
,
num_thread
=
8
)
data
=
[]
for
x
in
augmented_data
:
x
=
pad_with_silence
(
x
,
5
*
fs
)
data
.
append
(
x
)
return
data
if
__name__
==
"__main__"
:
...
...
src/audio/features.py
View file @
22b9e8ca
...
...
@@ -22,6 +22,8 @@ def normalize(signal):
normalize a float signal to have a maximum absolute value of 1.0
"""
maximum
=
max
(
abs
(
signal
.
max
()),
abs
(
signal
.
min
()))
if
maximum
==
0.0
:
return
signal
return
signal
/
float
(
maximum
)
...
...
src/config_gen.yaml
View file @
22b9e8ca
languages
:
[
"
english"
,
"
farsi"
,
"
french"
,
"
german"
,
"
mandarin"
,
"
spanish"
,
"
polish"
,
"
russian"
]
#languages: ["__noise","english","farsi","french","german","mandarin","spanish","polish","russian"]
languages
:
[
"
__noise"
,
"
english"
,
"
french"
,
"
german"
,
"
spanish"
]
# Training
train_dir
:
"
/data/common_voice_filtered/fiveSec
_validated
/wav/train/"
val_dir
:
"
/data/common_voice_filtered/fiveSec
_validated
/wav/
dev
/"
train_dir
:
"
/data/common_voice_filtered/fiveSec/wav/train/"
val_dir
:
"
/data/common_voice_filtered/fiveSec/wav/
test
/"
audio_length_s
:
5
sample_rate
:
16000
feature_nu
:
80
batch_size
:
128
batch_size
:
64
learning_rate
:
0.001
num_epochs
:
20
model
:
"
attRnnSTFT"
augment
:
Fals
e
augment
:
Tru
e
src/models/attRnnSTFT.py
View file @
22b9e8ca
...
...
@@ -75,7 +75,9 @@ def create_model(config):
attVector
=
Dot
(
axes
=
[
1
,
1
])([
attScores
,
x
])
# [b_s, vec_dim]
x
=
Dense
(
64
,
activation
=
'relu'
)(
attVector
)
x
=
Dropout
(
0.25
)(
x
)
x
=
Dense
(
32
)(
x
)
x
=
Dropout
(
0.25
)(
x
)
output
=
Dense
(
len
(
config
[
"languages"
]),
activation
=
'softmax'
,
name
=
'output'
)(
x
)
...
...
src/train_gen.py
View file @
22b9e8ca
...
...
@@ -31,14 +31,14 @@ def batch_gen(generator, batch_size, augmenter=None, fs=None, desired_audio_leng
while
True
:
try
:
x
,
y
=
next
(
generator
)
if
augmenter
:
x
=
augmenter
.
augment_audio_array
(
x
,
fs
)
x
=
pad_with_silence
(
x
,
desired_audio_length_s
*
fs
)
x
=
normalize
(
x
)
x_batch
.
append
(
x
)
y_batch
.
append
(
y
)
i
+=
1
if
i
==
batch_size
:
if
augmenter
:
x_batch
=
augmenter
.
augment_audio_array
(
x_batch
,
fs
)
for
i
in
range
(
len
(
x_batch
)):
x_batch
[
i
]
=
normalize
(
x_batch
[
i
])
x_arr
=
np
.
asarray
(
x_batch
)
y_arr
=
np
.
asarray
(
y_batch
)
yield
x_arr
,
y_arr
...
...
@@ -148,7 +148,11 @@ def train(config_path, log_dir, model_path):
logs
=
{
'train_acc'
:
train_acc
,
'train_loss'
:
train_loss
,
'val_acc'
:
val_acc
,
'val_loss'
:
val_loss
}
model
.
save
(
os
.
path
.
join
(
log_dir
,
'model'
+
"_"
+
str
(
epoch
+
1
)))
@
tf
.
function
(
input_signature
=
[
tf
.
TensorSpec
((
1
,
audio_length_s
*
fs
,
1
),
dtype
=
tf
.
float32
)])
def
model_predict
(
input_1
):
return
{
'outputs'
:
model
(
input_1
,
training
=
False
)}
model
.
save
(
os
.
path
.
join
(
log_dir
,
'model'
+
"_"
+
str
(
epoch
+
1
)),
signatures
=
{
'serving_default'
:
model_predict
})
write_csv
(
os
.
path
.
join
(
log_dir
,
'log.csv'
),
optimizer
,
epoch
,
logs
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment