You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a question about your comment on the test accuracy for the model with sigmoid activation function.
(It is under cell 17 for pytorch and under cell 18 for JAX)
You mentioned the result with sigmoid is very poor and, coincidentally, the model for JAX is trained, but is it because the model for pytorch is not trained well and is the result for JAX is correct?
I re-trained the model for pytorch and I found the training stops at epoch 8, because the result of epoch 1 is better than epoch 2-8.
This means the saved model is the result of epoch 1.
I changed "patient" variable from 7 to 50 and I got a similar result with JAX.
Thank you.
The text was updated successfully, but these errors were encountered:
Hi, the sigmoid model is indeed a fun one to play around in this tutorial. :)
I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings. You need to be a bit lucky that the gradients don't cancel each other out too much in the early layers and actually start learning. In JAX, the sigmoid networks tend to go slightly more stably to the learning regime. At the same time, when you optimize the initialization, add some normalization or use Adam, the MLP also trains relatively good with sigmoid activation functions. Nonetheless, the idea of the sigmoid training was to show that one shouldn't use sigmoid as the main hidden activation function in a network, since it brings several drawbacks. So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)
I had tried a couple of trainings in PyTorch with 50 epochs and noticed that some start learning suddenly, but many continued to fail even after longer trainings.
I ran a code shown below and all test accuracies were higher than 75%...
Did many models with sigmoid really fail to learn?
for i in range(50):
print(f"Training BaseNetwork with {i} ")
set_seed(i)
act_fn = Sigmoid()
net_actfn = BaseNetwork(act_fn=act_fn).to(device)
train_model(net_actfn, f"FashionMNIST_sigmoid_{i}", overwrite=False, patience=50)
It is true the learning start suddenly.
So I would recommend using other activation functions than trying to over-optimize the sigmoid network :)
Thank you for your great tutorials!
I have a question about your comment on the test accuracy for the model with sigmoid activation function.
(It is under cell 17 for pytorch and under cell 18 for JAX)
You mentioned the result with sigmoid is very poor and, coincidentally, the model for JAX is trained, but is it because the model for pytorch is not trained well and is the result for JAX is correct?
I re-trained the model for pytorch and I found the training stops at epoch 8, because the result of epoch 1 is better than epoch 2-8.
This means the saved model is the result of epoch 1.
I changed "patient" variable from 7 to 50 and I got a similar result with JAX.
Thank you.
The text was updated successfully, but these errors were encountered: