-
-
Notifications
You must be signed in to change notification settings - Fork 606
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1367: RNN update to drop CUDNN, fix LSTM bug and output type stability r=CarloLucibello a=jeremiedb PR related to #1114 #1360 #1365 Some experiment for RNN handling. Hidden state of each cell structure was dropped as they weren't needed (AFAIK, only needed for size inference for CUDNN, but bias size could be used as a substitute to cells' `h` there as well). Looked to drop dependence on CUDNN entirely, so it's a pure Flux/CUDA.jl. File `src/cuda/curnnjl` no longer used. No modifications were made to the cell computations. Initial test seems to show decent performance, but yet to benchmark. Pending issue: despite having dropped completely the CUDNN dependency, there's still an instability issue that seems present when running on GPU. This is illustrated in the test at lines 1-50 of file `test\rnn-test-jdb.jl`. If that test runs on CPU, it goes well thorugh the 100 iterations. However, the same on GPU will thow NAs after couple dozens of iterations. My only hypothesis so far: when performing the iteration over the sequence through `m.(x)` or `map(rnn, x)`, is the order of the execution safe? Ie: is it possible that there isn't a `sync()` on the CUDA side between those seq steps, which may mess up the state? ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: jeremiedb <[email protected]> Co-authored-by: jeremie.db <[email protected]>
- Loading branch information
Showing
7 changed files
with
40 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ | |
docs/build/ | ||
docs/site/ | ||
deps | ||
# Manifest.toml | ||
.vscode | ||
# Manifest.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
module CUDAint | ||
|
||
using ..CUDA | ||
|
||
using CUDA: CUDNN | ||
include("curnn.jl") | ||
|
||
import ..Flux: Flux | ||
import Zygote | ||
using Zygote: @adjoint | ||
|
||
include("cudnn.jl") | ||
|
||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters