Skip to content

Commit

Permalink
rot kernel (#438)
Browse files Browse the repository at this point in the history
* start

* works but could be better

* better throughput alt

* comment

* clean
  • Loading branch information
gabizon103 authored May 28, 2024
1 parent 45044c2 commit cfebc86
Show file tree
Hide file tree
Showing 8 changed files with 434 additions and 20 deletions.
131 changes: 131 additions & 0 deletions apps/blas/rot/rot-alt.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import "apps/blas/util.fil";
import "primitives/signed.fil";

// Applies a rotation to vectors x, y
// W: element width
// N: vector length
// M: multiplier amount
// A: adder amount
// Tiles scalar-vector products based on multiplier count,
// then does additions as multiplies finish
comp Rot[W, N, M, A]<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W,
) -> (
out_1[N]: ['G+L, 'G+L+1] W,
out_2[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
some II where II > 0;
} where W > 0,
N > 0,
M > 0,
A > 0,
M % 4 == 0, // need to do at least 4 multiplies at once
A % 2 == 0, // need to do at least 2 adds at once
N % (M/4) == 0,
(M/4) % (A/2) == 0
{
// partition mults into 4 groups, for 4 mults that need to happen at each time step
let m = M/4;

// same for adds, but only 2 adds to do
let a = A/2;

// reuse for a single scalar-vector computation
let mult_reuses = N / m;

// reuse for each of the adder groups
let add_reuses = m / a;

// dummy so we can get its params
M_ := new Multipliers[W, m];
let mult_latency = M_::L;
let mult_ii = M_::II;

A_ := new Adders[W, a];
let add_ii = A_::II;

// -s
negs := new Neg[W]<'G>(s);

// instantiate multipliers
let last_mult_invoke = (mult_reuses)*mult_ii;
M_cy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_sy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_cx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_nsx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];

// instantiate adders
let add_end = last_mult_invoke + mult_latency + add_reuses*add_ii + (add_reuses-1)*(mult_reuses-1) + 1;
A_1 := new Adders[W, a] in ['G+mult_latency, 'G+add_end];
A_2 := new Adders[W, a] in ['G+mult_latency, 'G+add_end];

// check which stage is limiting the pipeline
let ii = if (add_end-mult_latency) < (last_mult_invoke) {(last_mult_invoke)} else {(add_end-mult_latency)};

bundle cy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle sy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle cx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle nsx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;

// scalar bundles for multiplications
bundle c_bundle[m]: ['G, 'G+1] W;
bundle s_bundle[m]: ['G, 'G+1] W;
bundle negs_bundle[m]: ['G, 'G+1] W;

// fill them
for i in 0..m {
c_bundle{i} = c;
s_bundle{i} = s;
negs_bundle{i} = negs.out;
}

// start multiplications
for i in 0..mult_reuses {
// some parameters
let mult_start = i*mult_ii;
let mult_end = i*mult_ii + mult_latency;

// register inputs
x_reg := new Shift[W, i*mult_ii, m]<'G>(x{i*m..(i+1)*m});
y_reg := new Shift[W, i*mult_ii, m]<'G>(y{i*m..(i+1)*m});
c_reg := new Shift[W, i*mult_ii, m]<'G>(c_bundle{0..m});
s_reg := new Shift[W, i*mult_ii, m]<'G>(s_bundle{0..m});
negs_reg := new Shift[W, i*mult_ii, m]<'G>(negs_bundle{0..m});

mult_cy := M_cy<'G+i*mult_ii>(c_reg.out{0..m}, y_reg.out{0..m});
mult_sy := M_sy<'G+i*mult_ii>(s_reg.out{0..m}, y_reg.out{0..m});
mult_cx := M_cx<'G+i*mult_ii>(c_reg.out{0..m}, x_reg.out{0..m});
mult_nsx := M_nsx<'G+i*mult_ii>(negs_reg.out{0..m}, x_reg.out{0..m});

cy{i}{0..m} = mult_cy.out{0..m};
sy{i}{0..m} = mult_sy.out{0..m};
cx{i}{0..m} = mult_cx.out{0..m};
nsx{i}{0..m} = mult_nsx.out{0..m};

for j in 0..add_reuses {
let offset = i*(add_reuses-1);
mult_cy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cy{i}{(j*a)..(j+1)*a});
mult_sy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(sy{i}{(j*a)..(j+1)*a});
mult_cx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cx{i}{(j*a)..(j+1)*a});
mult_nsx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(nsx{i}{(j*a)..(j+1)*a});

add_1 := A_1<'G + mult_end + j*add_ii + offset>(mult_cx_reg.out{0..a}, mult_sy_reg.out{0..a});
add_2 := A_2<'G + mult_end + j*add_ii + offset>(mult_nsx_reg.out{0..a}, mult_cy_reg.out{0..a});

add_1_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_1.out{0..a});
add_2_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_2.out{0..a});

