Commit 7b15f728 authored by Paul Bethge's avatar Paul Bethge
Browse files

fix csv and visualize functions

parent a30fa88b
......@@ -7,9 +7,12 @@ Paul Bethge (bethge@zkm.de)
This package is published under Simplified BSD License.
"""
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy.io.wavfile as wav
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.backend import get_value
......@@ -17,13 +20,11 @@ from tensorflow.keras.losses import CategoricalCrossentropy as CrossLoss
from tensorflow.keras.metrics import Precision, Recall, CategoricalAccuracy, CategoricalCrossentropy
from tensorflow.keras.utils import Progbar
import scipy.io.wavfile as wav
from kapre.composed import get_stft_magnitude_layer
from kapre.composed import get_melspectrogram_layer
from kapre.composed import get_log_frequency_spectrogram_layer
import os
def create_dataset_from_set_of_files(ds_dir, languages):
......@@ -78,7 +79,7 @@ def get_feature_layer(feature_type, feature_nu, sample_rate):
def write_csv(logging_dir, epoch, logs={}):
with open(logging_dir, mode='a') as log_file:
log_file_writer = csv.writer(log_file, delimiter=',')
if epoch == 1:
if epoch == 0:
row = list(logs.keys())
log_file_writer.writerow(row)
row_vals = [round(x, 6) for x in list(logs.values())]
......@@ -88,8 +89,8 @@ def write_csv(logging_dir, epoch, logs={}):
class CustomCSVCallback(Callback):
def __init__(self, logging_dir):
self._logging_dir = logging_dir
self._counter = 0
def on_epoch_end(self, epoch, logs={}):
logs['learning_rate'] = float(get_value(self.model.optimizer.learning_rate))
write_csv(self._logging_dir, epoch, logs)
......@@ -100,6 +101,32 @@ def get_saved_model_function(model, dims=(None, None, 1)):
return model_predict
def visualize_results(history, config, log_dir):
epochs = config["num_epochs"]
acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.savefig(os.path.join(log_dir, "training.png"))
##### Old Stuff ######
def run_epoch(model, batch_generator_obj, training=False, optimizer=None, show_progress=False, num_batches=32):
"""Train or validate a model with a given generator.
......@@ -165,27 +192,3 @@ def run_epoch(model, batch_generator_obj, training=False, optimizer=None, show_p
pb.add(1, values=values)
return accuracy, loss, recall, precision
def visualize_results(history, config):
epochs = config["num_epochs"]
acc = history.history['categorical_accuracy']
val_acc = history.history['val_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment