Skip to content

Commit

Permalink
src/PositionalEmbeddings.jl: Align outputs with reference implementat…
Browse files Browse the repository at this point in the history
…ion and add sources.
  • Loading branch information
mashu committed Nov 24, 2024
1 parent fd4124c commit 575d3ec
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 114 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.3.0"

[deps]
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "5.5.2"
Expand All @@ -14,8 +13,9 @@ Zygote = "0.6.41"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "CUDA"]
test = ["Test", "CUDA", "Zygote"]
68 changes: 5 additions & 63 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This package provides various implementations of positional embeddings - techniq

Currently implemented:
- Rotary Position Embeddings (RoPE) - A method that encodes positions by rotating vectors in 2D subspaces
- Frequency Positional Embeddings (FrequencyPE) - The original positional encoding scheme from "Attention Is All You Need"
- Positional Embeddings (AbsolutePE) - The original positional encoding scheme from "Attention Is All You Need"

## API Reference

Expand All @@ -33,79 +33,21 @@ features = randn(Float32, 512, 100, 32)
features_with_pos = rope(features)
```

### Basic FrequencyPE Usage
### Basic AbsolutePE Usage

```julia
# Create embeddings for 128-dimensional features up to length 100
pe = FrequencyPE(128, 100)
pe = AbsolutePE(128, 100)

# Apply to input tensor of shape (features, seq_len, batch)
x = randn(Float32, 128, 100, 32)
x_positioned = pe(x)
```
![AbsolutePE](assets/AbsolutePE-128-100.svg)

### Example with Query/Key Matrices
## Rotary Position Embeddings (RoPE) with MultiHeadAttention from Flux

Here's a complete example showing how to use RoPE with attention mechanisms:

```julia
using Flux: MultiHeadAttention
using PositionalEmbeddings: RoPE, RoPEMultiHeadAttention
using NNlib: dot_product_attention

# Initialize RoPE with specific feature dimensions to rotate
dim = 64 # embedding dimension
nheads = 4 # number of attention heads
max_seq_len = 2048
seq_len = 100
batch_size = 32

# Create attention layer with RoPE
rmha = RoPEMultiHeadAttention(
dim, # embedding dimension
nheads; # number of heads
max_seq_len = max_seq_len,
rope_fraction = 1.0 # apply RoPE to all features
)

# Sample input tensors (dim, seq_len, batch)
q_in = randn(Float32, dim, seq_len, batch_size)
k_in = randn(Float32, dim, seq_len, batch_size)
v_in = randn(Float32, dim, seq_len, batch_size)

# The forward pass will:
# 1. Project inputs
mha = rmha.mha
q = mha.q_proj(q_in)
k = mha.k_proj(k_in)
v = mha.v_proj(v_in)

# 2. Apply RoPE to queries and keys
q = rmha.rope(q)
k = rmha.rope(k)

# 3. Compute attention
bias = nothing
mask = nothing
x, α = dot_product_attention(q, k, v, bias;
nheads=mha.nheads,
mask=mask,
fdrop=mha.attn_drop)

# 4. Project output
output = mha.out_proj(x)

# Alternatively, use the provided wrapper:
output, attention_weights = rmha(q_in, k_in, v_in)

# Or for self-attention:
output, attention_weights = rmha(q_in)
```

## Flux Integration

The package provides a wrapper type `RoPEMultiHeadAttention` that adds Rotary Position Embeddings to Flux's MultiHeadAttention. Here's the complete implementation:
This example can add `RoPEMultiHeadAttention` that with Rotary Position Embeddings to Flux's MultiHeadAttention. Here's the complete implementation:

```julia
using Flux: MultiHeadAttention
Expand Down
4 changes: 2 additions & 2 deletions src/PositionalEmbeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ module PositionalEmbeddings
"""
function neg_half(x::AbstractArray{T}, dim::Int=1) where T
d_2 = size(x, dim) ÷ 2
vcat(-view(x, d_2+1:size(x,dim), :, :),
view(x, 1:d_2, :, :))
return vcat(view(x, d_2+1:size(x,dim), :, :) .* -1,
view(x, 1:d_2, :, :))
end

"""
Expand Down
207 changes: 161 additions & 46 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,156 @@ using PositionalEmbeddings
function has_working_cuda()
return CUDA.has_cuda() && CUDA.functional()
end

