Skip to content

Commit

Permalink
Add info about GetOrLoad() in README.md and example_test.go in lrucac…
Browse files Browse the repository at this point in the history
…he package
  • Loading branch information
vasayxtx committed Feb 27, 2025
1 parent 94c72ec commit 9c10992
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 47 deletions.
32 changes: 24 additions & 8 deletions lrucache/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The `lrucache` package provides an in-memory cache with an LRU (Least Recently U
- **LRU Eviction Policy**: Automatically removes the least recently used items when the cache reaches its maximum size.
- **Prometheus Metrics**: Collects and exposes metrics to monitor cache usage and performance.
- **Expiration**: Supports setting TTL (Time To Live) for entries. Expired entries are removed during cleanup or when accessed.
- **Cache Stampede Mitigation**: Prevents multiple goroutines from loading the same key concurrently by using a single flight pattern.

## Usage

Expand Down Expand Up @@ -62,7 +63,7 @@ func Example() {
fmt.Printf("User: %s, %s\n", user.UUID, user.Name)
}

// LRU cache for posts.
// LRU cache for posts. Posts are loaded from DB if not found in cache.
const post1UUID = "823e50c7-984d-4de3-8a09-92fa21d3cc3b"
const post2UUID = "24707009-ddf6-4e88-bd51-84ae236b7fda"
postsCache, err := lrucache.NewWithOpts[string, Post](1_000,
Expand All @@ -77,29 +78,44 @@ func Example() {
defer cleanupCancel()
go postsCache.RunPeriodicCleanup(cleanupCtx, 10*time.Minute) // Run cleanup every 10 minutes.

postsCache.Add(post1UUID, Post{post1UUID, "Lorem ipsum dolor sit amet..."})
if post, found := postsCache.Get(post1UUID); found {
fmt.Printf("Post: %s, %s\n", post.UUID, post.Text)
loadPostFromDatabase := func(id string) (value Post, err error) {
// Emulate loading post from DB.
if id == post1UUID {
return Post{id, "Lorem ipsum dolor sit amet..."}, nil
}
return Post{}, fmt.Errorf("not found")
}
if _, found := postsCache.Get(post2UUID); !found {
fmt.Printf("Post: %s is missing\n", post2UUID)

for _, postID := range []string{post1UUID, post1UUID, post2UUID} {
// Get post from cache or load it from DB. If two goroutines try to load the same post concurrently,
// only one of them will actually load the post, while the other will wait for the first one to finish.
if post, exists, loadErr := postsCache.GetOrLoad(postID, loadPostFromDatabase); loadErr != nil {
fmt.Printf("Failed to load post %s: %v\n", postID, loadErr)
} else {
if exists {
fmt.Printf("Post: %s, %s\n", post.UUID, post.Text)
} else {
fmt.Printf("Post (loaded from db): %s, %s\n", post.UUID, post.Text)
}
}
}

// The following Prometheus metrics will be exposed:
// my_app_cache_entries_amount{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_entries_amount{app_version="1.2.3",entry_type="user"} 2
// my_app_cache_hits_total{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_hits_total{app_version="1.2.3",entry_type="user"} 2
// my_app_cache_misses_total{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_misses_total{app_version="1.2.3",entry_type="note"} 2

fmt.Printf("Users: %d\n", usersCache.Len())
fmt.Printf("Posts: %d\n", postsCache.Len())

// Output:
// User: 966971df-a592-4e7e-a309-52501016fa44, Alice
// User: 848adf28-84c1-4259-97a2-acba7cf5c0b6, Bob
// Post (loaded from db): 823e50c7-984d-4de3-8a09-92fa21d3cc3b, Lorem ipsum dolor sit amet...
// Post: 823e50c7-984d-4de3-8a09-92fa21d3cc3b, Lorem ipsum dolor sit amet...
// Post: 24707009-ddf6-4e88-bd51-84ae236b7fda is missing
// Failed to load post 24707009-ddf6-4e88-bd51-84ae236b7fda: not found
// Users: 2
// Posts: 1
}
Expand Down
98 changes: 67 additions & 31 deletions lrucache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ func TestLRUCache_PeriodicCleanup(t *testing.T) {
require.True(t, found)
}

func TestLRUCache_GetOrLoad(t *testing.T) {
func TestLRUCache_GetOrLoad_GetOrLoadWithTTL(t *testing.T) {
t.Run("key exists", func(t *testing.T) {
metrics := NewPrometheusMetrics()
cache, err := New[string, int](10, metrics)
Expand All @@ -363,7 +363,6 @@ func TestLRUCache_GetOrLoad(t *testing.T) {
require.True(t, exists)
require.Equal(t, 42, val)
require.Equal(t, 0, callCount)

assertPrometheusMetrics(t, expectedMetrics{EntriesAmount: 1, HitsTotal: 1}, metrics)
})

Expand Down Expand Up @@ -423,6 +422,9 @@ func TestLRUCache_GetOrLoad(t *testing.T) {
require.ErrorIs(t, err, loadErr)
require.Equal(t, 2, callCount)
assertPrometheusMetrics(t, expectedMetrics{EntriesAmount: 0, MissesTotal: 2}, metrics)

_, found := cache.Get("errorKey")
require.False(t, found)
})

t.Run("load value, single-flight", func(t *testing.T) {
Expand All @@ -439,27 +441,22 @@ func TestLRUCache_GetOrLoad(t *testing.T) {

const numGoroutines = 20
var wg sync.WaitGroup
results := make([]int, numGoroutines)
existsFlags := make([]bool, numGoroutines)
errs := make([]error, numGoroutines)
results := make([]getOrLoadResult[int], numGoroutines)

for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
v, exists, err := cache.GetOrLoad("sf-key", loadFunc)
results[idx] = v
existsFlags[idx] = exists
errs[idx] = err
var r getOrLoadResult[int]
r.value, r.exists, r.err = cache.GetOrLoad("sf-key", loadFunc)
results[idx] = r
}(i)
}
wg.Wait()

// Ensure each goroutine received the expected result.
for i := 0; i < numGoroutines; i++ {
require.NoError(t, errs[i])
require.Equal(t, 999, results[i])
require.False(t, existsFlags[i])
require.Equal(t, getOrLoadResult[int]{value: 999}, results[i], "goroutine %d received unexpected result", i)
}
require.EqualValues(t, 1, callCount.Load())
assertPrometheusMetrics(t, expectedMetrics{EntriesAmount: 1, MissesTotal: numGoroutines}, metrics)
Expand All @@ -475,31 +472,70 @@ func TestLRUCache_GetOrLoad(t *testing.T) {
}

func TestLRUCache_GetOrLoadWithTTL(t *testing.T) {
// Define a custom TTL shorter than the default TTL.
const customTTL = 100 * time.Millisecond
t.Run("custom TTL", func(t *testing.T) {
// Define a custom TTL shorter than the default TTL.
const customTTL = 100 * time.Millisecond

// Set a default TTL that is longer than the custom TTL to ensure the custom TTL is used.
cache, err := NewWithOpts[string, string](10, nil, Options{DefaultTTL: time.Second})
require.NoError(t, err)
// Set a default TTL that is longer than the custom TTL to ensure the custom TTL is used.
cache, err := NewWithOpts[string, string](10, nil, Options{DefaultTTL: time.Second})
require.NoError(t, err)

v, exists, err := cache.GetOrLoadWithTTL("ttl-key", func(key string) (string, time.Duration, error) {
return "ttl-value", customTTL, nil
})
require.NoError(t, err)
require.False(t, exists)
require.Equal(t, "ttl-value", v)

// Immediately after loading, the value is in the cache.
v2, ok := cache.Get("ttl-key")
require.True(t, ok)
require.Equal(t, "ttl-value", v2)

// Wait longer than the custom TTL.
time.Sleep(2 * customTTL)

v, exists, err := cache.GetOrLoadWithTTL("ttl-key", func(key string) (string, time.Duration, error) {
return "ttl-value", customTTL, nil
// The value should now be expired and thus not returned.
_, ok = cache.Get("ttl-key")
require.False(t, ok, "expected the value to be expired after the custom TTL")
})
require.NoError(t, err)
require.False(t, exists)
require.Equal(t, "ttl-value", v)

// Immediately after loading, the value is in the cache.
v2, ok := cache.Get("ttl-key")
require.True(t, ok)
require.Equal(t, "ttl-value", v2)
t.Run("custom TTL, error", func(t *testing.T) {
loadErr := errors.New("load error")

metrics := NewPrometheusMetrics()
cache, err := New[string, int](10, metrics)
require.NoError(t, err)

callCount := 0

// Wait longer than the custom TTL.
time.Sleep(2 * customTTL)
// The first call returns an error.
_, _, err = cache.GetOrLoadWithTTL("errorKey", func(key string) (int, time.Duration, error) {
callCount++
return 73, time.Second, loadErr
})
require.ErrorIs(t, err, loadErr)
require.Equal(t, 1, callCount)
assertPrometheusMetrics(t, expectedMetrics{EntriesAmount: 0, MissesTotal: 1}, metrics)

// The next call should try to load again.
_, _, err = cache.GetOrLoadWithTTL("errorKey", func(key string) (int, time.Duration, error) {
callCount++
return 73, time.Second, loadErr
})
require.ErrorIs(t, err, loadErr)
require.Equal(t, 2, callCount)
assertPrometheusMetrics(t, expectedMetrics{EntriesAmount: 0, MissesTotal: 2}, metrics)

_, found := cache.Get("errorKey")
require.False(t, found)
})
}

// The value should now be expired and thus not returned.
_, ok = cache.Get("ttl-key")
require.False(t, ok, "expected the value to be expired after the custom TTL")
type getOrLoadResult[V any] struct {
value V
exists bool
err error
}

type User struct {
Expand Down
31 changes: 23 additions & 8 deletions lrucache/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func Example() {
fmt.Printf("User: %s, %s\n", user.UUID, user.Name)
}

// LRU cache for posts.
// LRU cache for posts. Posts are loaded from DB if not found in cache.
const post1UUID = "823e50c7-984d-4de3-8a09-92fa21d3cc3b"
const post2UUID = "24707009-ddf6-4e88-bd51-84ae236b7fda"
postsCache, err := lrucache.NewWithOpts[string, Post](1_000,
Expand All @@ -68,29 +68,44 @@ func Example() {
defer cleanupCancel()
go postsCache.RunPeriodicCleanup(cleanupCtx, 10*time.Minute) // Run cleanup every 10 minutes.

postsCache.Add(post1UUID, Post{post1UUID, "Lorem ipsum dolor sit amet..."})
if post, found := postsCache.Get(post1UUID); found {
fmt.Printf("Post: %s, %s\n", post.UUID, post.Text)
loadPostFromDatabase := func(id string) (value Post, err error) {
// Emulate loading post from DB.
if id == post1UUID {
return Post{id, "Lorem ipsum dolor sit amet..."}, nil
}
return Post{}, fmt.Errorf("not found")
}
if _, found := postsCache.Get(post2UUID); !found {
fmt.Printf("Post: %s is missing\n", post2UUID)

for _, postID := range []string{post1UUID, post1UUID, post2UUID} {
// Get post from cache or load it from DB. If two goroutines try to load the same post concurrently,
// only one of them will actually load the post, while the other will wait for the first one to finish.
if post, exists, loadErr := postsCache.GetOrLoad(postID, loadPostFromDatabase); loadErr != nil {
fmt.Printf("Failed to load post %s: %v\n", postID, loadErr)
} else {
if exists {
fmt.Printf("Post: %s, %s\n", post.UUID, post.Text)
} else {
fmt.Printf("Post (loaded from db): %s, %s\n", post.UUID, post.Text)
}
}
}

// The following Prometheus metrics will be exposed:
// my_app_cache_entries_amount{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_entries_amount{app_version="1.2.3",entry_type="user"} 2
// my_app_cache_hits_total{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_hits_total{app_version="1.2.3",entry_type="user"} 2
// my_app_cache_misses_total{app_version="1.2.3",entry_type="note"} 1
// my_app_cache_misses_total{app_version="1.2.3",entry_type="note"} 2

fmt.Printf("Users: %d\n", usersCache.Len())
fmt.Printf("Posts: %d\n", postsCache.Len())

// Output:
// User: 966971df-a592-4e7e-a309-52501016fa44, Alice
// User: 848adf28-84c1-4259-97a2-acba7cf5c0b6, Bob
// Post (loaded from db): 823e50c7-984d-4de3-8a09-92fa21d3cc3b, Lorem ipsum dolor sit amet...
// Post: 823e50c7-984d-4de3-8a09-92fa21d3cc3b, Lorem ipsum dolor sit amet...
// Post: 24707009-ddf6-4e88-bd51-84ae236b7fda is missing
// Failed to load post 24707009-ddf6-4e88-bd51-84ae236b7fda: not found
// Users: 2
// Posts: 1
}

0 comments on commit 9c10992

Please sign in to comment.