import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

########### Se videon neurala nätverk (repetition)

def sigma(x):
    x = np.where(x < 500, x, 500)
    x = np.where(x > -500, x, -500)
    z = 1.0/( 1 + np.exp( -x ))
    return z

def sigma_derivative(x):
    sig = sigma(x)
    return sig*(1-sig) 


def forward(x):
    z1 = np.matmul(W1, x) + B1
    a1 = sigma(z1)
    z2 = np.matmul(W2, a1) + B2
    a2 = sigma(z2)
    return z1, z2, x, a1, a2

def predict(x):
    return forward(x)[-1]


########### Se video Hämta data online med requests
import requests
#text = requests.get('https://fileadmin.cs.lth.se/cs/Education/EDAA55/numpy/trainingdata_images.txt').text
#with open('trainingdata_images.txt', 'w') as f: f.write(text)

#text = requests.get('https://fileadmin.cs.lth.se/cs/Education/EDAA55/numpy/trainingdata_labels.txt').text
#with open('trainingdata_labels.txt', 'w') as f: f.write(text)
training_data = np.loadtxt('trainingdata_images.txt')
labels = np.loadtxt('trainingdata_labels.txt', dtype= int)
print(training_data.shape)
print(labels.shape)

########### Se video Matrisfunktioner: argmax och reshape
images = training_data.reshape(len(training_data), 28, 28)
training_data = training_data.reshape(len(training_data), 784, 1)


########### Se video filtrera med villkor
training_data = training_data[labels < 2]
images = images[labels <2]
labels = labels[labels < 2]
print(training_data.shape)
print(labels.shape)


########### Se video förbered datan

# Förbered datan!
split = 12665
train_X = training_data[0:split]
test_X = training_data[split:]
train_y = labels[0:split]
test_y = labels[split:]
test_images = images[split:]


def onehot(y):
    v = np.array([0, 0])
    v[y] = 1 #[1, 0] för 0 och [0, 1] för 1
    return v.reshape(2, 1)
train_y = np.array([onehot(i) for i in train_y])

########### Se video Matriser med slumpvärden
# Initialisera vikterna
N1 = 3
N2 = 2
INPUT = 784
W1 = np.random.randn(N1, INPUT)
B1 = np.random.randn(N1, 1)
W2 = np.random.randn(N2, N1)
B2 = np.random.randn(N2, 1)


########### Se video backpropagation

def backpropagation(y, z1, z2, x, a1, a2):
    deltas_z2 = - 2 *(a2 - y)*sigma_derivative(z2)
    db2 = deltas_z2
    dw2 = np.matmul(deltas_z2, a1.T)
    deltas_z1 = np.matmul(W2.T, deltas_z2)*(sigma_derivative(z1)) 
    db1 = deltas_z1 
    dw1 = np.matmul(deltas_z1, x.T)
    return dw1, dw2, db1, db2


########### Se video kod för att träna det neurala nätverket
step_size = 0.001
batch_size = 100
dW1 = np.zeros((N1, INPUT)) 
dW2 = np.zeros((N2, N1))
dB1 = np.zeros((N1, 1))
dB2 = np.zeros((N2, 1))

for e in range(5): 
    print('epoch', e)

    for i in range(len(train_X)):
        x = train_X[i]
        y = train_y[i]
        
        z1, z2, x, a1, a2 = forward(x) 
        dw1, dw2, db1, db2 = backpropagation(y, z1, z2, x, a1, a2)
        dW1 += dw1
        dW2 += dw2
        dB1 += db1
        dB2 += db2
        
        if i % batch_size == 0 or i == len(train_X) -1:
        
            W1 += step_size*dW1
            W2 += step_size * dW2
            B1 += step_size * dB1
            B2 += step_size * dB2
            dW1[:,:] = 0
            dW2[:,:] = 0
            dB1[:,:] = 0
            dB2[:,:] = 0


outputs = [predict(x) for x in test_X]
predictions = [np.argmax(result) for result in outputs]


for i in range(10, 20):  
    plt.imshow(test_images[i])
    plt.title('Prediction ' + str(predictions[i]))
    plt.show()



