Commit bf5f400b authored by pbethge's avatar pbethge
Browse files

use attRnn with wav2vec2

parent 8cc94f9c
......@@ -8,10 +8,12 @@ This package is published under Simplified BSD License.
"""
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.layers import Dense, Permute, Input, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Permute, Dropout, BatchNormalization
from tensorflow.keras.layers import Input, Dense, LSTM, Conv2D, Bidirectional, Flatten
from tensorflow.keras.layers import LayerNormalization, Lambda, Dot, Softmax
from tensorflow.keras.models import Model
from src.utils.training_utils import get_feature_layer
......@@ -31,19 +33,34 @@ def create_model(config):
from tensorflow.keras.models import load_model
model_path = 'wav2vec2'
feature_extractor = load_model(model_path)
feature_extractor.trainable = False
feature_extractor.trainable = True
inputs = Input((input_length), name='input')
model = Sequential()
model.add(inputs)
model.add(feature_extractor)
model.add(BatchNormalization())
model.add(Flatten())
# model.add(GlobalAveragePooling2D())
# model.add(Dropout(0.5))
model.add(Dense(len(config["languages"]),
activation='softmax'
))
x = feature_extractor(inputs)
x = BatchNormalization()(x)
x = Bidirectional(LSTM(64, return_sequences=True))(x) # [b_s, seq_len, vec_dim]
x = Bidirectional(LSTM(64, return_sequences=True))(x) # [b_s, seq_len, vec_dim]
xFirst = Lambda(lambda q: q[:, -1])(x) # [b_s, vec_dim]
query = Dense(128)(xFirst)
# dot product attention
attScores = Dot(axes=[1, 2])([query, x])
attScores = Softmax(name='attSoftmax')(attScores) # [b_s, seq_len]
# rescale sequence
attVector = Dot(axes=[1, 1])([attScores, x]) # [b_s, vec_dim]
x = Dense(64, activation='relu')(attVector)
x = Dropout(0.25)(x)
x = Dense(32)(x)
x = Dropout(0.25)(x)
output = Dense(len(config["languages"]), activation='softmax', name='output')(x)
model = Model(inputs=[inputs], outputs=[output])
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment