Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about train.py #54

Open
Mustardburger opened this issue Jul 22, 2023 · 5 comments
Open

Questions about train.py #54

Mustardburger opened this issue Jul 22, 2023 · 5 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@Mustardburger
Copy link
Collaborator

Issue type

Need help

Summary

Some functions in /cellbox/train.py have some ambiguity in what task they perform. These are crucial to understand to reproduce similar results for Pytorch version of CellBox. Therefore, this issue is for resolving the ambiguity.

Details

  • Line 76 to 79 in train.py, are loss_valid_i and loss_valid_mse_i evaluated on one random batch fetched from args.feed_dicts['valid_set'], or are these losses evaluated on the whole validation set?
  • The eval_model function returns different values with different calls. At line 101 to 103, it returns both the total and mse loss for args.n_batches_eval number of batches on the validation set. At line 109 to 111, it returns only the mse loss for args.n_batches_eval number of batches on the test set. And at line 262 it returns the expression predictions y_hat for the whole test set. Are all of these statements correct?
  • The record_eval.csv file generated after training, using the default training arguments and config file as specified in the README (python scripts/main.py -config=configs/Example.random_partition.json), has test_mse column to be None. Is it the expected behaviour of the code?
  • random_pos.csv, generated after training, stores the index of the perturbation conditions. Does it indicate how the conditions for training, validation, and testing are split?
  • After each substage, say substage 6, the code generates 6_best.y_hat.loss.csv, containing the expression prediction for perturbation conditions in the test set for all nodes, but it does not indicate which row in this file corresponds to which perturbation condition. How is this file and random_pos.csv related?
@Mustardburger Mustardburger added the help wanted Extra attention is needed label Jul 22, 2023
@DesmondYuan
Copy link
Collaborator

Line 76 to 79 in train.py, are loss_valid_i and loss_valid_mse_i evaluated on one random batch fetched from args.feed_dicts['valid_set'], or are these losses evaluated on the whole validation set?

Great question. The dataset iterator is defined in the dataset.py.

cfg.iter_monitor = tf.compat.v1.data.make_initializable_iterator(
dataset.repeat().shuffle(buffer_size=1024, reshuffle_each_iteration=True).batch(cfg.batchsize))

So here in train_model(), for each iteration we sample a batch for train and a batch for 'monitor' and the monitor iterator does not exhaust (because we don't know exactly how many batches we need to sample for monitoring before training).

@DesmondYuan
Copy link
Collaborator

The eval_model function returns different values with different calls. At line 101 to 103, it returns both the total and mse loss for args.n_batches_eval number of batches on the validation set. At line 109 to 111, it returns only the mse loss for args.n_batches_eval number of batches on the test set. And at line 262 it returns the expression predictions y_hat for the whole test set. Are all of these statements correct?

So the way we designed eval_model is to allow the same logic to be applied to different outputs. These are defined as obj_fn, or operations, and they can be scalar metrics like MSE or large tensors for predicted yhat. The naming was bad and we could probably fix that to avoid such confusion.

@DesmondYuan
Copy link
Collaborator

The record_eval.csv file generated after training, using the default training arguments and config file as specified in the README (python scripts/main.py -config=configs/Example.random_partition.json), has test_mse column to be None. Is it the expected behaviour of the code?

Yes that was intentional. The rationale was to 'only test the model as the final step'. Although it's sometimes not practical to do so.

Here for example we might want to bootstrap random partition sampling for a handful times and then pick the best model based on the validation set. Then we can test the picked model and report that performance. This can also save computational resources to skip the test stage during the training/development stage.

The final test evaluation step is implemented here

# Evaluation on test set
t0 = time.perf_counter()
sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['test_set'])
loss_test_mse = eval_model(
sess, model.iter_eval, model.eval_mse_loss,
args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval)
append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.perf_counter() - t0])

As the result, the last line of record should have only test MSE while the other columns being None.

@DesmondYuan
Copy link
Collaborator

random_pos.csv, generated after training, stores the index of the perturbation conditions. Does it indicate how the conditions for training, validation, and testing are split?

We did not document this well. This step was handled here

def random_partition(cfg) -> Mapping[str, Any]:
"""random dataset partition"""
nexp, _ = cfg.pert.shape
nvalid = int(nexp * cfg.trainset_ratio)
ntrain = int(nvalid * cfg.validset_ratio)
try:
random_pos = np.genfromtxt('random_pos.csv', defaultfmt='%d')
except Exception:
random_pos = np.random.choice(range(nexp), nexp, replace=False)
np.savetxt('random_pos.csv', random_pos, fmt='%d')
dataset = {
"node_index": cfg.node_index,
"pert_full": cfg.pert,
"train_pos": random_pos[:ntrain],
"valid_pos": random_pos[ntrain:nvalid],
"test_pos": random_pos[nvalid:]
}

so in the default configs/Example.random_partition.json, the first trainset_ratio=70% of data would be taken for train+dev and the rest becomes test, i.e. 30%. And the 70% would be further split into 80%+20% percent, as specified by validset_ratio=80%. So in the end we have 56% training + 14% validation + 30% test.

@DesmondYuan
Copy link
Collaborator

After each substage, say substage 6, the code generates 6_best.y_hat.loss.csv, containing the expression prediction for perturbation conditions in the test set for all nodes, but it does not indicate which row in this file corresponds to which perturbation condition. How is this file and random_pos.csv related?

This is related to my comment above. They should be on the test subset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants