forked from vgvassilev/clad
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for differentiating switch stmt in the reverse mode AD. (v…
…gvassilev#339) This commit adds support for differentiating switch statements in the reverse mode AD. The basic idea used to differentiate switch statement is that in the forward pass, processing of the statements of the switch statement body always starts from a case/default label and ends at a break statement or at the end of the switch body. Similarly, in the reverse pass, processing of the differentiated statements of the switch statement body will start from the statement just above the break statement that was hit or from the last differentiated statement in the case when no break statement was hit. Thus, we can keep track of which break statement was hit in the forward pass or if no break statement was hit at all in a variable. This information is further used by an auxiliary switch statement in the reverse pass to jump the execution to the correct point (that is, differentiated statement of the statement just before the break statement that was hit in the forward pass). In this strategy, each switch case statement of the original function gets transformed to an if condition in the reverse pass. The if condition decides at runtime whether the processing of the differentiated statements of the switch statement body should stop or continue. This is again based on the fact that the processing of statements of the switch statement body always starts at a case statement. For an example, consider this code snippet: ```cpp switch (count) { case 0: a += i; break; case 2: a += 4 * i; break; default: a += 10 * i; } case 0 of this code snippet gets transformed to the following in the differentiated function: forward pass: { case 0: a += i; } { clad::push(_t0, 1UL); // this is used to keep track if this break was hit; 1UL is used to represent the case number break; } reverse pass: case 1UL:; // this case is selected if the corresponding break was hit in the forward pass { { double _r_d0 = _d_a; _d_a += _r_d0; *_d_i += _r_d0; _d_a -= _r_d0; } if (0 == _cond0) // If case 0: was selected in the forward pass, we should break out of processing differentiated switch stmt body here. break; } ```
- Loading branch information
Showing
7 changed files
with
1,100 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.