Skip to content

Commit

Permalink
Write unit-tests for GetOrLoad() and GetOrLoadWithTTL()
Browse files Browse the repository at this point in the history
  • Loading branch information
vasayxtx committed Feb 22, 2025
1 parent c6995c9 commit 7f303f1
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 17 deletions.
74 changes: 57 additions & 17 deletions lrucache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func NewWithOpts[K comparable, V any](maxEntries int, metricsCollector MetricsCo
return nil, fmt.Errorf("defaultTTL must be greater or equal to 0 (no expiration)")
}
if metricsCollector == nil {
metricsCollector = disabledMetrics{}
metricsCollector = disabledMetricsCollector
}

return &LRUCache[K, V]{
Expand All @@ -81,7 +81,7 @@ func NewWithOpts[K comparable, V any](maxEntries int, metricsCollector MetricsCo
func (c *LRUCache[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
return c.get(key)
return c.get(key, true)
}

// Add adds a value to the cache with the provided key and type.
Expand All @@ -94,6 +94,7 @@ func (c *LRUCache[K, V]) Add(key K, value V) {
// If the cache is full, the oldest entry will be removed.
// Please note that expired entries are not removed immediately,
// but only when they are accessed or during periodic cleanup (see RunPeriodicCleanup).
// If the TTL is less than or equal to 0, the value will not expire.
func (c *LRUCache[K, V]) AddWithTTL(key K, value V, ttl time.Duration) {
var expiresAt time.Time
if ttl > 0 {
Expand All @@ -113,8 +114,7 @@ func (c *LRUCache[K, V]) AddWithTTL(key K, value V, ttl time.Duration) {

// GetOrAdd returns a value from the cache by the provided key,
// and adds a new value with the default TTL if the key does not exist.
// The new value is provided by the valueProvider function.
// The function is called only if the key does not exist.
// The new value is provided by the valueProvider function, which is called only if the key does not exist.
// Note that the function is called under the LRUCache lock, so it should be fast and non-blocking.
// If you need to perform a blocking operation, consider using GetOrLoad instead.
func (c *LRUCache[K, V]) GetOrAdd(key K, valueProvider func() V) (value V, exists bool) {
Expand All @@ -123,15 +123,15 @@ func (c *LRUCache[K, V]) GetOrAdd(key K, valueProvider func() V) (value V, exist

// GetOrAddWithTTL returns a value from the cache by the provided key,
// and adds a new value with the specified TTL if the key does not exist.
// The new value is provided by the valueProvider function.
// The function is called only if the key does not exist.
// The new value is provided by the valueProvider function, which is called only if the key does not exist.
// Note that the function is called under the LRUCache lock, so it should be fast and non-blocking.
// If you need to perform a blocking operation, consider using GetOrLoad instead.
// If you need to perform a blocking operation, consider using GetOrLoadWithTTL instead.
// If the TTL is less than or equal to 0, the value will not expire.
func (c *LRUCache[K, V]) GetOrAddWithTTL(key K, valueProvider func() V, ttl time.Duration) (value V, exists bool) {
c.mu.Lock()
defer c.mu.Unlock()

if value, exists = c.get(key); exists {
if value, exists = c.get(key, true); exists {
return value, exists
}

Expand All @@ -146,18 +146,52 @@ func (c *LRUCache[K, V]) GetOrAddWithTTL(key K, valueProvider func() V, ttl time

// GetOrLoad returns a value from the cache by the provided key,
// and loads a new value if the key does not exist.
// The new value is provided by the loadValue function.
// The function is called only if the key does not exist.
// Note that the single flight pattern is used to prevent multiple concurrent calls for the same key.
// The new value is provided by the loadValue function, which is called only if the key does not exist.
// The loadValue function returns the value and error.
// If the loadValue function returns an error, the value will not be added to the cache.
// Single flight pattern is used to prevent multiple concurrent calls for the same key.
func (c *LRUCache[K, V]) GetOrLoad(
key K, loadValue func(K) (value V, err error),
) (value V, exists bool, err error) {
return c.GetOrLoadWithTTL(key, func(k K) (value V, ttl time.Duration, err error) {
val, err := loadValue(k)
return val, c.defaultTTL, err
})
}

// GetOrLoadWithTTL returns a value from the cache by the provided key,
// and loads a new value if the key does not exist.
// The new value is provided by the loadValue function, which is called only if the key does not exist.
// The loadValue function returns the value, TTL, and error.
// If the TTL is less than or equal to 0, the value will not expire.
// If the loadValue function returns an error, the value will not be added to the cache.
// Single flight pattern is used to prevent multiple concurrent calls for the same key.
func (c *LRUCache[K, V]) GetOrLoadWithTTL(
key K, loadValue func(K) (value V, ttl time.Duration, err error),
) (value V, exists bool, err error) {
if val, ok := c.Get(key); ok {
// We have to use a separate function to get the value without modifying hits
// and misses metrics because of the single flight pattern and the double check.
get := func(key K) (value V, exists bool) {
c.mu.Lock()
defer c.mu.Unlock()
return c.get(key, false)
}

defer func() {
// We have to increment metrics after the actual call because of the single flight pattern and the double check.
if exists {
c.metricsCollector.IncHits()
} else {
c.metricsCollector.IncMisses()
}
}()

if val, ok := get(key); ok {
return val, true, nil
}

result, doErr := c.sfGroup.Do(key, func() (singleFlightCallResult[V], error) {
if val, ok := c.Get(key); ok { // double check after possible concurrent call
if val, ok := get(key); ok { // double check after possible concurrent call
return singleFlightCallResult[V]{value: val, exists: true}, nil
}
val, ttl, valErr := loadValue(key)
Expand Down Expand Up @@ -234,22 +268,28 @@ func (c *LRUCache[K, V]) Len() int {
return len(c.cache)
}

func (c *LRUCache[K, V]) get(key K) (value V, ok bool) {
func (c *LRUCache[K, V]) get(key K, incHitsAndMisses bool) (value V, ok bool) {
elem, hit := c.cache[key]
if !hit {
c.metricsCollector.IncMisses()
if incHitsAndMisses {
c.metricsCollector.IncMisses()
}
return value, false
}
entry := elem.Value.(*cacheEntry[K, V])
if !entry.expiresAt.IsZero() && entry.expiresAt.Before(time.Now()) {
c.lruList.Remove(elem)
delete(c.cache, key)
c.metricsCollector.SetAmount(len(c.cache))
c.metricsCollector.IncMisses()
if incHitsAndMisses {
c.metricsCollector.IncMisses()
}
return value, false
}
c.lruList.MoveToFront(elem)
c.metricsCollector.IncHits()
if incHitsAndMisses {
c.metricsCollector.IncHits()
}
return entry.value, true
}

Expand Down
149 changes: 149 additions & 0 deletions lrucache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ package lrucache

import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -341,6 +344,152 @@ func TestLRUCache_PeriodicCleanup(t *testing.T) {
require.True(t, found)
}

func TestLRUCache_GetOrLoad(t *testing.T) {
t.Run("key exists", func(t *testing.T) {
cache, err := New[string, int](10, nil)
require.NoError(t, err)

// Pre-populate the cache.
cache.Add("existing", 42)

// Call GetOrLoad with a load function that should not be invoked.
callCount := 0
val, exists, err := cache.GetOrLoad("existing", func(key string) (int, error) {
callCount++
return 73, nil
})
require.NoError(t, err)
require.True(t, exists)
require.Equal(t, 42, val)
require.Equal(t, 0, callCount)
})

t.Run("load value, success", func(t *testing.T) {
cache, err := New[string, int](10, nil)
require.NoError(t, err)

callCount := 0

// First call: key does not exist, so load function is called.
val, exists, err := cache.GetOrLoad("key", func(key string) (int, error) {
callCount++
return 456, nil
})
require.NoError(t, err)
require.False(t, exists) // fresh load returns exists == false
require.Equal(t, 456, val)
require.Equal(t, 1, callCount)

// Second call: the value should be cached.
val, exists, err = cache.GetOrLoad("key", func(key string) (int, error) {
callCount++
return 789, nil
})
require.NoError(t, err)
require.True(t, exists)
require.Equal(t, 456, val)
require.Equal(t, 1, callCount) // load function is not invoked again
})

t.Run("load value, error", func(t *testing.T) {
loadErr := errors.New("load error")

cache, err := New[string, int](10, nil)
require.NoError(t, err)

callCount := 0

// The first call returns an error.
_, _, err = cache.GetOrLoad("errorKey", func(key string) (int, error) {
callCount++
return 0, loadErr
})
require.ErrorIs(t, err, loadErr)
require.Equal(t, 1, callCount)

// The next call should try to load again.
_, _, err = cache.GetOrLoad("errorKey", func(key string) (int, error) {
callCount++
return 0, loadErr
})
require.ErrorIs(t, err, loadErr)
require.Equal(t, 2, callCount)
})

t.Run("load value, single-flight", func(t *testing.T) {
cache, err := NewWithOpts[string, int](10, nil, Options{DefaultTTL: time.Minute})
require.NoError(t, err)

var callCount atomic.Int64
loadFunc := func(key string) (int, error) {

Check failure on line 424 in lrucache/cache_test.go

View workflow job for this annotation

GitHub Actions / Lint (1.20)

`TestLRUCache_GetOrLoad$4$1` - `key` is unused (unparam)
time.Sleep(100 * time.Millisecond) // simulate some delay to force overlapping calls
callCount.Add(1)
return 999, nil
}

const numGoroutines = 20
var wg sync.WaitGroup
results := make([]int, numGoroutines)
existsFlags := make([]bool, numGoroutines)
errs := make([]error, numGoroutines)

for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
v, exists, err := cache.GetOrLoad("sf-key", loadFunc)
results[idx] = v
existsFlags[idx] = exists
errs[idx] = err
}(i)
}
wg.Wait()

// Ensure each goroutine received the expected result.
for i := 0; i < numGoroutines; i++ {
require.NoError(t, errs[i])
require.Equal(t, 999, results[i])
require.False(t, existsFlags[i])
}
require.EqualValues(t, 1, callCount.Load())

// A later call should find the key in the cache.
v, exists, err := cache.GetOrLoad("sf-key", loadFunc)
require.NoError(t, err)
require.True(t, exists)
require.Equal(t, 999, v)
require.EqualValues(t, 1, callCount.Load())
})
}

func TestLRUCache_GetOrLoadWithTTL(t *testing.T) {
// Define a custom TTL shorter than the default TTL.
const customTTL = 100 * time.Millisecond

// Set a default TTL that is longer than the custom TTL to ensure the custom TTL is used.
cache, err := NewWithOpts[string, string](10, nil, Options{DefaultTTL: time.Second})
require.NoError(t, err)

v, exists, err := cache.GetOrLoadWithTTL("ttl-key", func(key string) (string, time.Duration, error) {
return "ttl-value", customTTL, nil
})
require.NoError(t, err)
require.False(t, exists)
require.Equal(t, "ttl-value", v)

// Immediately after loading, the value is in the cache.
v2, ok := cache.Get("ttl-key")
require.True(t, ok)
require.Equal(t, "ttl-value", v2)

// Wait longer than the custom TTL.
time.Sleep(2 * customTTL)

// The value should now be expired and thus not returned.
_, ok = cache.Get("ttl-key")
require.False(t, ok, "expected the value to be expired after the custom TTL")
}

type User struct {
ID string
Name string
Expand Down
2 changes: 2 additions & 0 deletions lrucache/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,5 @@ func (disabledMetrics) SetAmount(int) {}
func (disabledMetrics) IncHits() {}
func (disabledMetrics) IncMisses() {}
func (disabledMetrics) AddEvictions(int) {}

var disabledMetricsCollector = disabledMetrics{}

0 comments on commit 7f303f1

Please sign in to comment.