# fit method from NeuralNetwork class which calls layers' backward method
def fit(self, X, y, n_epochs=1000):
self._initialize(X)
losses = []
for _ in range(n_epochs):
last_output = X
for layer in self.layers:
last_output = layer.forward(last_output)
loss = self.loss(last_output, y)
losses.append(loss)
last_d = self.dloss(last_output, y)
for layer in reversed(self.layers):
last_d = layer.backward(last_d, self.lr)
print(f"Epoch: {_}; Loss: {loss};")
return losses
#backward method for backpropagation in Dense class
def backward(self, last_derivative, lr):
"""
Parameters
- last_derivative : derivative from the previous layer
- lr (learning rate) : determines how big changes are made to network's weights
"""
w = self.weights
dloss_1 = self.dactivate(last_derivative)
d_w = np.dot(self.layer_input.T, dloss_1)
self.weights -= -np.dot(lr, d_w)
return np.dot(w, dloss_1.T)
# The problem is:
""" Gradient should point to the highest direction so I should substract it from weights, but then the loss increases. What's wrong? (I use self.weights -= -np.dot(...) because the loss decreases when it shouldn't)
{"html5":"htmlmixed","css":"css","javascript":"javascript","php":"php","python":"python","ruby":"ruby","lua":"text\/x-lua","bash":"text\/x-sh","go":"go","c":"text\/x-csrc","cpp":"text\/x-c++src","diff":"diff","latex":"stex","sql":"sql","xml":"xml","apl":"apl","asterisk":"asterisk","c_loadrunner":"text\/x-csrc","c_mac":"text\/x-csrc","coffeescript":"text\/x-coffeescript","csharp":"text\/x-csharp","d":"d","ecmascript":"javascript","erlang":"erlang","groovy":"text\/x-groovy","haskell":"text\/x-haskell","haxe":"text\/x-haxe","html4strict":"htmlmixed","java":"text\/x-java","java5":"text\/x-java","jquery":"javascript","mirc":"mirc","mysql":"sql","ocaml":"text\/x-ocaml","pascal":"text\/x-pascal","perl":"perl","perl6":"perl","plsql":"sql","properties":"text\/x-properties","q":"text\/x-q","scala":"scala","scheme":"text\/x-scheme","tcl":"text\/x-tcl","vb":"text\/x-vb","verilog":"text\/x-verilog","yaml":"text\/x-yaml","z80":"text\/x-z80"}