Skip to content

Commit

Permalink
fix: 2*numLayers + 1
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 21, 2023
1 parent 8d91ef7 commit d667215
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager mana
NDList pastKeyValuesOutput = output.subNDList(1, config.getNumLayers() * 2 + 1);
NDArray hiddenStatesOutput = manager.zeros(new Shape(1));
if (output.size() > config.getNumLayers() * 2 + 2) {
hiddenStatesOutput = output.subNDList(config.getNumLayers() * 2 + 2).get(0);
hiddenStatesOutput = output.subNDList(config.getNumLayers() * 2 + 1).get(0);
}

if (flagDummyKvCach) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ public boolean mainContrastivePt(String[] args) {
NDArray expected =
manager.create(
new long[][] {
{284, 8494, 3716, 2761, 11, 884, 355, 1692, 1535, 11},
{4436, 329, 257, 2910, 1332, 13, 632, 373, 257, 3487}
{1212, 2708, 318, 546, 262, 2095, 13, 921, 743, 307},
{379, 502, 351, 10953, 287, 607, 2951, 13, 366, 40}
});
return output.get(":, -10:").equals(expected);
}
Expand Down

0 comments on commit d667215

Please sign in to comment.