Skip to content

Commit

Permalink
de-allocate indices on the read size, when closed
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHegarty committed Feb 27, 2025
1 parent 8cf5087 commit 3837a10
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
*/
package org.apache.lucene.sandbox.vectorsearch;

import static org.apache.lucene.sandbox.vectorsearch.CuVSVectorsReader.handleThrowable;

import com.nvidia.cuvs.BruteForceIndex;
import com.nvidia.cuvs.CagraIndex;
import com.nvidia.cuvs.HnswIndex;
import java.io.Closeable;
import java.io.IOException;
import java.util.Objects;

/** This class holds references to the actual CuVS Index (Cagra, Brute force, etc.) */
public class CuVSIndex {
public class CuVSIndex implements Closeable {
private final CagraIndex cagraIndex;
private final BruteForceIndex bruteforceIndex;
private final HnswIndex hnswIndex;

private int maxDocs;
private String fieldName;
private String segmentName;
private volatile boolean closed;

public CuVSIndex(
String segmentName,
Expand All @@ -55,14 +60,17 @@ public CuVSIndex(CagraIndex cagraIndex, BruteForceIndex bruteforceIndex, HnswInd
}

public CagraIndex getCagraIndex() {
ensureOpen();
return cagraIndex;
}

public BruteForceIndex getBruteforceIndex() {
ensureOpen();
return bruteforceIndex;
}

public HnswIndex getHNSWIndex() {
ensureOpen();
return hnswIndex;
}

Expand All @@ -77,4 +85,35 @@ public String getSegmentName() {
public int getMaxDocs() {
return maxDocs;
}

private void ensureOpen() {
if (closed) {
throw new IllegalStateException("index is closed");
}
}

@Override
public void close() throws IOException {
if (closed) {
return;
}
closed = true;
destroyIndices();
}

private void destroyIndices() throws IOException {
try {
if (cagraIndex != null) {
cagraIndex.destroyIndex();
}
if (bruteforceIndex != null) {
bruteforceIndex.destroyIndex();
}
if (hnswIndex != null) {
hnswIndex.destroyIndex();
}
} catch (Throwable t) {
handleThrowable(t);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
import com.nvidia.cuvs.HnswIndex;
import com.nvidia.cuvs.HnswIndexParams;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
Expand Down Expand Up @@ -276,7 +279,15 @@ private CuVSIndex loadCuVSIndex(FieldEntry fieldEntry) throws IOException {

@Override
public void close() throws IOException {
IOUtils.close(flatVectorsReader, cuvsIndexInput);
var closeableStream =
Stream.concat(
Stream.of(flatVectorsReader, cuvsIndexInput),
stream(cuvsIndices.values().iterator()).map(cursor -> cursor.value));
IOUtils.close(closeableStream::iterator);
}

static <T> Stream<T> stream(Iterator<T> iterator) {
return StreamSupport.stream(((Iterable<T>) () -> iterator).spliterator(), false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.apache.lucene.sandbox.vectorsearch.CuVSVectorsFormat.CUVS_META_CODEC_EXT;
import static org.apache.lucene.sandbox.vectorsearch.CuVSVectorsFormat.CUVS_META_CODEC_NAME;
import static org.apache.lucene.sandbox.vectorsearch.CuVSVectorsFormat.VERSION_CURRENT;
import static org.apache.lucene.sandbox.vectorsearch.CuVSVectorsReader.handleThrowable;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;

Expand Down Expand Up @@ -433,15 +434,6 @@ static void handleThrowableWithIgnore(Throwable t, String msg) throws IOExceptio
handleThrowable(t);
}

static void handleThrowable(Throwable t) throws IOException {
switch (t) {
case IOException ioe -> throw ioe;
case Error error -> throw error;
case RuntimeException re -> throw re;
case null, default -> throw new RuntimeException("UNEXPECTED: exception type", t);
}
}

/** Copies the vector values into dst. Returns the actual number of vectors copied. */
private static int getVectorData(FloatVectorValues floatVectorValues, float[][] dst)
throws IOException {
Expand Down

0 comments on commit 3837a10

Please sign in to comment.