import io
import numpy as np
import pyaudio
import scipy.io.wavfile as wav
import threading
from queue import Queue

import torch
torch.set_num_threads(1)
import torchaudio
torchaudio.set_audio_backend("soundfile")
from utils import *

import tensorflow as tf
from tensorflow.keras.models import load_model

# Configure
vad_threshhold = 0.8
kws_threshhold = 0.95
kws_required_size = 4
frame_duration_ms = 250

#=== Silero VAD ===#
model = torch.jit.load('vad_model.jit')
def normalize(sound):
	abs_max = np.abs(sound).max()
	if abs_max > 0:
		sound *= 1/abs_max
	sound = sound.squeeze()  # depends on the use case
	return sound

#=== Keyword Spotting ===#
classes = ['unknown', 'nine', 'yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go',
		'zero', 'one', 'two', 'three', 'four', 'five', 'six',
		'seven',  'eight', 'backward', 'bed', 'bird', 'cat', 'dog',
		'follow', 'forward', 'happy', 'house', 'learn', 'marvin', 'sheila', 'tree',
		'visual', 'wow']

# KWS process thread
def processAudio(config, q):
	kws_model = load_model('kws_model')
	while True:
		try:
			data = q.get()
			data = [item for sublist in data for item in sublist]
			data_tensor = tf.convert_to_tensor(data)
			data_tensor = tf.expand_dims(data_tensor, axis=0)
			data_tensor = tf.cast(data_tensor, tf.float32)
			data_tensor = tf.math.divide(data_tensor, 32768)
			out = kws_model(data_tensor)[0].numpy()
			index = tf.math.argmax(out).numpy()

			if out[index] >= kws_threshhold:
				print(classes[index])

			# if you want to see the data please uncomment
			# wav.write('results/'+classes[index]+'.wav', SAMPLE_RATE, np.asarray(data))

		except Exception as e:
			print("Ooopsi: ", e)
		q.task_done()

# Threading
queue = Queue()
some_config = ''
worker = threading.Thread(target=processAudio, args=(some_config, queue), daemon=True)
worker.start()

# Pyaudio
FORMAT = pyaudio.paInt16
CHANNELS = 1
SAMPLE_RATE = 16000
CHUNK = int(SAMPLE_RATE / 10)
audio = pyaudio.PyAudio()
stream = audio.open(format=FORMAT,
					channels=CHANNELS,
					rate=SAMPLE_RATE,
					input=True,
					frames_per_buffer=CHUNK)
chunk_size = int(SAMPLE_RATE * frame_duration_ms / 1000.0)

data = []
audio_int16 = []
nu_voice_chunks = 0
got_voice = False

print("Started Recording")
while True:

	# keep the last chunk so nothing gets lost
	last_chunk = audio_int16

	# sample chunk, convert to float and normalize
	audio_chunk = stream.read(chunk_size)
	audio_int16 = np.frombuffer(audio_chunk, np.int16)
	audio_float32 = audio_int16.astype('float32')
	audio_float32_norm = normalize(audio_float32)
	
	# get the confidences
	vad_outs = validate(model, torch.from_numpy(audio_float32_norm))[:,1]

	# trigger if voice is detected
	if vad_outs >= vad_threshhold and not got_voice:
		print("I found something")
		got_voice = True
		data = []
		data.append(last_chunk)
	
	# collect data and analyze
	if got_voice:
		if nu_voice_chunks < kws_required_size:
			data.append(audio_int16)
			nu_voice_chunks += 1
		else:
			got_voice = False
			nu_voice_chunks = 0
			queue.put(data)

queue.join()