Skip to content

Commit

Permalink
4-BIT DSP-MVU: Allow very wide lanes requiring even a high-lane exten…
Browse files Browse the repository at this point in the history
…sion.
  • Loading branch information
preusser committed Sep 30, 2024
1 parent 945a4a4 commit a7f29e0
Show file tree
Hide file tree
Showing 5 changed files with 2,508 additions and 23 deletions.
59 changes: 39 additions & 20 deletions finn-rtllib/mvu/mvu_4sx4u.sv
Original file line number Diff line number Diff line change
Expand Up @@ -80,35 +80,47 @@ module mvu_4sx4u #(
* - Internal lane widths differ, at most, by a single bit.
* - The rightmost lane (#0) has the maximum internal width.
* - The leftmost lane (#3) extends into the wide DSP accumulation path and
* is constrained by ACCU_WIDTH rather than the next lane. It doesn't have
* an external high extension.
* is typically constrained by ACCU_WIDTH rather than the next lane. If so,
* it doesn't have an external high extension.
* - The one but leftmost lane (#2) has the minimum internal width and, hence,
* the macimum external high extension.
*/
typedef int unsigned lane_offset_v[4:0];
function lane_offset_v sliceLanes();
automatic lane_offset_v res;
unique case(VERSION)
1: begin
return NARROW_WEIGHTS?
res = NARROW_WEIGHTS?
lane_offset_v'{ ACCU_WIDTH+21, 21, 14, 7, 0 } :
lane_offset_v'{ 0, 0, 0, 0, 0 }; // not supported
end
2: begin
return NARROW_WEIGHTS?
res = NARROW_WEIGHTS?
lane_offset_v'{ ACCU_WIDTH+23, 23, 16, 8, 0 } :
lane_offset_v'{ ACCU_WIDTH+22, 22, 15, 8, 0 };
end
endcase
if(res[4] > 48) res[4] = 48;
return res;
endfunction : sliceLanes
localparam lane_offset_v OFFSETS = sliceLanes();

function int unsigned sum_width(input int unsigned n, input int unsigned w);
return w <= 16? $clog2(1 + n*(2**w - 1)) : w + $clog2(n);
endfunction : sum_width
function int unsigned lo_width(input int unsigned i);
return OFFSETS[i+1] - OFFSETS[i];
endfunction : lo_width
function int unsigned hi_width(input int unsigned i);
return 1 + $clog2(2**(ACCU_WIDTH-lo_width(i)-1)+SIMD);
automatic int unsigned lw = lo_width(i);
return ACCU_WIDTH <= lw?
0 :
1 + ($clog2(SIMD) < ACCU_WIDTH-lw?
ACCU_WIDTH-lw :
$clog2(2**(ACCU_WIDTH-lw-1)+SIMD)
);
endfunction : hi_width
localparam int unsigned LO_WIDTH_MAX = OFFSETS[1] - OFFSETS[0];
localparam int unsigned LO_WIDTH_MAX = lo_width(3);
localparam int unsigned HI_WIDTH_MAX = hi_width(2);

localparam int unsigned A_WIDTH = 23 + 2*VERSION; // Width of A datapath
Expand Down Expand Up @@ -139,7 +151,7 @@ module mvu_4sx4u #(
localparam int unsigned PE_REM = 4*(c+1) - PE_END;

uwire [47:0] p3[SIMD];
uwire signed [ 1:0] h3[SIMD][3];
uwire signed [ 1:0] h3[SIMD][4];
for(genvar s = 0; s < SIMD; s++) begin : genSIMD

// Input Lane Assembly
Expand Down Expand Up @@ -500,6 +512,16 @@ module mvu_4sx4u #(
for(genvar i = 0; i < 3; i++) begin
assign h3[s][i] = pp[OFFSETS[i+1]+:2] - X3[i+1];
end
// Overflow out of high lane
logic PZ = 0;
always_ff @(posedge clk) begin
if(rst) PZ <= 0;
else if(en) PZ <= L[3]? 0 : pp[$left(pp)];
end
assign h3[s][3] =
( PZ && !pp[$left(pp)-:2])? +1 :
(!PZ && &pp[$left(pp)-:2])? -1 : 0;

assign p3[s] = pp;

end : genSIMD
Expand All @@ -509,17 +531,16 @@ module mvu_4sx4u #(
// Count leaves reachable from each node
localparam leave_load_t LEAVE_LOAD = SIMD > 1 ? init_leave_loads() : '{ default: 1 }; // SIMD=1 requires no adder tree, so zero-ing out, otherwise init_leave_loads ends up in infinite loop

uwire signed [ACCU_WIDTH-1:0] up4;
uwire signed [ HI_WIDTH_MAX-1:0] hi4[3];
uwire [$clog2(SIMD)+LO_WIDTH_MAX-1:0] lo4[3];
uwire signed [HI_WIDTH_MAX-1:0] hi4[4];
uwire [LO_WIDTH_MAX-1:0] lo4[4];
for(genvar i = 0; i < 4; i++) begin

// Conclusive high part accumulation
if(i < 3) begin : genHi
if(i < PE_REM) assign hi4[i] = '0;
if(i < PE_REM) assign hi4[i] = 0;
else begin : genHi
localparam int unsigned HI_WIDTH = hi_width(i);
if(HI_WIDTH == 0) assign hi4[i] = 0;
else begin
localparam int unsigned HI_WIDTH = hi_width(i);

// Adder Tree across all SIMD high contributions, each from [-1:1]
uwire signed [2*SIMD-2:0][$clog2(1+SIMD):0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = h3[s][i];
Expand All @@ -543,7 +564,6 @@ module mvu_4sx4u #(
end
end
assign hi4[i] = Hi4;

end
end : genHi

Expand All @@ -553,12 +573,12 @@ module mvu_4sx4u #(
localparam int unsigned LO_WIDTH = lo_width(i);

// Adder Tree across all SIMD low contributions
localparam int unsigned ROOT_WIDTH = $clog2(1 + SIMD*(2**LO_WIDTH-1));
localparam int unsigned ROOT_WIDTH = sum_width(SIMD, LO_WIDTH);
uwire [2*SIMD-2:0][ROOT_WIDTH-1:0] tree;
for(genvar s = 0; s < SIMD; s++) assign tree[SIMD-1+s] = p3[s][OFFSETS[i]+:LO_WIDTH];
for(genvar n = 0; n < SIMD-1; n++) begin
// Sum truncated to actual maximum bit width at this node
localparam int unsigned NODE_WIDTH = $clog2(1 + LEAVE_LOAD[n]*(2**LO_WIDTH-1));
localparam int unsigned NODE_WIDTH = sum_width(LEAVE_LOAD[n], LO_WIDTH);
uwire [NODE_WIDTH-1:0] s = tree[2*n+1] + tree[2*n+2];
assign tree[n] = s;
end
Expand All @@ -569,8 +589,7 @@ module mvu_4sx4u #(
else if(en) Lo4 <= tree[0];
end

if(i == 3) assign up4 = Lo4;
else assign lo4[i] = Lo4;
assign lo4[i] = Lo4;
end : genLo

end
Expand All @@ -580,7 +599,7 @@ module mvu_4sx4u #(
always_ff @(posedge clk) begin
if(rst) Res5 <= '{ default: 0 };
else if(en) begin
Res5[3] <= up4 - hi4[2];
Res5[3] <= $signed({ hi4[3], {(lo_width(3)){1'b0}} }) + $signed({ 1'b0, lo4[3] }) - hi4[2];
Res5[2] <= $signed({ hi4[2], {(lo_width(2)){1'b0}} }) + $signed({ 1'b0, lo4[2] }) - hi4[1];
Res5[1] <= $signed({ hi4[1], {(lo_width(1)){1'b0}} }) + $signed({ 1'b0, lo4[1] }) - hi4[0];
Res5[0] <= $signed({ hi4[0], {(lo_width(0)){1'b0}} }) + $signed({ 1'b0, lo4[0] });
Expand Down
Loading

0 comments on commit a7f29e0

Please sign in to comment.