Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Hertz-Lab
Research
Intelligent Museum
language-identification
Commits
e50d899c
Commit
e50d899c
authored
Jul 20, 2021
by
Paul Bethge
Browse files
add augmentation, model save and load, fix report
parent
7b15f728
Changes
1
Hide whitespace changes
Inline
Side-by-side
train_simple.py
View file @
e50d899c
...
...
@@ -23,13 +23,13 @@ from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy
from
tensorflow.keras.models
import
load_model
import
src.models
as
models
from
src.utils.training_utils
import
CustomCSVCallback
,
get_saved_model_function
from
src.utils.training_utils
import
CustomCSVCallback
,
get_saved_model_function
,
visualize_results
from
src.utils.training_utils
import
create_dataset_from_set_of_files
,
tf_normalize
from
src.audio.augment
import
AudioAugmenter
def
train
(
config_path
,
log_dir
,
model_path
):
def
train
(
config_path
,
log_dir
):
# Config
config
=
load
(
open
(
config_path
,
"rb"
))
...
...
@@ -45,9 +45,10 @@ def train(config_path, log_dir, model_path):
augment
=
config
[
"augment"
]
learning_rate
=
config
[
"learning_rate"
]
model_name
=
config
[
"model"
]
model_path
=
config
[
"model_path"
]
# create or load the model
if
model_path
:
if
model_path
!=
""
:
model
=
load_model
(
model_path
)
else
:
model_class
=
getattr
(
models
,
model_name
)
...
...
@@ -67,13 +68,15 @@ def train(config_path, log_dir, model_path):
val_ds
=
val_ds
.
batch
(
batch_size
)
# Optional augmentation of the training set
## Note: tf.py_function allows to construct a graph but code is executed in python
if
augment
:
augmenter
=
AudioAugmenter
(
audio_length_s
,
sample_rate
)
def
process_aug
(
audio
,
label
):
audio
=
augmenter
.
augment_audio_array
(
audio
)
return
audio
,
label
# train_ds.map(process_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)
aug_wav
=
lambda
x
,
y
:
tf
.
py_function
(
process_aug
,
[
x
,
y
],
tf
.
float32
)
train_ds
.
map
(
aug_wav
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
# normalize audio
def
process
(
audio
,
label
):
audio
=
tf_normalize
(
audio
)
...
...
@@ -88,47 +91,59 @@ def train(config_path, log_dir, model_path):
# Training Callbacks
tensorboard_callback
=
TensorBoard
(
log_dir
=
log_dir
,
write_images
=
True
)
csv_logger_callback
=
CustomCSVCallback
(
os
.
path
.
join
(
log_dir
,
"log.csv"
))
reduce_on_plateau
=
ReduceLROnPlateau
(
monitor
=
'val_loss'
,
factor
=
0.2
,
patience
=
1
,
reduce_on_plateau
=
ReduceLROnPlateau
(
monitor
=
'val_loss'
,
factor
=
0.2
,
patience
=
3
,
verbose
=
1
,
min_lr
=
0.000001
,
min_delta
=
0.001
)
checkpoint_filename
=
os
.
path
.
join
(
log_dir
,
"trained_models"
,
"
weights
.{epoch:02d}"
)
checkpoint_filename
=
os
.
path
.
join
(
log_dir
,
"trained_models"
,
"
model
.{epoch:02d}"
)
model_checkpoint_callback
=
ModelCheckpoint
(
checkpoint_filename
,
save_best_only
=
True
,
verbose
=
1
,
monitor
=
"val_categorical_accuracy"
,
save_weights_only
=
False
)
#
early_stopping_callback = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode="min")
early_stopping_callback
=
EarlyStopping
(
monitor
=
'val_loss'
,
min_delta
=
0
,
patience
=
10
,
verbose
=
1
,
mode
=
"min"
)
callbacks
=
[
tensorboard_callback
,
csv_logger_callback
,
reduce_on_plateau
,
model_checkpoint_callback
,
# early_stopping_callback,
]
# comment callbacks that you don't care about
callbacks
=
[
# tensorboard_callback,
csv_logger_callback
,
reduce_on_plateau
,
model_checkpoint_callback
,
# early_stopping_callback,
]
# Training
history
=
model
.
fit
(
x
=
train_ds
,
epochs
=
num_epochs
,
callbacks
=
callbacks
,
validation_data
=
val_ds
)
# visualize_results(history, config)
# Do evaluation on model with best validation accuracy
# TODO Do evaluation on model with best validation accuracy
visualize_results
(
history
,
config
,
log_dir
)
best_epoch
=
np
.
argmax
(
history
.
history
[
"val_categorical_accuracy"
])
print
(
"Log files: "
,
log_dir
)
print
(
"Best epoch: "
,
best_epoch
)
return
checkpoint_filename
.
replace
(
"{epoch:02d}"
,
"{:02d}"
.
format
(
best_epoch
))
checkpoint_filename
.
replace
(
"{epoch:02d}"
,
"{:02d}"
.
format
(
best_epoch
))
print
(
"Best model at: "
,
checkpoint_filename
)
return
model
,
best_epoch
if
__name__
==
"__main__"
:
tf
.
config
.
list_physical_devices
(
'GPU'
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
default
=
"config_train.yaml"
)
parser
.
add_argument
(
'--model_path'
,
default
=
None
,
help
=
"Path to a trained model for retraining
"
)
parser
.
add_argument
(
'--config'
,
default
=
"config_train.yaml"
,
help
=
"Path to the required config file.
"
)
cli_args
=
parser
.
parse_args
()
tf
.
config
.
list_physical_devices
(
'GPU'
)
# copy models & config for later
log_dir
=
os
.
path
.
join
(
"logs"
,
datetime
.
now
().
strftime
(
"%Y-%m-%d-%H-%M-%S"
))
print
(
"Logging to {}"
.
format
(
log_dir
))
# copy models & config for later
shutil
.
copytree
(
"src/models"
,
os
.
path
.
join
(
log_dir
,
"models"
))
shutil
.
copy
(
cli_args
.
config
,
log_dir
)
model_file_name
=
train
(
cli_args
.
config
,
log_dir
,
cli_args
.
model_path
)
print
(
"Best model at: "
,
model_file_name
)
# train and save the best model as SavedModel
model
,
best_epoch
=
train
(
cli_args
.
config
,
log_dir
)
saved_model_path
=
os
.
path
.
join
(
log_dir
,
"model_"
+
str
(
best_epoch
))
model
.
save
(
saved_model_path
,
signatures
=
get_saved_model_function
(
model
))
#TODO visualize the training process and save as png
#TODO convert model to saved model
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment