-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* start * works but could be better * better throughput alt * comment * clean
- Loading branch information
1 parent
45044c2
commit cfebc86
Showing
8 changed files
with
434 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import "apps/blas/util.fil"; | ||
import "primitives/signed.fil"; | ||
|
||
// Applies a rotation to vectors x, y | ||
// W: element width | ||
// N: vector length | ||
// M: multiplier amount | ||
// A: adder amount | ||
// Tiles scalar-vector products based on multiplier count, | ||
// then does additions as multiplies finish | ||
comp Rot[W, N, M, A]<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
x[N]: ['G, 'G+1] W, | ||
y[N]: ['G, 'G+1] W, | ||
) -> ( | ||
out_1[N]: ['G+L, 'G+L+1] W, | ||
out_2[N]: ['G+L, 'G+L+1] W | ||
) with { | ||
some L where L > 0; | ||
some II where II > 0; | ||
} where W > 0, | ||
N > 0, | ||
M > 0, | ||
A > 0, | ||
M % 4 == 0, // need to do at least 4 multiplies at once | ||
A % 2 == 0, // need to do at least 2 adds at once | ||
N % (M/4) == 0, | ||
(M/4) % (A/2) == 0 | ||
{ | ||
// partition mults into 4 groups, for 4 mults that need to happen at each time step | ||
let m = M/4; | ||
|
||
// same for adds, but only 2 adds to do | ||
let a = A/2; | ||
|
||
// reuse for a single scalar-vector computation | ||
let mult_reuses = N / m; | ||
|
||
// reuse for each of the adder groups | ||
let add_reuses = m / a; | ||
|
||
// dummy so we can get its params | ||
M_ := new Multipliers[W, m]; | ||
let mult_latency = M_::L; | ||
let mult_ii = M_::II; | ||
|
||
A_ := new Adders[W, a]; | ||
let add_ii = A_::II; | ||
|
||
// -s | ||
negs := new Neg[W]<'G>(s); | ||
|
||
// instantiate multipliers | ||
let last_mult_invoke = (mult_reuses)*mult_ii; | ||
M_cy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_sy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_cx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_nsx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
|
||
// instantiate adders | ||
let add_end = last_mult_invoke + mult_latency + add_reuses*add_ii + (add_reuses-1)*(mult_reuses-1) + 1; | ||
A_1 := new Adders[W, a] in ['G+mult_latency, 'G+add_end]; | ||
A_2 := new Adders[W, a] in ['G+mult_latency, 'G+add_end]; | ||
|
||
// check which stage is limiting the pipeline | ||
let ii = if (add_end-mult_latency) < (last_mult_invoke) {(last_mult_invoke)} else {(add_end-mult_latency)}; | ||
|
||
bundle cy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle sy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle cx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle nsx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
|
||
// scalar bundles for multiplications | ||
bundle c_bundle[m]: ['G, 'G+1] W; | ||
bundle s_bundle[m]: ['G, 'G+1] W; | ||
bundle negs_bundle[m]: ['G, 'G+1] W; | ||
|
||
// fill them | ||
for i in 0..m { | ||
c_bundle{i} = c; | ||
s_bundle{i} = s; | ||
negs_bundle{i} = negs.out; | ||
} | ||
|
||
// start multiplications | ||
for i in 0..mult_reuses { | ||
// some parameters | ||
let mult_start = i*mult_ii; | ||
let mult_end = i*mult_ii + mult_latency; | ||
|
||
// register inputs | ||
x_reg := new Shift[W, i*mult_ii, m]<'G>(x{i*m..(i+1)*m}); | ||
y_reg := new Shift[W, i*mult_ii, m]<'G>(y{i*m..(i+1)*m}); | ||
c_reg := new Shift[W, i*mult_ii, m]<'G>(c_bundle{0..m}); | ||
s_reg := new Shift[W, i*mult_ii, m]<'G>(s_bundle{0..m}); | ||
negs_reg := new Shift[W, i*mult_ii, m]<'G>(negs_bundle{0..m}); | ||
|
||
mult_cy := M_cy<'G+i*mult_ii>(c_reg.out{0..m}, y_reg.out{0..m}); | ||
mult_sy := M_sy<'G+i*mult_ii>(s_reg.out{0..m}, y_reg.out{0..m}); | ||
mult_cx := M_cx<'G+i*mult_ii>(c_reg.out{0..m}, x_reg.out{0..m}); | ||
mult_nsx := M_nsx<'G+i*mult_ii>(negs_reg.out{0..m}, x_reg.out{0..m}); | ||
|
||
cy{i}{0..m} = mult_cy.out{0..m}; | ||
sy{i}{0..m} = mult_sy.out{0..m}; | ||
cx{i}{0..m} = mult_cx.out{0..m}; | ||
nsx{i}{0..m} = mult_nsx.out{0..m}; | ||
|
||
for j in 0..add_reuses { | ||
let offset = i*(add_reuses-1); | ||
mult_cy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cy{i}{(j*a)..(j+1)*a}); | ||
mult_sy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(sy{i}{(j*a)..(j+1)*a}); | ||
mult_cx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cx{i}{(j*a)..(j+1)*a}); | ||
mult_nsx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(nsx{i}{(j*a)..(j+1)*a}); | ||
|
||
add_1 := A_1<'G + mult_end + j*add_ii + offset>(mult_cx_reg.out{0..a}, mult_sy_reg.out{0..a}); | ||
add_2 := A_2<'G + mult_end + j*add_ii + offset>(mult_nsx_reg.out{0..a}, mult_cy_reg.out{0..a}); | ||
|
||
add_1_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_1.out{0..a}); | ||
add_2_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_2.out{0..a}); | ||
|
||
out_1{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_1_reg.out{0..a}; | ||
out_2{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_2_reg.out{0..a}; | ||
} | ||
} | ||
|
||
let latency = (mult_reuses*mult_ii + mult_latency) + (add_reuses-1)*(mult_reuses-1) + (add_reuses-1)*add_ii; | ||
L := latency; | ||
II := ii; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import "primitives/core.fil"; | ||
import "apps/blas/scal/scal.fil"; | ||
import "apps/blas/util.fil"; | ||
import "primitives/signed.fil"; | ||
|
||
// Applies a rotation to vectors x, y | ||
// W: element width | ||
// N: vector length | ||
// M: multiplier amount | ||
// A: adder amount | ||
// Uses scal to compute the scalar-vector products cy, sy, cx, -sx | ||
comp Rot[W, N, M, A]<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
x[N]: ['G, 'G+1] W, | ||
y[N]: ['G, 'G+1] W, | ||
) -> ( | ||
out_1[N]: ['G+L, 'G+L+1] W, | ||
out_2[N]: ['G+L, 'G+L+1] W | ||
) with { | ||
some L where L > 0; | ||
some II where II > 0; | ||
} where W > 0, | ||
L > 0, | ||
N > 0, | ||
M > 0, | ||
A > 0, | ||
M % 4 == 0, | ||
A % 2 == 0 | ||
{ | ||
|
||
scalex := new Scal[W, N, M/4]; | ||
let scale_latency = scalex::L; | ||
let scale_ii = scalex::II; | ||
|
||
zero := new Const[W, 0]<'G>(); | ||
neg_s := new Neg[W]<'G>(s); | ||
|
||
bundle cy[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle sy[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle cx[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle msx[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
|
||
SCY := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_cy := SCY<'G>(y{0..N}, c); | ||
cy{0..N} = scale_cy.out{0..N}; | ||
|
||
SSY := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_sy := SSY<'G>(y{0..N}, s); | ||
sy{0..N} = scale_sy.out{0..N}; | ||
|
||
SCX := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_cx := SCX<'G>(x{0..N}, c); | ||
cx{0..N} = scale_cx.out{0..N}; | ||
|
||
SMSX := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_msx := SMSX<'G>(x{0..N}, neg_s.out); | ||
msx{0..N} = scale_msx.out{0..N}; | ||
|
||
// out_1{i} <- cx{i} + sy{i} | ||
// out_2{i} <- msx{i} + cy{i} | ||
|
||
let add_uses = N / (A/2); | ||
let add_ii = 1; | ||
|
||
// use half the adders for x, half for y | ||
A_x := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1]; | ||
A_y := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1]; | ||
|
||
let latency = scale_latency + (add_uses-1) * add_ii; | ||
for k in 0..add_uses { | ||
// save chunked arrays based on when we are ready to add them | ||
cx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cx{k*(A/2)..(k+1)*(A/2)}); | ||
sy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(sy{k*(A/2)..(k+1)*(A/2)}); | ||
cy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cy{k*(A/2)..(k+1)*(A/2)}); | ||
msx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(msx{k*(A/2)..(k+1)*(A/2)}); | ||
|
||
ax := A_x<'G + scale_latency + k*add_ii>(cx_reg.out{0..(A/2)}, sy_reg.out{0..(A/2)}); | ||
ay := A_y<'G + scale_latency + k*add_ii>(msx_reg.out{0..(A/2)}, cy_reg.out{0..(A/2)}); | ||
|
||
// save add result | ||
ax_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ax.out{0..(A/2)}); | ||
ay_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ay.out{0..(A/2)}); | ||
|
||
out_1{k*(A/2)..(k+1)*(A/2)} = ax_reg.out{0..(A/2)}; | ||
out_2{k*(A/2)..(k+1)*(A/2)} = ay_reg.out{0..(A/2)}; | ||
} | ||
|
||
L := latency; | ||
// this is a thing we can do now? | ||
let ii = if (add_uses*add_ii) > (scale_ii) {add_uses*add_ii} else {scale_ii}; | ||
II := ii; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# for determining the right answer | ||
|
||
def rot(u, v, c, s): | ||
assert len(u) == len(v) | ||
x = [0] * len(u) | ||
y = [0] * len(v) | ||
for i in range(len(u)): | ||
x[i] = c*u[i] + s*v[i] | ||
y[i] = (-s)*u[i] + c*v[i] | ||
return (x,y) | ||
|
||
if __name__ == '__main__': | ||
u0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] | ||
v0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] | ||
c0 = 1 | ||
s0 = 1 | ||
(x0, y0) = rot(u0, v0, c0, s0) | ||
print(f"x0: {x0}") | ||
print(f"y0: {y0}") | ||
|
||
print("\n======================\n") | ||
|
||
u1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] | ||
v1 = [16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | ||
c1 = 1 | ||
s1 = 1 | ||
(x1, y1) = rot(u1, v1, c1, s1) | ||
print(f"x1: {x1}") | ||
print(f"y1: {y1}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import "apps/blas/rot/rot-alt.fil"; | ||
|
||
comp main<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
u_0: ['G, 'G+1] W, | ||
u_1: ['G, 'G+1] W, | ||
u_2: ['G, 'G+1] W, | ||
u_3: ['G, 'G+1] W, | ||
u_4: ['G, 'G+1] W, | ||
u_5: ['G, 'G+1] W, | ||
u_6: ['G, 'G+1] W, | ||
u_7: ['G, 'G+1] W, | ||
u_8: ['G, 'G+1] W, | ||
u_9: ['G, 'G+1] W, | ||
u_10: ['G, 'G+1] W, | ||
u_11: ['G, 'G+1] W, | ||
u_12: ['G, 'G+1] W, | ||
u_13: ['G, 'G+1] W, | ||
u_14: ['G, 'G+1] W, | ||
u_15: ['G, 'G+1] W, | ||
v_0: ['G, 'G+1] W, | ||
v_1: ['G, 'G+1] W, | ||
v_2: ['G, 'G+1] W, | ||
v_3: ['G, 'G+1] W, | ||
v_4: ['G, 'G+1] W, | ||
v_5: ['G, 'G+1] W, | ||
v_6: ['G, 'G+1] W, | ||
v_7: ['G, 'G+1] W, | ||
v_8: ['G, 'G+1] W, | ||
v_9: ['G, 'G+1] W, | ||
v_10: ['G, 'G+1] W, | ||
v_11: ['G, 'G+1] W, | ||
v_12: ['G, 'G+1] W, | ||
v_13: ['G, 'G+1] W, | ||
v_14: ['G, 'G+1] W, | ||
v_15: ['G, 'G+1] W, | ||
) -> ( | ||
x_0: ['G+L, 'G+L+1] W, | ||
x_1: ['G+L, 'G+L+1] W, | ||
x_2: ['G+L, 'G+L+1] W, | ||
x_3: ['G+L, 'G+L+1] W, | ||
x_4: ['G+L, 'G+L+1] W, | ||
x_5: ['G+L, 'G+L+1] W, | ||
x_6: ['G+L, 'G+L+1] W, | ||
x_7: ['G+L, 'G+L+1] W, | ||
x_8: ['G+L, 'G+L+1] W, | ||
x_9: ['G+L, 'G+L+1] W, | ||
x_10: ['G+L, 'G+L+1] W, | ||
x_11: ['G+L, 'G+L+1] W, | ||
x_12: ['G+L, 'G+L+1] W, | ||
x_13: ['G+L, 'G+L+1] W, | ||
x_14: ['G+L, 'G+L+1] W, | ||
x_15: ['G+L, 'G+L+1] W, | ||
y_0: ['G+L, 'G+L+1] W, | ||
y_1: ['G+L, 'G+L+1] W, | ||
y_2: ['G+L, 'G+L+1] W, | ||
y_3: ['G+L, 'G+L+1] W, | ||
y_4: ['G+L, 'G+L+1] W, | ||
y_5: ['G+L, 'G+L+1] W, | ||
y_6: ['G+L, 'G+L+1] W, | ||
y_7: ['G+L, 'G+L+1] W, | ||
y_8: ['G+L, 'G+L+1] W, | ||
y_9: ['G+L, 'G+L+1] W, | ||
y_10: ['G+L, 'G+L+1] W, | ||
y_11: ['G+L, 'G+L+1] W, | ||
y_12: ['G+L, 'G+L+1] W, | ||
y_13: ['G+L, 'G+L+1] W, | ||
y_14: ['G+L, 'G+L+1] W, | ||
y_15: ['G+L, 'G+L+1] W, | ||
) with { | ||
let M = 8; | ||
let N = 16; | ||
let W = 32; | ||
let A = 2; | ||
some L where L > 0; | ||
some II where II > 0; | ||
} { | ||
Rotx := new Rot[W, N, M, A]; | ||
|
||
bundle u[N]: ['G, 'G+1] W; | ||
u{0} = u_0; u{1} = u_1; u{2} = u_2; u{3} = u_3; | ||
u{4} = u_4; u{5} = u_5; u{6} = u_6; u{7} = u_7; | ||
u{8} = u_8; u{9} = u_9; u{10} = u_10; u{11} = u_11; | ||
u{12} = u_12; u{13} = u_13; u{14} = u_14; u{15} = u_15; | ||
|
||
bundle v[N]: ['G, 'G+1] W; | ||
v{0} = v_0; v{1} = v_1; v{2} = v_2; v{3} = v_3; | ||
v{4} = v_4; v{5} = v_5; v{6} = v_6; v{7} = v_7; | ||
v{8} = v_8; v{9} = v_9; v{10} = v_10; v{11} = v_11; | ||
v{12} = v_12; v{13} = v_13; v{14} = v_14; v{15} = v_15; | ||
|
||
r := Rotx<'G>(c, s, u{0..N}, v{0..N}); | ||
|
||
bundle x[N]: ['G+Rotx::L, 'G+Rotx::L+1] W; | ||
x{0..N} = r.out_1{0..N}; | ||
x_0 = x{0}; | ||
x_1 = x{1}; | ||
x_2 = x{2}; | ||
x_3 = x{3}; | ||
x_4 = x{4}; | ||
x_5 = x{5}; | ||
x_6 = x{6}; | ||
x_7 = x{7}; | ||
x_8 = x{8}; | ||
x_9 = x{9}; | ||
x_10 = x{10}; | ||
x_11 = x{11}; | ||
x_12 = x{12}; | ||
x_13 = x{13}; | ||
x_14 = x{14}; | ||
x_15 = x{15}; | ||
|
||
bundle y[N]: ['G+Rotx::L, 'G+Rotx::L+1] W; | ||
y{0..N} = r.out_2{0..N}; | ||
y_0 = y{0}; | ||
y_1 = y{1}; | ||
y_2 = y{2}; | ||
y_3 = y{3}; | ||
y_4 = y{4}; | ||
y_5 = y{5}; | ||
y_6 = y{6}; | ||
y_7 = y{7}; | ||
y_8 = y{8}; | ||
y_9 = y{9}; | ||
y_10 = y{10}; | ||
y_11 = y{11}; | ||
y_12 = y{12}; | ||
y_13 = y{13}; | ||
y_14 = y{14}; | ||
y_15 = y{15}; | ||
|
||
L := Rotx::L; | ||
II := Rotx::II; | ||
} |
Oops, something went wrong.