Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Friendlier API for multiple inputs/outputs #35

Open
cmdupuis3 opened this issue Nov 8, 2021 · 2 comments
Open

Friendlier API for multiple inputs/outputs #35

cmdupuis3 opened this issue Nov 8, 2021 · 2 comments

Comments

@cmdupuis3
Copy link

cmdupuis3 commented Nov 8, 2021

Currently, there doesn't seem to be a nice way of interfacing a batch with a model with multiple inputs/outputs. This list comprehension works, but I think there could be a more elegant solution.

    model.fit([batch[x] for x in (sc.stencil_2D + sc.stencil_3D + sc.vanilla)],
              [batch[x] for x in sc.target])

From here, it would be possible to index with a list and get a list back. So, ideally I'd like to have something like

    model.fit([batch[sc.stencil_2D + sc.stencil_3D + sc.vanilla]],
              [batch[sc.target]])

without the list comprehension boilerplate.

@jhamman
Copy link
Contributor

jhamman commented Feb 3, 2022

@cmdupuis3 - I know this has sat for a while but do you think you could expand a bit on your use case? Even better if you could provide a simple demo that articulates how you are creating your batches and what data shapes you expect to pass to your model. For example, it is difficult to tell what sc is supposed to be in your example.

@cmdupuis3
Copy link
Author

cmdupuis3 commented Feb 17, 2022

@jhamman Ah, yeah that is a little unclear. sc is just a struct with lists of variable names. Maybe instead, you can think of it like this:

    model.fit([batch[x] for x in (var_list1 + var_list2 + var_list3)],
              [batch[x] for x in var_list4])

(changed to)

    model.fit([batch[var_list1 + var_list2 + var_list3]],
              [batch[var_list4]])

So in my case, I have something like this (note that I need to use the squeeze_batch_dim option I added in #39 ):

    bgen = xb.BatchGenerator(
            ds,
            {'nlon':nlons,     'nlat':nlats},
            {'nlon':halo_size, 'nlat':halo_size},
            squeeze_batch_dim = False
        )

    for batch in bgen:
            sub = {'nlon':range(halo_size,nlons-halo_size),
                   'nlat':range(halo_size,nlats-halo_size)}
            ...
            batch_stencil_2D = [batch[x.name] for x in sc.variable]
            batch_target     = [batch[x.name][sub] for x in sc.target]
            ...
            model.compile(loss='mae', optimizer='Adam', metrics=['mae', 'mse', 'accuracy'])
            model.fit(batch_stencil_2D, batch_target, batch_size=32, epochs=1, verbose=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants