diff --git a/examples/test_qp.rs b/examples/test_qp.rs index 798c290..4826f61 100644 --- a/examples/test_qp.rs +++ b/examples/test_qp.rs @@ -1,9 +1,9 @@ -use rdma_mummy_sys::ibv_access_flags; use sideway::verbs::{ address::AddressHandleAttribute, device, device_context::Mtu, queue_pair::{QueuePair, QueuePairAttribute, QueuePairState}, + AccessFlags, }; fn main() -> Result<(), Box> { @@ -34,7 +34,7 @@ fn main() -> Result<(), Box> { attr.setup_state(QueuePairState::Init) .setup_pkey_index(0) .setup_port(1) - .setup_access_flags(ibv_access_flags::IBV_ACCESS_REMOTE_WRITE); + .setup_access_flags(AccessFlags::LocalWrite | AccessFlags::RemoteWrite); qp.modify(&attr).unwrap(); assert_eq!(QueuePairState::Init, qp.state()); diff --git a/src/verbs/mod.rs b/src/verbs/mod.rs index 9169ac0..f382c3b 100644 --- a/src/verbs/mod.rs +++ b/src/verbs/mod.rs @@ -5,3 +5,22 @@ pub mod device_context; pub mod memory_region; pub mod protection_domain; pub mod queue_pair; + +use bitmask_enum::bitmask; +use rdma_mummy_sys::ibv_access_flags; + +#[bitmask(i32)] +#[bitmask_config(vec_debug)] +pub enum AccessFlags { + LocalWrite = ibv_access_flags::IBV_ACCESS_LOCAL_WRITE.0 as _, + RemoteWrite = ibv_access_flags::IBV_ACCESS_REMOTE_WRITE.0 as _, + RemoteRead = ibv_access_flags::IBV_ACCESS_REMOTE_READ.0 as _, + RemoteAtomic = ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC.0 as _, + MemoryWindowBind = ibv_access_flags::IBV_ACCESS_MW_BIND.0 as _, + ZeroBased = ibv_access_flags::IBV_ACCESS_ZERO_BASED.0 as _, + OnDemand = ibv_access_flags::IBV_ACCESS_ON_DEMAND.0 as _, + HugeTlb = ibv_access_flags::IBV_ACCESS_HUGETLB.0 as _, + FlushGlobal = ibv_access_flags::IBV_ACCESS_FLUSH_GLOBAL.0 as _, + FlushPersistent = ibv_access_flags::IBV_ACCESS_FLUSH_PERSISTENT.0 as _, + RelaxedOrdering = ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING.0 as _, +} diff --git a/src/verbs/protection_domain.rs b/src/verbs/protection_domain.rs index 53d4d56..cf1cef5 100644 --- a/src/verbs/protection_domain.rs +++ b/src/verbs/protection_domain.rs @@ -7,6 +7,7 @@ use super::{ device_context::DeviceContext, memory_region::{Buffer, MemoryRegion}, queue_pair::QueuePairBuilder, + AccessFlags, }; #[derive(Debug)] @@ -35,7 +36,14 @@ impl ProtectionDomain<'_> { pub fn reg_managed_mr(&self, size: usize) -> Result { let buf = Buffer::from_len_zeroed(size); - let mr = unsafe { ibv_reg_mr(self.pd.as_ptr(), buf.data.as_ptr() as _, buf.len, 0) }; + let mr = unsafe { + ibv_reg_mr( + self.pd.as_ptr(), + buf.data.as_ptr() as _, + buf.len, + (AccessFlags::RemoteWrite | AccessFlags::LocalWrite).into(), + ) + }; if mr.is_null() { return Err(format!("{:?}", io::Error::last_os_error())); diff --git a/src/verbs/queue_pair.rs b/src/verbs/queue_pair.rs index def6eb3..9468250 100644 --- a/src/verbs/queue_pair.rs +++ b/src/verbs/queue_pair.rs @@ -1,8 +1,9 @@ +use bitmask_enum::bitmask; use lazy_static::lazy_static; use rdma_mummy_sys::{ - ibv_access_flags, ibv_create_qp, ibv_create_qp_ex, ibv_destroy_qp, ibv_modify_qp, ibv_qp, ibv_qp_attr, - ibv_qp_attr_mask, ibv_qp_cap, ibv_qp_ex, ibv_qp_init_attr, ibv_qp_init_attr_ex, ibv_qp_state, ibv_qp_type, - ibv_rx_hash_conf, + ibv_create_qp, ibv_create_qp_ex, ibv_destroy_qp, ibv_modify_qp, ibv_qp, ibv_qp_attr, ibv_qp_attr_mask, ibv_qp_cap, + ibv_qp_create_send_ops_flags, ibv_qp_ex, ibv_qp_init_attr, ibv_qp_init_attr_ex, ibv_qp_init_attr_mask, + ibv_qp_state, ibv_qp_to_qp_ex, ibv_qp_type, ibv_rx_hash_conf, }; use std::{ io, @@ -11,11 +12,9 @@ use std::{ ptr::{null_mut, NonNull}, }; -use bitmask_enum::bitmask; - use super::{ address::AddressHandleAttribute, completion::CompletionQueue, device_context::Mtu, - protection_domain::ProtectionDomain, + protection_domain::ProtectionDomain, AccessFlags, }; #[repr(u32)] @@ -58,6 +57,24 @@ impl From for QueuePairState { } } +#[bitmask(u64)] +#[bitmask_config(vec_debug)] +pub enum SendOperationFlags { + Write = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE.0 as _, + WriteWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM.0 as _, + Send = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND.0 as _, + SendWithImmediate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_IMM.0 as _, + Read = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_RDMA_READ.0 as _, + AtomicCompareAndSwap = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_CMP_AND_SWP.0 as _, + AtomicFetchAndAdd = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_FETCH_AND_ADD.0 as _, + LocalInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_LOCAL_INV.0 as _, + BindMemoryWindow = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_BIND_MW.0 as _, + SendWithInvalidate = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_SEND_WITH_INV.0 as _, + TcpSegmentationOffload = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_TSO.0 as _, + Flush = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_FLUSH.0 as _, + AtomicWrite = ibv_qp_create_send_ops_flags::IBV_QP_EX_WITH_ATOMIC_WRITE.0 as _, +} + pub trait QueuePair { //! return the basic handle of QP; //! we mark this method unsafe because the lifetime of ibv_qp is not @@ -280,24 +297,21 @@ impl QueuePair for BasicQueuePair<'_> { #[derive(Debug)] pub struct ExtendedQueuePair<'res> { - // TODO: ibv_create_qp_ex returns ibv_qp instead of ibv_qp_ex, to be fixed - pub(crate) qp_ex: NonNull, + pub(crate) qp_ex: NonNull, // phantom data for protection domain & completion queues _phantom: PhantomData<&'res ()>, } impl Drop for ExtendedQueuePair<'_> { fn drop(&mut self) { - // TODO: convert qp_ex to qp (port ibv_qp_ex_to_qp in rdma-mummy-sys) - let ret = unsafe { ibv_destroy_qp(self.qp_ex.as_ptr().cast()) }; + let ret = unsafe { ibv_destroy_qp(self.qp().as_ptr()) }; assert_eq!(ret, 0) } } impl QueuePair for ExtendedQueuePair<'_> { unsafe fn qp(&self) -> NonNull { - // TODO: convert qp_ex to qp (port ibv_qp_ex_to_qp in rdma-mummy-sys) - self.qp_ex.cast() + NonNull::new_unchecked(&mut (*self.qp_ex.as_ptr()).qp_base as _) } } @@ -327,7 +341,10 @@ impl<'res> QueuePairBuilder<'res> { }, qp_type: QueuePairType::ReliableConnection as _, sq_sig_all: 0, - comp_mask: 0, + // when building an extended qp instead of a basic qp, we need to pass in + // these essential attributes. + comp_mask: ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_PD.0 + | ibv_qp_init_attr_mask::IBV_QP_INIT_ATTR_SEND_OPS_FLAGS.0, pd: pd.pd.as_ptr(), xrcd: null_mut(), create_flags: 0, @@ -335,7 +352,14 @@ impl<'res> QueuePairBuilder<'res> { rwq_ind_tbl: null_mut(), rx_hash_conf: unsafe { MaybeUninit::::zeroed().assume_init() }, source_qpn: 0, - send_ops_flags: 0, + // unless user specified, we assume every extended qp would support send, + // write and read, just as what basic qp supports. + send_ops_flags: (SendOperationFlags::Send + | SendOperationFlags::SendWithImmediate + | SendOperationFlags::Write + | SendOperationFlags::WriteWithImmediate + | SendOperationFlags::Read) + .into(), }, _phantom: PhantomData, } @@ -387,6 +411,11 @@ impl<'res> QueuePairBuilder<'res> { self } + pub fn setup_send_ops_flags(&mut self, send_ops_flags: SendOperationFlags) -> &mut Self { + self.init_attr.send_ops_flags = send_ops_flags.bits; + self + } + // build basic qp pub fn build(&self) -> Result, String> { let qp = unsafe { @@ -411,11 +440,14 @@ impl<'res> QueuePairBuilder<'res> { } // build extended qp - pub fn build_ex(&mut self) -> Result, String> { - let qp = unsafe { ibv_create_qp_ex((*self.init_attr.pd).context, &mut self.init_attr).unwrap_or(null_mut()) }; + pub fn build_ex(&self) -> Result, String> { + let mut attr = self.init_attr.clone(); + + let qp = unsafe { ibv_create_qp_ex((*(attr.pd)).context, &mut attr).unwrap_or(null_mut()) }; Ok(ExtendedQueuePair { - qp_ex: NonNull::new(qp).ok_or(format!("ibv_create_qp failed, {}", io::Error::last_os_error()))?, + qp_ex: NonNull::new(unsafe { ibv_qp_to_qp_ex(qp) }) + .ok_or(format!("ibv_create_qp_ex failed, {}", io::Error::last_os_error()))?, _phantom: PhantomData, }) } @@ -462,9 +494,8 @@ impl QueuePairAttribute { self } - // TODO(fuji): use ibv_access_flags directly or wrap a type for this? - pub fn setup_access_flags(&mut self, access_flags: ibv_access_flags) -> &mut Self { - self.attr.qp_access_flags = access_flags.0; + pub fn setup_access_flags(&mut self, access_flags: AccessFlags) -> &mut Self { + self.attr.qp_access_flags = access_flags.bits as _; self.attr_mask |= QueuePairAttributeMask::AccessFlags; self }