Skip to content

Commit

Permalink
Merge pull request #209 from psarna/grpc_replication
Browse files Browse the repository at this point in the history
replication: add gRPC-based replication
  • Loading branch information
penberg authored Jul 17, 2023
2 parents c6c0b30 + 2f06e8a commit d861924
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 65 deletions.
2 changes: 2 additions & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ libsql_replication = { version = "0", path = "../replication", optional = true }
tokio = { version = "1.29.1", features = ["macros"] }
tracing-subscriber = "0.3.17"
tracing = "0.1.37"
parking_lot = "0.12.1"

[dev-dependencies]
criterion = { version = "0.4", features = ["html_reports", "async", "async_futures"] }
pprof = { version = "0.11.1", features = ["criterion", "flamegraph"] }
tokio = { version = "1.29.1", features = ["full"] }

[features]
default = ["replication"]
Expand Down
9 changes: 6 additions & 3 deletions crates/core/examples/from_snapshot.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use libsql::Database;
use libsql_replication::{Frames, TempSnapshot};

fn main() {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

let mut db = Database::with_replicator("test.db");
let mut db = Database::with_replicator("http://localhost:5001", "test.db")
.await
.unwrap();
let conn = db.connect().unwrap();

let args = std::env::args().collect::<Vec<String>>();
Expand All @@ -15,7 +18,7 @@ fn main() {
let snapshot_path = args.get(1).unwrap();
let snapshot = TempSnapshot::from_snapshot_file(snapshot_path.as_ref()).unwrap();

db.sync(Frames::Snapshot(snapshot)).unwrap();
db.sync_frames(Frames::Snapshot(snapshot)).unwrap();

let rows = conn
.execute("SELECT * FROM sqlite_master", ())
Expand Down
79 changes: 33 additions & 46 deletions crates/core/examples/replica.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,44 @@
use libsql::Database;
use libsql_replication::{Frame, FrameHeader, Frames};

fn frame_data_offset(frame_no: u64) -> u64 {
tracing::debug!(
"WAL offset: {frame_no}->{}",
32 + (frame_no - 1) * (24 + 4096) + 24
);
32 + (frame_no - 1) * (24 + 4096) + 24
}

fn test_frame(frame_no: u64) -> Frame {
let header = FrameHeader {
frame_no,
checksum: 0xdeadc0de,
page_no: frame_no as u32,
size_after: frame_no as u32,
};

let loaded = {
use std::io::{Read, Seek};
let mut f = std::fs::File::open("tests/template.db-wal").unwrap();
f.seek(std::io::SeekFrom::Start(frame_data_offset(frame_no)))
.unwrap();
let mut buf = vec![0; 4096];
f.read_exact(&mut buf).unwrap();
buf
};

Frame::from_parts(&header, &loaded)
}

fn main() {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

std::fs::create_dir("data.libsql").ok();
std::fs::copy("tests/template.db", "data.libsql/data").unwrap();

let mut db = Database::with_replicator("data.libsql/data");
let db = Database::with_replicator("http://localhost:5001", "test.db")
.await
.unwrap();
let conn = db.connect().unwrap();

let sync_result = db.sync(Frames::Vec(vec![test_frame(1), test_frame(2)]));
println!("sync result: {:?}", sync_result);
let rows = conn
.execute("SELECT * FROM sqlite_master", ())
.unwrap()
.unwrap();
while let Ok(Some(row)) = rows.next() {
println!(
"| {:024} | {:024} | {:024} | {:024} |",
row.get::<&str>(0).unwrap(),
row.get::<&str>(1).unwrap(),
row.get::<&str>(2).unwrap(),
row.get::<&str>(3).unwrap(),
);
let db = std::sync::Arc::new(parking_lot::Mutex::new(db));
loop {
if let Err(e) = tokio::task::spawn_blocking({
let db = db.clone();
move || db.lock().sync()
})
.await
{
println!("Error: {e}");
break;
};
let response = conn.execute("SELECT * FROM sqlite_master", ()).unwrap();
let rows = match response {
Some(rows) => rows,
None => {
println!("No rows");
continue;
}
};
while let Ok(Some(row)) = rows.next() {
println!(
"| {:024} | {:024} | {:024} | {:024} |",
row.get::<&str>(0).unwrap(),
row.get::<&str>(1).unwrap(),
row.get::<&str>(2).unwrap(),
row.get::<&str>(3).unwrap(),
);
}
}
}
61 changes: 46 additions & 15 deletions crates/core/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
use crate::{connection::Connection, Result};
use crate::{connection::Connection, errors::Error::ConnectionFailed, Result};
#[cfg(feature = "replication")]
use libsql_replication::Replicator;
#[cfg(feature = "replication")]
pub use libsql_replication::{Frames, TempSnapshot};
pub use libsql_replication::{rpc, Client, Frames, TempSnapshot};

pub struct ReplicationContext {
pub replicator: Replicator,
pub client: Client,
}

// A libSQL database.
pub struct Database {
pub url: String,
#[cfg(feature = "replication")]
pub replicator: Option<Replicator>,
pub replication_ctx: Option<ReplicationContext>,
}

impl Database {
pub fn open<S: Into<String>>(url: S) -> Database {
let url = url.into();
if url.starts_with("libsql:") {
let url = url.replace("libsql:", "http:");
tracing::info!("Absolutely ignoring libsql URL: {url}");
let filename = "data.libsql/data".to_string();
if url.starts_with("libsql:") || url.starts_with("http:") {
tracing::warn!("Ignoring {url} in Database::open() and opening a local db");
let filename = "libsql_tmp.db".to_string();
Database::new(filename)
} else {
Database::new(url)
Expand All @@ -28,15 +32,28 @@ impl Database {
Database {
url,
#[cfg(feature = "replication")]
replicator: None,
replication_ctx: None,
}
}

#[cfg(feature = "replication")]
pub fn with_replicator(url: impl Into<String>) -> Database {
pub async fn with_replicator(
url: impl Into<String>,
db_path: impl Into<String>,
) -> Result<Database> {
let url = url.into();
let replicator = Some(Replicator::new(&url).unwrap());
Database { url, replicator }
let db_path = db_path.into();
let mut db = Database::open(&db_path);
let replicator = Replicator::new(db_path).map_err(|e| ConnectionFailed(format!("{e}")))?;
let (client, meta) = Replicator::connect_to_rpc(
rpc::Endpoint::from_shared(url.clone())
.map_err(|e| ConnectionFailed(format!("{e}")))?,
)
.await
.map_err(|e| ConnectionFailed(format!("{e}")))?;
*replicator.meta.lock() = Some(meta);
db.replication_ctx = Some(ReplicationContext { replicator, client });
Ok(db)
}

pub fn close(&self) {}
Expand All @@ -46,11 +63,25 @@ impl Database {
}

#[cfg(feature = "replication")]
pub fn sync(&mut self, frames: Frames) -> Result<()> {
if let Some(replicator) = &mut self.replicator {
replicator
pub fn sync(&mut self) -> Result<()> {
if let Some(ctx) = &mut self.replication_ctx {
ctx.replicator
.sync_from_rpc(&mut ctx.client)
.map_err(|e| ConnectionFailed(format!("{e}")))
} else {
Err(crate::errors::Error::Misuse(
"No replicator available. Use Database::with_replicator() to enable replication"
.to_string(),
))
}
}

#[cfg(feature = "replication")]
pub fn sync_frames(&mut self, frames: Frames) -> Result<()> {
if let Some(ctx) = &mut self.replication_ctx {
ctx.replicator
.sync(frames)
.map_err(|e| crate::errors::Error::ConnectionFailed(format!("{e}")))
.map_err(|e| ConnectionFailed(format!("{e}")))
} else {
Err(crate::errors::Error::Misuse(
"No replicator available. Use Database::with_replicator() to enable replication"
Expand Down
133 changes: 133 additions & 0 deletions crates/replication/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ pub use replica::snapshot::TempSnapshot;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;

pub mod rpc {
#![allow(clippy::all)]
tonic::include_proto!("wal_log");

pub use tonic::transport::Endpoint;
pub type Client = replication_log_client::ReplicationLogClient<tonic::transport::Channel>;
}
pub struct Replicator {
pub frames_sender: Sender<Frames>,
pub current_frame_no_notifier: tokio::sync::watch::Receiver<FrameNo>,
Expand All @@ -24,6 +31,11 @@ pub struct Replicator {
pub injector: replica::injector::FrameInjector<'static>,
}

pub struct Client {
pub inner: rpc::Client,
pub stream: Option<tonic::Streaming<rpc::Frame>>,
}

impl Replicator {
pub fn new(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
let (applied_frame_notifier, current_frame_no_notifier) =
Expand Down Expand Up @@ -147,4 +159,125 @@ impl Replicator {
self.injector.step()?;
Ok(())
}

pub async fn connect_to_rpc(
addr: impl Into<tonic::transport::Endpoint>,
) -> anyhow::Result<(Client, replica::meta::WalIndexMeta)> {
let mut client = rpc::Client::connect(addr).await?;
let response = client.hello(rpc::HelloRequest {}).await?.into_inner();
let client = Client {
inner: client,
stream: None,
};
// FIXME: not that simple, we need to figure out if we always start from frame 1?
let meta = replica::meta::WalIndexMeta {
pre_commit_frame_no: 0,
post_commit_frame_no: 0,
generation_id: response.generation_id.parse::<uuid::Uuid>()?.to_u128_le(),
database_id: response.database_id.parse::<uuid::Uuid>()?.to_u128_le(),
};
tracing::debug!("Hello response: {response:?}");
Ok((client, meta))
}

// Syncs frames from RPC, returns true if it succeeded in applying a whole transaction
async fn sync_from_rpc_internal(&mut self, client: &mut Client) -> anyhow::Result<bool> {
use futures::StreamExt;
const MAX_REPLICA_REPLICATION_BUFFER_LEN: usize = 10_000_000 / 4096; // ~10MB
tracing::trace!("Syncing frames from RPC");
// Reuse the stream if it exists, otherwise create a new one
let stream = match &mut client.stream {
Some(stream) => stream,
None => {
tracing::trace!("Creating new stream");
// FIXME: sqld code uses the frame_no_notifier here - investigate if so should we
let next_offset = self.meta.lock().unwrap().pre_commit_frame_no;
client.stream = Some(
client
.inner
.log_entries(rpc::LogOffset { next_offset })
.await?
.into_inner(),
);
client.stream.as_mut().unwrap()
}
};

let mut buffer = Vec::new();
loop {
match stream.next().await {
Some(Ok(frame)) => {
let frame = Frame::try_from_bytes(frame.data)?;
tracing::trace!(
"Received frame {frame:?}, buffer has {} frames, size_after={}",
buffer.len(),
frame.header().size_after
);
buffer.push(frame.clone());
if frame.header().size_after != 0
|| buffer.len() > MAX_REPLICA_REPLICATION_BUFFER_LEN
{
tracing::trace!("Sending {} frames to the injector", buffer.len());
let _ = self
.frames_sender
.send(Frames::Vec(std::mem::take(&mut buffer)))
.await;
// Let's return here to indicate that we made progress.
// There may be more data in the stream and it's fine, the user would just ask to sync again.
return Ok(frame.header().size_after != 0);
}
}
Some(Err(err))
if err.code() == tonic::Code::FailedPrecondition
&& err.message() == "NEED_SNAPSHOT" =>
{
tracing::info!("loading snapshot");
// remove any outstanding frames in the buffer that are not part of a
// transaction: they are now part of the snapshot.
buffer.clear();
let _ = stream;
self.sync_from_snapshot(client).await?;
return Ok(true);
}
Some(Err(e)) => return Err(e.into()),
None => return Ok(true),
}
}
}

pub fn sync_from_rpc(&mut self, client: &mut Client) -> anyhow::Result<()> {
let runtime = tokio::runtime::Handle::current();
loop {
let done = runtime.block_on(self.sync_from_rpc_internal(client))?;
tracing::trace!("Injecting frames");
self.injector.step()?;
tracing::trace!("Injected frames");
if done {
break;
}
}
Ok(())
}

async fn sync_from_snapshot(&mut self, client: &mut Client) -> anyhow::Result<()> {
use futures::StreamExt;

let next_offset = self.meta.lock().unwrap().pre_commit_frame_no;
let frames = client
.inner
.snapshot(rpc::LogOffset { next_offset })
.await?
.into_inner();

let stream = frames.map(|data| match data {
Ok(frame) => Frame::try_from_bytes(frame.data),
Err(e) => anyhow::bail!(e),
});
// FIXME: do not hardcode the temporary path for downloading snapshots
let snap = TempSnapshot::from_stream("data.sqld".as_ref(), stream).await?;

let _ = self.frames_sender.send(Frames::Snapshot(snap)).await;

Ok(())
}
}
3 changes: 2 additions & 1 deletion crates/replication/src/replica/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub enum Frames {
Snapshot(TempSnapshot),
}

#[derive(Debug)]
pub struct Headers<'a> {
ptr: *mut PgHdr,
_pth: PhantomData<&'a ()>,
Expand Down Expand Up @@ -152,10 +153,10 @@ unsafe impl WalHook for InjectorHook {
let wal_ptr = wal as *mut _;
let ctx = Self::wal_extract_ctx(wal);
loop {
tracing::trace!("Waiting for a frame");
match ctx.receiver.blocking_recv() {
Some(frames) => {
let (headers, last_frame_no, size_after) = frames.to_headers();

let ret = ctx.inject_pages(
headers,
last_frame_no,
Expand Down

0 comments on commit d861924

Please sign in to comment.