Facebook
From RomanTK, 1 Month ago, written in Python.
Embed
Download Paste or View Raw
Hits: 137
  1. import numpy as np
  2. import nnet as net
  3. import matplotlib.pyplot as plt
  4. x = np.array([[-1.0, -0.9, -0.8, -0.7, -0.6, -0.5]])
  5.  
  6. y_t = np.array([[-0.9602, -0.5770, -0.0729, 0.3771, 0.6405, 0.6600]])
  7. L = x.shape[0]
  8. K1 = y_t.shape[0]
  9. w1, b1 = net.nwlog(K1, L)
  10. max_epoch = 20000
  11. err_goal = 0.01
  12. disp_freq = 1000
  13. lr = 0.1
  14. SSE_vec = []
  15. for epoch in range(1, max_epoch+1):
  16.     y1 = net.logsig(np.dot(w1, x), b1)
  17.     e = y_t - y1
  18.     SSE = net.sumsqr(e)
  19.     if np.isnan(SSE):
  20.         break
  21.     SSE_vec.append(SSE)
  22.     if SSE < err_goal:
  23.         break
  24.     d1 = net.deltalog(y1, e)
  25.     dw1, db1 = net.learnbp(x, d1, lr)
  26.     w1 += dw1
  27.     b1 += db1
  28.     if (epoch % disp_freq) == 0:
  29.         plt.clf()
  30.         print("Epoch: ] | SSE: %5.5f " % (epoch, SSE))
  31.         plt.plot(x[0], y_t[0], 'r', x[0], y1[0], 'g')
  32.         plt.grid()
  33.         plt.show()
  34.         plt.pause(1e-2)
  35.  
  36. plt.figure()
  37. plt.plot(SSE_vec)
  38. plt.ylabel('SSE')
  39. plt.yscale('linear')
  40. plt.title('epoch')
  41. plt.grid(True)
  42. plt.show()