Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better MP, multi-gpu, atom graphs, replay, GPS, LS-GFN, and fixes #141

Open
wants to merge 35 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7dbca12
first throw at refactoring SamplingIterator
bengioe Feb 28, 2024
939cb56
Merge branch 'trunk' into bengioe-better-iterators
bengioe Feb 28, 2024
dfba1ca
changed all iterators to DataSource
bengioe Feb 29, 2024
e5239fb
lots of little fixes, tested all tasks, better device management
bengioe Feb 29, 2024
43dfc2b
style
bengioe Mar 1, 2024
279ecfc
change batch size hyperparameters + fix nested dataclasses
bengioe Mar 7, 2024
2ba251a
Merge branch 'trunk' into bengioe-better-iterators
bengioe Mar 7, 2024
282bbfb
move things around & prevent circular import
bengioe Mar 7, 2024
c3bc6d0
tox
bengioe Mar 7, 2024
b1c5630
fix imports
bengioe Mar 7, 2024
a64a639
replace device references with get_worker_device
bengioe Mar 7, 2024
28bcc59
little fixes
bengioe Mar 7, 2024
4811e7c
a few more stragglers
bengioe Mar 7, 2024
7d32ac1
proof of concept of using shared pinned buffers
bengioe Feb 23, 2024
d4a2a7d
32mb buffer
bengioe Feb 23, 2024
27dfc23
add to DataSource
bengioe Mar 7, 2024
e9f1dc1
various fixes
bengioe Mar 8, 2024
c048e77
major simplification by reusing pickling mechanisms
bengioe Mar 8, 2024
acfe070
memory copy + fixes and doc
bengioe Mar 11, 2024
9454da8
Merge branch 'trunk' into bengioe-mp-with-batch-buffers
bengioe Mar 11, 2024
2b9da70
Merge branch 'trunk' into bengioe-mp-with-batch-buffers
bengioe May 8, 2024
907ffcd
fix global_cfg + opt_Z when there's no Z
bengioe May 8, 2024
60722a7
fix entropy when masks are used
bengioe May 9, 2024
f859640
small fixes
bengioe May 9, 2024
d536233
removing timing prints
bengioe May 9, 2024
6c3beba
C graphs, DDP, logit scaling
bengioe Aug 21, 2024
67f4b62
C mol valence fix, mask-backwards sample, MLE in TB, priority replay,…
bengioe Aug 28, 2024
a1534be
first (bad) attempt
bengioe Aug 29, 2024
4491f6b
working local serach, cond_info dict pass, allow no log dir, fix Pad …
bengioe Aug 30, 2024
c5373cf
lstb file
bengioe Aug 30, 2024
7e623bd
yield_only_accepted in LS + load_model_state flag
bengioe Sep 5, 2024
32d4caf
many fixes, frag env options
bengioe Oct 8, 2024
30bd2e3
tox
bengioe Oct 8, 2024
ccefd86
ruff & mypy
bengioe Oct 8, 2024
1669c28
bandit
bengioe Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,18 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo

`DataSource` also covers validation sets, including cases such as:
- Generating new trajectories (w.r.t a fixed dataset of conditioning goals)
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset
- Evaluating the model's likelihood on trajectories from a fixed, offline dataset

## Multiprocessing

We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks.

Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported.

Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected.

On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory.

We implement two solutions to this problem (in order of preference):
- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages.
- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ universal = "true"
[tool.bandit]
# B101 tests the use of assert
# B301 and B403 test the use of pickle
skips = ["B101", "B301", "B403"]
# B614 tests the use of torch.load/save
skips = ["B101", "B301", "B403", "B614"]
exclude_dirs = ["tests", ".tox", ".venv"]

[tool.pytest.ini_options]
Expand Down
18 changes: 16 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ast import literal_eval
from subprocess import check_output # nosec - command is hard-coded, no possibility of injection

from setuptools import setup
from setuptools import Extension, setup


def _get_next_version():
Expand All @@ -25,4 +25,18 @@ def _get_next_version():
return f"{major}.{minor}.{latest_patch+1}"


setup(name="gflownet", version=_get_next_version())
ext = [
Extension(
name="gflownet._C",
sources=[
"src/C/main.c",
"src/C/data.c",
"src/C/graph_def.c",
"src/C/node_view.c",
"src/C/edge_view.c",
"src/C/degree_view.c",
"src/C/mol_graph_to_Data.c",
],
)
]
setup(name="gflownet", version=_get_next_version(), ext_modules=ext)
Loading
Loading