Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion ravdl/v2/NeuralNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,30 @@ def fit(self, X, y, n_epochs, batch_size, save_model = False):
if save_model:
self.save_model()

def _forward_pass(self, X, training=True):
def _forward_pass(self, X, training=True,return_all_layer_output=False):
""" Calculate the output of the NN """
layer_output = X
all_layer_out={}
for layer in self.layers:
if layer == self.layers[0]:
layer_output = layer._forward_pass(layer_output, input_layer="True", training = training)
if isinstance(layer_output, dict):
layer_output = layer_output['output']
all_layer_out[layer.layer_name]=layer_output
else:
layer_output = layer._forward_pass(layer_output, training = training)
if isinstance(layer_output, dict):
layer_output = layer_output['output']
all_layer_out[layer.layer_name]=layer_output

if return_all_layer_output is True:
return all_layer_out

return layer_output




def _backward_pass(self, loss_grad):
""" Propagate the gradient 'backwards' and update the weights in each layer """
reversed_layers = list(reversed(self.layers))
Expand Down