out_1{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_1_reg.out{0..a};
out_2{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_2_reg.out{0..a};
}
}

let latency = (mult_reuses*mult_ii + mult_latency) + (add_reuses-1)*(mult_reuses-1) + (add_reuses-1)*add_ii;
L := latency;
II := ii;
}
94 changes: 94 additions & 0 deletions apps/blas/rot/rot.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import "primitives/core.fil";
import "apps/blas/scal/scal.fil";
import "apps/blas/util.fil";
import "primitives/signed.fil";

// Applies a rotation to vectors x, y
// W: element width
// N: vector length
// M: multiplier amount
// A: adder amount
// Uses scal to compute the scalar-vector products cy, sy, cx, -sx
comp Rot[W, N, M, A]<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W,
) -> (
out_1[N]: ['G+L, 'G+L+1] W,
out_2[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
some II where II > 0;
} where W > 0,
L > 0,
N > 0,
M > 0,
A > 0,
M % 4 == 0,
A % 2 == 0
{

scalex := new Scal[W, N, M/4];
let scale_latency = scalex::L;
let scale_ii = scalex::II;

zero := new Const[W, 0]<'G>();
neg_s := new Neg[W]<'G>(s);

bundle cy[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle sy[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle cx[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle msx[N]: ['G+scale_latency, 'G+scale_latency+1] W;

SCY := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_cy := SCY<'G>(y{0..N}, c);
cy{0..N} = scale_cy.out{0..N};

SSY := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_sy := SSY<'G>(y{0..N}, s);
sy{0..N} = scale_sy.out{0..N};

SCX := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_cx := SCX<'G>(x{0..N}, c);
cx{0..N} = scale_cx.out{0..N};

SMSX := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_msx := SMSX<'G>(x{0..N}, neg_s.out);
msx{0..N} = scale_msx.out{0..N};

// out_1{i} <- cx{i} + sy{i}
// out_2{i} <- msx{i} + cy{i}

let add_uses = N / (A/2);
let add_ii = 1;

// use half the adders for x, half for y
A_x := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1];
A_y := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1];

let latency = scale_latency + (add_uses-1) * add_ii;
for k in 0..add_uses {
// save chunked arrays based on when we are ready to add them
cx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cx{k*(A/2)..(k+1)*(A/2)});
sy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(sy{k*(A/2)..(k+1)*(A/2)});
cy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cy{k*(A/2)..(k+1)*(A/2)});
msx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(msx{k*(A/2)..(k+1)*(A/2)});

ax := A_x<'G + scale_latency + k*add_ii>(cx_reg.out{0..(A/2)}, sy_reg.out{0..(A/2)});
ay := A_y<'G + scale_latency + k*add_ii>(msx_reg.out{0..(A/2)}, cy_reg.out{0..(A/2)});

// save add result
ax_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ax.out{0..(A/2)});
ay_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ay.out{0..(A/2)});

out_1{k*(A/2)..(k+1)*(A/2)} = ax_reg.out{0..(A/2)};
out_2{k*(A/2)..(k+1)*(A/2)} = ay_reg.out{0..(A/2)};
}

L := latency;
// this is a thing we can do now?
let ii = if (add_uses*add_ii) > (scale_ii) {add_uses*add_ii} else {scale_ii};
II := ii;
}
29 changes: 29 additions & 0 deletions apps/blas/rot/sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# for determining the right answer

def rot(u, v, c, s):
assert len(u) == len(v)
x = [0] * len(u)
y = [0] * len(v)
for i in range(len(u)):
x[i] = c*u[i] + s*v[i]
y[i] = (-s)*u[i] + c*v[i]
return (x,y)

if __name__ == '__main__':
u0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
v0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
c0 = 1
s0 = 1
(x0, y0) = rot(u0, v0, c0, s0)
print(f"x0: {x0}")
print(f"y0: {y0}")

print("\n======================\n")

