diff --git a/Project.toml b/Project.toml index b8bec1b..85a7d3d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.3.0" [deps] Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "5.5.2" @@ -14,8 +13,9 @@ Zygote = "0.6.41" julia = "1.9" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "CUDA"] +test = ["Test", "CUDA", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 7cff46f..7277790 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,7 +10,7 @@ This package provides various implementations of positional embeddings - techniq Currently implemented: - Rotary Position Embeddings (RoPE) - A method that encodes positions by rotating vectors in 2D subspaces -- Frequency Positional Embeddings (FrequencyPE) - The original positional encoding scheme from "Attention Is All You Need" +- Positional Embeddings (AbsolutePE) - The original positional encoding scheme from "Attention Is All You Need" ## API Reference @@ -33,11 +33,11 @@ features = randn(Float32, 512, 100, 32) features_with_pos = rope(features) ``` -### Basic FrequencyPE Usage +### Basic AbsolutePE Usage ```julia # Create embeddings for 128-dimensional features up to length 100 -pe = FrequencyPE(128, 100) +pe = AbsolutePE(128, 100) # Apply to input tensor of shape (features, seq_len, batch) x = randn(Float32, 128, 100, 32) @@ -45,67 +45,9 @@ x_positioned = pe(x) ``` ![AbsolutePE](assets/AbsolutePE-128-100.svg) -### Example with Query/Key Matrices +## Rotary Position Embeddings (RoPE) with MultiHeadAttention from Flux -Here's a complete example showing how to use RoPE with attention mechanisms: - -```julia -using Flux: MultiHeadAttention -using PositionalEmbeddings: RoPE, RoPEMultiHeadAttention -using NNlib: dot_product_attention - -# Initialize RoPE with specific feature dimensions to rotate -dim = 64 # embedding dimension -nheads = 4 # number of attention heads -max_seq_len = 2048 -seq_len = 100 -batch_size = 32 - -# Create attention layer with RoPE -rmha = RoPEMultiHeadAttention( - dim, # embedding dimension - nheads; # number of heads - max_seq_len = max_seq_len, - rope_fraction = 1.0 # apply RoPE to all features -) - -# Sample input tensors (dim, seq_len, batch) -q_in = randn(Float32, dim, seq_len, batch_size) -k_in = randn(Float32, dim, seq_len, batch_size) -v_in = randn(Float32, dim, seq_len, batch_size) - -# The forward pass will: -# 1. Project inputs -mha = rmha.mha -q = mha.q_proj(q_in) -k = mha.k_proj(k_in) -v = mha.v_proj(v_in) - -# 2. Apply RoPE to queries and keys -q = rmha.rope(q) -k = rmha.rope(k) - -# 3. Compute attention -bias = nothing -mask = nothing -x, α = dot_product_attention(q, k, v, bias; - nheads=mha.nheads, - mask=mask, - fdrop=mha.attn_drop) - -# 4. Project output -output = mha.out_proj(x) - -# Alternatively, use the provided wrapper: -output, attention_weights = rmha(q_in, k_in, v_in) - -# Or for self-attention: -output, attention_weights = rmha(q_in) -``` - -## Flux Integration - -The package provides a wrapper type `RoPEMultiHeadAttention` that adds Rotary Position Embeddings to Flux's MultiHeadAttention. Here's the complete implementation: +This example can add `RoPEMultiHeadAttention` that with Rotary Position Embeddings to Flux's MultiHeadAttention. Here's the complete implementation: ```julia using Flux: MultiHeadAttention diff --git a/src/PositionalEmbeddings.jl b/src/PositionalEmbeddings.jl index 7d6e61a..b94f271 100644 --- a/src/PositionalEmbeddings.jl +++ b/src/PositionalEmbeddings.jl @@ -139,8 +139,8 @@ module PositionalEmbeddings """ function neg_half(x::AbstractArray{T}, dim::Int=1) where T d_2 = size(x, dim) ÷ 2 - vcat(-view(x, d_2+1:size(x,dim), :, :), - view(x, 1:d_2, :, :)) + return vcat(view(x, d_2+1:size(x,dim), :, :) .* -1, + view(x, 1:d_2, :, :)) end """ diff --git a/test/runtests.jl b/test/runtests.jl index f1bd855..4d277a2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,156 @@ using PositionalEmbeddings function has_working_cuda() return CUDA.has_cuda() && CUDA.functional() end + +# # https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py +# # Code from the above link was used to generate the test data +# import torch +# import numpy as np + +# base = 10000 +# d = 16 +# seq_len = 10 # Changed from 16 to 10 + +# # Calculate frequencies +# theta = 1. / (base ** (torch.arange(0, d, 2).float() / d)) +# seq_idx = torch.arange(seq_len).float() # Now 0-9 instead of 0-15 +# idx_theta = torch.einsum('n,d->nd', seq_idx, theta) +# idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + +# # Cache sin/cos values - now 10x1x1x16 instead of 16x1x1x16 +# cos_cached = idx_theta2.cos()[:, None, None, :] +# sin_cached = idx_theta2.sin()[:, None, None, :] + +# # Input will be 16x10x1 instead of 16x16x1 +# input_original = torch.arange(1, 16*10 + 1, dtype=torch.float32).reshape(16, 10, 1) +# x = input_original.permute(1, 2, 0).unsqueeze(1) # [10, 1, 1, 16] + +# # Split features +# d_2 = d // 2 +# x_rope, x_pass = x[..., :d], x[..., d:] + +# # Calculate neg_half +# neg_half_x = torch.cat([-x_rope[:, :, :, d_2:], x_rope[:, :, :, :d_2]], dim=-1) + +# # Final output +# final_output = (x_rope * cos_cached[:x.shape[0]]) + (neg_half_x * sin_cached[:x.shape[0]]) + +# # Save results +# np.savez('rope_test_data.npz', +# input_original=input_original, +# input_permuted=x.numpy(), +# theta=theta.numpy(), +# idx_theta=idx_theta.numpy(), +# idx_theta2=idx_theta2.numpy(), +# cos_cached=cos_cached.numpy(), +# sin_cached=sin_cached.numpy(), +# neg_half=neg_half_x.numpy(), +# final_output=final_output.numpy()) + @testset "RoPE Tests" begin + @testset "Cached Values Test" begin + features, seq_len = 16, 10 + rope = RoPE(features, seq_len) + + expected_cos = reshape([ + 1.0 0.540302 -0.416147 -0.989992 -0.653644 0.283662 0.96017 0.753902 -0.1455 -0.91113; + 1.0 0.950415 0.806578 0.582754 0.301137 -0.0103423 -0.320796 -0.599437 -0.818632 -0.956644; + 1.0 0.995004 0.980067 0.955337 0.921061 0.877583 0.825336 0.764842 0.696707 0.62161; + 1.0 0.9995 0.998001 0.995503 0.992011 0.987526 0.982054 0.9756 0.96817 0.959773; + 1.0 0.99995 0.9998 0.99955 0.9992 0.99875 0.998201 0.997551 0.996802 0.995953; + 1.0 0.999995 0.99998 0.999955 0.99992 0.999875 0.99982 0.999755 0.99968 0.999595; + 1.0 1.0 0.999998 0.999996 0.999992 0.999987 0.999982 0.999976 0.999968 0.99996; + 1.0 1.0 1.0 1.0 0.999999 0.999999 0.999998 0.999998 0.999997 0.999996; + 1.0 0.540302 -0.416147 -0.989992 -0.653644 0.283662 0.96017 0.753902 -0.1455 -0.91113; + 1.0 0.950415 0.806578 0.582754 0.301137 -0.0103423 -0.320796 -0.599437 -0.818632 -0.956644; + 1.0 0.995004 0.980067 0.955337 0.921061 0.877583 0.825336 0.764842 0.696707 0.62161; + 1.0 0.9995 0.998001 0.995503 0.992011 0.987526 0.982054 0.9756 0.96817 0.959773; + 1.0 0.99995 0.9998 0.99955 0.9992 0.99875 0.998201 0.997551 0.996802 0.995953; + 1.0 0.999995 0.99998 0.999955 0.99992 0.999875 0.99982 0.999755 0.99968 0.999595; + 1.0 1.0 0.999998 0.999996 0.999992 0.999987 0.999982 0.999976 0.999968 0.99996; + 1.0 1.0 1.0 1.0 0.999999 0.999999 0.999998 0.999998 0.999997 0.999996 + ], (16,10,1)) + + expected_sin = reshape([ + 0.0 0.841471 0.909297 0.14112 -0.756802 -0.958924 -0.279415 0.656987 0.989358 0.412118; + 0.0 0.310984 0.591127 0.812649 0.953581 0.999947 0.947148 0.800422 0.574318 0.291259; + 0.0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327; + 0.0 0.0316175 0.0632034 0.0947261 0.126154 0.157456 0.1886 0.219556 0.250292 0.280778; + 0.0 0.00999983 0.0199987 0.0299955 0.0399893 0.0499792 0.059964 0.0699428 0.0799147 0.0898785; + 0.0 0.00316227 0.00632451 0.00948669 0.0126488 0.0158107 0.0189725 0.0221341 0.0252955 0.0284567; + 0.0 0.001 0.002 0.003 0.00399999 0.00499998 0.00599996 0.00699994 0.00799991 0.00899988; + 0.0 0.000316228 0.000632456 0.000948683 0.00126491 0.00158114 0.00189737 0.00221359 0.00252982 0.00284605; + 0.0 0.841471 0.909297 0.14112 -0.756802 -0.958924 -0.279415 0.656987 0.989358 0.412118; + 0.0 0.310984 0.591127 0.812649 0.953581 0.999947 0.947148 0.800422 0.574318 0.291259; + 0.0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327; + 0.0 0.0316175 0.0632034 0.0947261 0.126154 0.157456 0.1886 0.219556 0.250292 0.280778; + 0.0 0.00999983 0.0199987 0.0299955 0.0399893 0.0499792 0.059964 0.0699428 0.0799147 0.0898785; + 0.0 0.00316227 0.00632451 0.00948669 0.0126488 0.0158107 0.0189725 0.0221341 0.0252955 0.0284567; + 0.0 0.001 0.002 0.003 0.00399999 0.00499998 0.00599996 0.00699994 0.00799991 0.00899988; + 0.0 0.000316228 0.000632456 0.000948683 0.00126491 0.00158114 0.00189737 0.00221359 0.00252982 0.00284605 + ], (16,10,1)) + + @test isapprox(rope.cos_cached, expected_cos) + @test isapprox(rope.sin_cached, expected_sin) + end + + @testset "neg_half Function Test" begin + x = x = reshape([ + 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0; + 11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0; + 21.0 22.0 23.0 24.0 25.0 26.0 27.0 28.0 29.0 30.0; + 31.0 32.0 33.0 34.0 35.0 36.0 37.0 38.0 39.0 40.0; + 41.0 42.0 43.0 44.0 45.0 46.0 47.0 48.0 49.0 50.0; + 51.0 52.0 53.0 54.0 55.0 56.0 57.0 58.0 59.0 60.0; + 61.0 62.0 63.0 64.0 65.0 66.0 67.0 68.0 69.0 70.0; + 71.0 72.0 73.0 74.0 75.0 76.0 77.0 78.0 79.0 80.0; + 81.0 82.0 83.0 84.0 85.0 86.0 87.0 88.0 89.0 90.0; + 91.0 92.0 93.0 94.0 95.0 96.0 97.0 98.0 99.0 100.0; + 101.0 102.0 103.0 104.0 105.0 106.0 107.0 108.0 109.0 110.0; + 111.0 112.0 113.0 114.0 115.0 116.0 117.0 118.0 119.0 120.0; + 121.0 122.0 123.0 124.0 125.0 126.0 127.0 128.0 129.0 130.0; + 131.0 132.0 133.0 134.0 135.0 136.0 137.0 138.0 139.0 140.0; + 141.0 142.0 143.0 144.0 145.0 146.0 147.0 148.0 149.0 150.0; + 151.0 152.0 153.0 154.0 155.0 156.0 157.0 158.0 159.0 160.0 + ], (16,10,1)) + + expected_neg_half = reshape([ + -81.0 -82.0 -83.0 -84.0 -85.0 -86.0 -87.0 -88.0 -89.0 -90.0; + -91.0 -92.0 -93.0 -94.0 -95.0 -96.0 -97.0 -98.0 -99.0 -100.0; + -101.0 -102.0 -103.0 -104.0 -105.0 -106.0 -107.0 -108.0 -109.0 -110.0; + -111.0 -112.0 -113.0 -114.0 -115.0 -116.0 -117.0 -118.0 -119.0 -120.0; + -121.0 -122.0 -123.0 -124.0 -125.0 -126.0 -127.0 -128.0 -129.0 -130.0; + -131.0 -132.0 -133.0 -134.0 -135.0 -136.0 -137.0 -138.0 -139.0 -140.0; + -141.0 -142.0 -143.0 -144.0 -145.0 -146.0 -147.0 -148.0 -149.0 -150.0; + -151.0 -152.0 -153.0 -154.0 -155.0 -156.0 -157.0 -158.0 -159.0 -160.0; + 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0; + 11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0; + 21.0 22.0 23.0 24.0 25.0 26.0 27.0 28.0 29.0 30.0; + 31.0 32.0 33.0 34.0 35.0 36.0 37.0 38.0 39.0 40.0; + 41.0 42.0 43.0 44.0 45.0 46.0 47.0 48.0 49.0 50.0; + 51.0 52.0 53.0 54.0 55.0 56.0 57.0 58.0 59.0 60.0; + 61.0 62.0 63.0 64.0 65.0 66.0 67.0 68.0 69.0 70.0; + 71.0 72.0 73.0 74.0 75.0 76.0 77.0 78.0 79.0 80.0 + ], (16,10,1)) + + @test isapprox(PositionalEmbeddings.neg_half(x), expected_neg_half) + end + + @testset "Forward Pass Test" begin + x = reshape(Float32.(1:160), (16,10,1)) + pe = RoPE(16, 10) + + # Manual calculation + neg_half_x = PositionalEmbeddings.neg_half(x) + cos_mat = view(pe.cos_cached, 1:size(x,1), 1:size(x,2), :) + sin_mat = view(pe.sin_cached, 1:size(x,1), 1:size(x,2), :) + expected_output = @. muladd(x * pe.scale, cos_mat, neg_half_x * pe.scale * sin_mat) + + # Test the forward pass + actual_output = pe(x) + @test isapprox(actual_output, expected_output) + end + @testset "Gradient Tests (CPU, Float64)" begin eps = 1e-8 rope = RoPE(8, 4; T=Float64) @@ -30,27 +179,6 @@ end @test isapprox(analytical_grad, numerical_grad, atol=1e-4) end - @testset "Reference Output Test (CPU)" begin - features, seq_len, batch_size = 16, 10, 64 - x = Float32.(reshape(collect(1:features*seq_len*batch_size), - features, seq_len, batch_size)) ./ - (features*seq_len*batch_size) - - reference = Float32[ - 9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281; - -0.00115739 0.000881045 0.00158297 0.00186569 0.00202236; - -0.00498184 0.000253548 0.00251558 0.00323702 0.00352467; - -0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477; - 0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219 - ] - - rope = RoPE(features, seq_len) - output = rope(x) - result = output[1:5,1:5,1]' - - @test all(isapprox.(result, reference, atol=1e-4)) - end - if has_working_cuda() @testset "GPU Tests" begin @testset "Gradient Computation (GPU, Float32)" begin @@ -63,37 +191,24 @@ end ) x = CUDA.randn(Float32, 8, 4, 1) - # Just verify that gradient computation doesn't error loss(x) = sum(abs2, rope_gpu(x)) @test_nowarn gradient(loss, x) end - @testset "Reference Output Test (GPU)" begin - features, seq_len, batch_size = 16, 10, 64 - x = Float32.(reshape(collect(1:features*seq_len*batch_size), - features, seq_len, batch_size)) ./ - (features*seq_len*batch_size) - x = cu(x) - - reference = Float32[ - 9.76563f-5 0.000195313 0.000292969 0.000390625 0.000488281; - -0.00115739 0.000881045 0.00158297 0.00186569 0.00202236; - -0.00498184 0.000253548 0.00251558 0.00323702 0.00352467; - -0.0055228 -0.00175742 0.00305532 0.00450025 0.00499477; - 0.00124607 -0.00495019 0.00317429 0.00565127 0.00643219 - ] - - rope = RoPE(features, seq_len) - rope_gpu = RoPE( - rope.features, - cu(rope.cos_cached), - cu(rope.sin_cached), - rope.scale + @testset "Forward Pass (GPU)" begin + x = reshape(Float32.(1:160), (16,10,1)) + x_gpu = cu(x) + pe = RoPE(16, 10) + pe_gpu = RoPE( + pe.features, + cu(pe.cos_cached), + cu(pe.sin_cached), + pe.scale ) - output = rope_gpu(x) - result = Array(output[1:5,1:5,1]') - @test all(isapprox.(result, reference, atol=1e-4)) + cpu_output = pe(x) + gpu_output = Array(pe_gpu(x_gpu)) + @test isapprox(cpu_output, gpu_output) end end else