# # https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py
# # Code from the above link was used to generate the test data
# import torch
# import numpy as np

# base = 10000
# d = 16
# seq_len = 10 # Changed from 16 to 10

# # Calculate frequencies
# theta = 1. / (base ** (torch.arange(0, d, 2).float() / d))
# seq_idx = torch.arange(seq_len).float() # Now 0-9 instead of 0-15
# idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
# idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

# # Cache sin/cos values - now 10x1x1x16 instead of 16x1x1x16
# cos_cached = idx_theta2.cos()[:, None, None, :]
# sin_cached = idx_theta2.sin()[:, None, None, :]

# # Input will be 16x10x1 instead of 16x16x1
# input_original = torch.arange(1, 16*10 + 1, dtype=torch.float32).reshape(16, 10, 1)
# x = input_original.permute(1, 2, 0).unsqueeze(1) # [10, 1, 1, 16]

# # Split features
# d_2 = d // 2
# x_rope, x_pass = x[..., :d], x[..., d:]

# # Calculate neg_half
# neg_half_x = torch.cat([-x_rope[:, :, :, d_2:], x_rope[:, :, :, :d_2]], dim=-1)

# # Final output
# final_output = (x_rope * cos_cached[:x.shape[0]]) + (neg_half_x * sin_cached[:x.shape[0]])

# # Save results
# np.savez('rope_test_data.npz',
# input_original=input_original,
# input_permuted=x.numpy(),
# theta=theta.numpy(),
# idx_theta=idx_theta.numpy(),
# idx_theta2=idx_theta2.numpy(),
# cos_cached=cos_cached.numpy(),
# sin_cached=sin_cached.numpy(),
# neg_half=neg_half_x.numpy(),
# final_output=final_output.numpy())

@testset "RoPE Tests" begin
@testset "Cached Values Test" begin
features, seq_len = 16, 10
rope = RoPE(features, seq_len)

expected_cos = reshape([
1.0 0.540302 -0.416147 -0.989992 -0.653644 0.283662 0.96017 0.753902 -0.1455 -0.91113;
1.0 0.950415 0.806578 0.582754 0.301137 -0.0103423 -0.320796 -0.599437 -0.818632 -0.956644;
1.0 0.995004 0.980067 0.955337 0.921061 0.877583 0.825336 0.764842 0.696707 0.62161;
1.0 0.9995 0.998001 0.995503 0.992011 0.987526 0.982054 0.9756 0.96817 0.959773;
1.0 0.99995 0.9998 0.99955 0.9992 0.99875 0.998201 0.997551 0.996802 0.995953;
1.0 0.999995 0.99998 0.999955 0.99992 0.999875 0.99982 0.999755 0.99968 0.999595;
1.0 1.0 0.999998 0.999996 0.999992 0.999987 0.999982 0.999976 0.999968 0.99996;
1.0 1.0 1.0 1.0 0.999999 0.999999 0.999998 0.999998 0.999997 0.999996;
1.0 0.540302 -0.416147 -0.989992 -0.653644 0.283662 0.96017 0.753902 -0.1455 -0.91113;
1.0 0.950415 0.806578 0.582754 0.301137 -0.0103423 -0.320796 -0.599437 -0.818632 -0.956644;
1.0 0.995004 0.980067 0.955337 0.921061 0.877583 0.825336 0.764842 0.696707 0.62161;
1.0 0.9995 0.998001 0.995503 0.992011 0.987526 0.982054 0.9756 0.96817 0.959773;
1.0 0.99995 0.9998 0.99955 0.9992 0.99875 0.998201 0.997551 0.996802 0.995953;
1.0 0.999995 0.99998 0.999955 0.99992 0.999875 0.99982 0.999755 0.99968 0.999595;
1.0 1.0 0.999998 0.999996 0.999992 0.999987 0.999982 0.999976 0.999968 0.99996;
1.0 1.0 1.0 1.0 0.999999 0.999999 0.999998 0.999998 0.999997 0.999996
], (16,10,1))