u1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
v1 = [16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
c1 = 1
s1 = 1
(x1, y1) = rot(u1, v1, c1, s1)
print(f"x1: {x1}")
print(f"y1: {y1}")
136 changes: 136 additions & 0 deletions apps/blas/rot/test.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import "apps/blas/rot/rot-alt.fil";

comp main<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
u_0: ['G, 'G+1] W,
u_1: ['G, 'G+1] W,
u_2: ['G, 'G+1] W,
u_3: ['G, 'G+1] W,
u_4: ['G, 'G+1] W,
u_5: ['G, 'G+1] W,
u_6: ['G, 'G+1] W,
u_7: ['G, 'G+1] W,
u_8: ['G, 'G+1] W,
u_9: ['G, 'G+1] W,
u_10: ['G, 'G+1] W,
u_11: ['G, 'G+1] W,
u_12: ['G, 'G+1] W,
u_13: ['G, 'G+1] W,
u_14: ['G, 'G+1] W,
u_15: ['G, 'G+1] W,
v_0: ['G, 'G+1] W,
v_1: ['G, 'G+1] W,
v_2: ['G, 'G+1] W,
v_3: ['G, 'G+1] W,
v_4: ['G, 'G+1] W,
v_5: ['G, 'G+1] W,
v_6: ['G, 'G+1] W,
v_7: ['G, 'G+1] W,
v_8: ['G, 'G+1] W,
v_9: ['G, 'G+1] W,
v_10: ['G, 'G+1] W,
v_11: ['G, 'G+1] W,
v_12: ['G, 'G+1] W,
v_13: ['G, 'G+1] W,
v_14: ['G, 'G+1] W,
v_15: ['G, 'G+1] W,
) -> (
x_0: ['G+L, 'G+L+1] W,
x_1: ['G+L, 'G+L+1] W,
x_2: ['G+L, 'G+L+1] W,
x_3: ['G+L, 'G+L+1] W,
x_4: ['G+L, 'G+L+1] W,
x_5: ['G+L, 'G+L+1] W,
x_6: ['G+L, 'G+L+1] W,
x_7: ['G+L, 'G+L+1] W,
x_8: ['G+L, 'G+L+1] W,
x_9: ['G+L, 'G+L+1] W,
x_10: ['G+L, 'G+L+1] W,
x_11: ['G+L, 'G+L+1] W,
x_12: ['G+L, 'G+L+1] W,
x_13: ['G+L, 'G+L+1] W,
x_14: ['G+L, 'G+L+1] W,
x_15: ['G+L, 'G+L+1] W,
y_0: ['G+L, 'G+L+1] W,
y_1: ['G+L, 'G+L+1] W,
y_2: ['G+L, 'G+L+1] W,
y_3: ['G+L, 'G+L+1] W,
y_4: ['G+L, 'G+L+1] W,
y_5: ['G+L, 'G+L+1] W,
y_6: ['G+L, 'G+L+1] W,
y_7: ['G+L, 'G+L+1] W,
y_8: ['G+L, 'G+L+1] W,
y_9: ['G+L, 'G+L+1] W,
y_10: ['G+L, 'G+L+1] W,
y_11: ['G+L, 'G+L+1] W,
y_12: ['G+L, 'G+L+1] W,
y_13: ['G+L, 'G+L+1] W,
y_14: ['G+L, 'G+L+1] W,
y_15: ['G+L, 'G+L+1] W,
) with {
let M = 8;
let N = 16;
let W = 32;
let A = 2;
some L where L > 0;
some II where II > 0;
} {
Rotx := new Rot[W, N, M, A];

bundle u[N]: ['G, 'G+1] W;
u{0} = u_0; u{1} = u_1; u{2} = u_2; u{3} = u_3;
u{4} = u_4; u{5} = u_5; u{6} = u_6; u{7} = u_7;
u{8} = u_8; u{9} = u_9; u{10} = u_10; u{11} = u_11;
u{12} = u_12; u{13} = u_13; u{14} = u_14; u{15} = u_15;

bundle v[N]: ['G, 'G+1] W;
v{0} = v_0; v{1} = v_1; v{2} = v_2; v{3} = v_3;
v{4} = v_4; v{5} = v_5; v{6} = v_6; v{7} = v_7;
v{8} = v_8; v{9} = v_9; v{10} = v_10; v{11} = v_11;
v{12} = v_12; v{13} = v_13; v{14} = v_14; v{15} = v_15;

r := Rotx<'G>(c, s, u{0..N}, v{0..N});

bundle x[N]: ['G+Rotx::L, 'G+Rotx::L+1] W;
x{0..N} = r.out_1{0..N};
x_0 = x{0};
x_1 = x{1};
x_2 = x{2};
x_3 = x{3};
x_4 = x{4};
x_5 = x{5};
x_6 = x{6};
x_7 = x{7};
x_8 = x{8};
x_9 = x{9};
x_10 = x{10};
x_11 = x{11};
x_12 = x{12};
x_13 = x{13};
x_14 = x{14};
x_15 = x{15};

bundle y[N]: ['G+Rotx::L, 'G+Rotx::L+1] W;
y{0..N} = r.out_2{0..N};
y_0 = y{0};
y_1 = y{1};
y_2 = y{2};
y_3 = y{3};
y_4 = y{4};
y_5 = y{5};
y_6 = y{6};
y_7 = y{7};
y_8 = y{8};
y_9 = y{9};
y_10 = y{10};
y_11 = y{11};
y_12 = y{12};
y_13 = y{13};
y_14 = y{14};
y_15 = y{15};

L := Rotx::L;
II := Rotx::II;
}
Loading

0 comments on commit cfebc86

Please sign in to comment.