expected_sin = reshape([
0.0 0.841471 0.909297 0.14112 -0.756802 -0.958924 -0.279415 0.656987 0.989358 0.412118;
0.0 0.310984 0.591127 0.812649 0.953581 0.999947 0.947148 0.800422 0.574318 0.291259;
0.0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327;
0.0 0.0316175 0.0632034 0.0947261 0.126154 0.157456 0.1886 0.219556 0.250292 0.280778;
0.0 0.00999983 0.0199987 0.0299955 0.0399893 0.0499792 0.059964 0.0699428 0.0799147 0.0898785;
0.0 0.00316227 0.00632451 0.00948669 0.0126488 0.0158107 0.0189725 0.0221341 0.0252955 0.0284567;
0.0 0.001 0.002 0.003 0.00399999 0.00499998 0.00599996 0.00699994 0.00799991 0.00899988;
0.0 0.000316228 0.000632456 0.000948683 0.00126491 0.00158114 0.00189737 0.00221359 0.00252982 0.00284605;
0.0 0.841471 0.909297 0.14112 -0.756802 -0.958924 -0.279415 0.656987 0.989358 0.412118;
0.0 0.310984 0.591127 0.812649 0.953581 0.999947 0.947148 0.800422 0.574318 0.291259;
0.0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327;
0.0 0.0316175 0.0632034 0.0947261 0.126154 0.157456 0.1886 0.219556 0.250292 0.280778;
0.0 0.00999983 0.0199987 0.0299955 0.0399893 0.0499792 0.059964 0.0699428 0.0799147 0.0898785;
0.0 0.00316227 0.00632451 0.00948669 0.0126488 0.0158107 0.0189725 0.0221341 0.0252955 0.0284567;
0.0 0.001 0.002 0.003 0.00399999 0.00499998 0.00599996 0.00699994 0.00799991 0.00899988;
0.0 0.000316228 0.000632456 0.000948683 0.00126491 0.00158114 0.00189737 0.00221359 0.00252982 0.00284605
], (16,10,1))

@test isapprox(rope.cos_cached, expected_cos)
@test isapprox(rope.sin_cached, expected_sin)
end

@testset "neg_half Function Test" begin
x = x = reshape([
1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0;
11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0;
21.0 22.0 23.0 24.0 25.0 26.0 27.0 28.0 29.0 30.0;
31.0 32.0 33.0 34.0 35.0 36.0 37.0 38.0 39.0 40.0;
41.0 42.0 43.0 44.0 45.0 46.0 47.0 48.0 49.0 50.0;
51.0 52.0 53.0 54.0 55.0 56.0 57.0 58.0 59.0 60.0;
61.0 62.0 63.0 64.0 65.0 66.0 67.0 68.0 69.0 70.0;
71.0 72.0 73.0 74.0 75.0 76.0 77.0 78.0 79.0 80.0;
81.0 82.0 83.0 84.0 85.0 86.0 87.0 88.0 89.0 90.0;
91.0 92.0 93.0 94.0 95.0 96.0 97.0 98.0 99.0 100.0;
101.0 102.0 103.0 104.0 105.0 106.0 107.0 108.0 109.0 110.0;
111.0 112.0 113.0 114.0 115.0 116.0 117.0 118.0 119.0 120.0;
121.0 122.0 123.0 124.0 125.0 126.0 127.0 128.0 129.0 130.0;
131.0 132.0 133.0 134.0 135.0 136.0 137.0 138.0 139.0 140.0;
141.0 142.0 143.0 144.0 145.0 146.0 147.0 148.0 149.0 150.0;
151.0 152.0 153.0 154.0 155.0 156.0 157.0 158.0 159.0 160.0
], (16,10,1))

expected_neg_half = reshape([
-81.0 -82.0 -83.0 -84.0 -85.0 -86.0 -87.0 -88.0 -89.0 -90.0;
-91.0 -92.0 -93.0 -94.0 -95.0 -96.0 -97.0 -98.0 -99.0 -100.0;
-101.0 -102.0 -103.0 -104.0 -105.0 -106.0 -107.0 -108.0 -109.0 -110.0;
-111.0 -112.0 -113.0 -114.0 -115.0 -116.0 -117.0 -118.0 -119.0 -120.0;
-121.0 -122.0 -123.0 -124.0 -125.0 -126.0 -127.0 -128.0 -129.0 -130.0;
-131.0 -132.0 -133.0 -134.0 -135.0 -136.0 -137.0 -138.0 -139.0 -140.0;
-141.0 -142.0 -143.0 -144.0 -145.0 -146.0 -147.0 -148.0 -149.0 -150.0;
-151.0 -152.0 -153.0 -154.0 -155.0 -156.0 -157.0 -158.0 -159.0 -160.0;
1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0;
11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0;
21.0 22.0 23.0 24.0 25.0 26.0 27.0 28.0 29.0 30.0;
31.0 32.0 33.0 34.0 35.0 36.0 37.0 38.0 39.0 40.0;
41.0 42.0 43.0 44.0 45.0 46.0 47.0 48.0 49.0 50.0;
51.0 52.0 53.0 54.0 55.0 56.0 57.0 58.0 59.0 60.0;
61.0 62.0 63.0 64.0 65.0 66.0 67.0 68.0 69.0 70.0;
71.0 72.0 73.0 74.0 75.0 76.0 77.0 78.0 79.0 80.0
], (16,10,1))

@test isapprox(PositionalEmbeddings.neg_half(x), expected_neg_half)
end

@testset "Forward Pass Test" begin
x = reshape(Float32.(1:160), (16,10,1))
pe = RoPE(16, 10)

# Manual calculation
neg_half_x = PositionalEmbeddings.neg_half(x)
cos_mat = view(pe.cos_cached, 1:size(x,1), 1:size(x,2), :)
sin_mat = view(pe.sin_cached, 1:size(x,1), 1:size(x,2), :)
expected_output = @. muladd(x * pe.scale, cos_mat, neg_half_x * pe.scale * sin_mat)

# Test the forward pass
actual_output = pe(x)
@test isapprox(actual_output, expected_output)
end

@testset "Gradient Tests (CPU, Float64)" begin
eps = 1e-8
rope = RoPE(8, 4; T=Float64)
Expand All @@ -30,27 +179,6 @@ end
@test isapprox(analytical_grad, numerical_grad, atol=1e-4)
end

@testset "Reference Output Test (CPU)" begin
features, seq_len, batch_size = 16, 10, 64
x = Float32.(reshape(collect(1:features*seq_len*batch_size),
features, seq_len, batch_size)) ./
(features*seq_len*batch_size)

reference = Float32[
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281;
-0.00115739 0.000881045 0.00158297 0.00186569 0.00202236;
-0.00498184 0.000253548 0.00251558 0.00323702 0.00352467;
-0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477;
0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219
]

rope = RoPE(features, seq_len)
output = rope(x)
result = output[1:5,1:5,1]'

@test all(isapprox.(result, reference, atol=1e-4))
end

if has_working_cuda()
@testset "GPU Tests" begin
@testset "Gradient Computation (GPU, Float32)" begin
Expand All @@ -63,37 +191,24 @@ end
)
x = CUDA.randn(Float32, 8, 4, 1)

# Just verify that gradient computation doesn't error
loss(x) = sum(abs2, rope_gpu(x))
@test_nowarn gradient(loss, x)
end

@testset "Reference Output Test (GPU)" begin
features, seq_len, batch_size = 16, 10, 64
x = Float32.(reshape(collect(1:features*seq_len*batch_size),
features, seq_len, batch_size)) ./
(features*seq_len*batch_size)
x = cu(x)

reference = Float32[
9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281;
-0.00115739 0.000881045 0.00158297 0.00186569 0.00202236;
-0.00498184 0.000253548 0.00251558 0.00323702 0.00352467;
-0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477;
0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219
]

rope = RoPE(features, seq_len)
rope_gpu = RoPE(
rope.features,
cu(rope.cos_cached),
cu(rope.sin_cached),
rope.scale
@testset "Forward Pass (GPU)" begin
x = reshape(Float32.(1:160), (16,10,1))
x_gpu = cu(x)
pe = RoPE(16, 10)
pe_gpu = RoPE(
pe.features,
cu(pe.cos_cached),
cu(pe.sin_cached),
pe.scale
)
output = rope_gpu(x)
result = Array(output[1:5,1:5,1]')

@test all(isapprox.(result, reference, atol=1e-4))
cpu_output = pe(x)
gpu_output = Array(pe_gpu(x_gpu))
@test isapprox(cpu_output, gpu_output)
end
end
else
Expand Down

2 comments on commit 575d3ec

@mashu
Copy link
Owner Author

@mashu mashu commented on 575d3ec Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120106

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 575d3ec7093bc0af54988c18acfaf62d3a84ed34
git push origin v0.3.0

Please sign in to comment.