Skip to content

Commit 6541fea

Browse files
committed
feat(virtq): address code review comments
Signed-off-by: Tomasz Andrzejak <andreiltd@gmail.com>
1 parent d18a41c commit 6541fea

5 files changed

Lines changed: 188 additions & 76 deletions

File tree

src/hyperlight_common/src/virtq/access.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ use bytemuck::Pod;
2929
/// # Safety
3030
///
3131
/// Implementations must ensure that:
32-
/// - Pointers passed to methods are valid for the duration of the call
33-
/// - Memory ordering guarantees are upheld as documented
34-
/// - Reads and writes don't cause undefined behavior (alignment, validity)
32+
/// - Addresses accepted by these methods are translated according to the
33+
/// backend's memory model.
34+
/// - Invalid or inaccessible addresses are reported with `Self::Error` rather
35+
/// than causing undefined behavior.
36+
/// - Memory ordering guarantees are upheld as documented.
37+
/// - Typed reads/writes and atomic operations honor alignment and initialized
38+
/// memory requirements for the translated addresses.
3539
///
3640
/// [`RingProducer`]: super::RingProducer
3741
/// [`RingConsumer`]: super::RingConsumer
38-
pub trait MemOps {
42+
pub unsafe trait MemOps {
3943
type Error;
4044

4145
/// Read bytes from physical memory.
@@ -47,9 +51,8 @@ pub trait MemOps {
4751
/// * `addr` - Guest physical address to read from
4852
/// * `dst` - Destination buffer to fill
4953
///
50-
/// # Safety
51-
///
52-
/// The caller must ensure `addr` is valid and points to at least `dst.len()` bytes.
54+
/// Implementations must return an error if `addr` cannot be read for
55+
/// at least `dst.len()` bytes.
5356
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error>;
5457

5558
/// Write bytes to physical memory.
@@ -59,23 +62,20 @@ pub trait MemOps {
5962
/// * `addr` - address to write to
6063
/// * `src` - Source data to write
6164
///
62-
/// # Safety
63-
///
64-
/// The caller must ensure `addr` is valid and points to at least `src.len()` bytes.
65+
/// Implementations must return an error if `addr` cannot be written for
66+
/// at least `src.len()` bytes.
6567
fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error>;
6668

6769
/// Load a u16 with acquire semantics.
6870
///
69-
/// # Safety
70-
///
71-
/// `addr` must translate to a valid, aligned `AtomicU16` in shared memory.
71+
/// Implementations must return an error if `addr` does not translate to a
72+
/// valid, aligned `AtomicU16` in shared memory.
7273
fn load_acquire(&self, addr: u64) -> Result<u16, Self::Error>;
7374

7475
/// Store a u16 with release semantics.
7576
///
76-
/// # Safety
77-
///
78-
/// `addr` must translate to a valid `AtomicU16` in shared memory.
77+
/// Implementations must return an error if `addr` does not translate to a
78+
/// valid, aligned `AtomicU16` in shared memory.
7979
fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error>;
8080

8181
/// Get a direct read-only slice into shared memory.
@@ -106,9 +106,8 @@ pub trait MemOps {
106106

107107
/// Read a Pod type at the given pointer.
108108
///
109-
/// # Safety
110-
///
111-
/// The caller must ensure `addr` is valid, aligned, and translates to initialized memory.
109+
/// Implementations must return an error if `addr` is not valid, aligned,
110+
/// and initialized for `T`.
112111
fn read_val<T: Pod>(&self, addr: u64) -> Result<T, Self::Error> {
113112
let mut val = T::zeroed();
114113
let bytes = bytemuck::bytes_of_mut(&mut val);
@@ -119,17 +118,18 @@ pub trait MemOps {
119118

120119
/// Write a Pod type at the given pointer.
121120
///
122-
/// # Safety
123-
///
124-
/// The caller ensures that `ptr` is valid.
121+
/// Implementations must return an error if `addr` is not valid and aligned
122+
/// for `T`.
125123
fn write_val<T: Pod>(&self, addr: u64, val: T) -> Result<(), Self::Error> {
126124
let bytes = bytemuck::bytes_of(&val);
127125
self.write(addr, bytes)?;
128126
Ok(())
129127
}
130128
}
131129

132-
impl<T: MemOps> MemOps for Arc<T> {
130+
// SAFETY: Arc delegates all memory operations to the wrapped backend, preserving
131+
// that backend's MemOps contract.
132+
unsafe impl<T: MemOps> MemOps for Arc<T> {
133133
type Error = T::Error;
134134

135135
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> {

src/hyperlight_common/src/virtq/desc.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ use super::MemOps;
2727

2828
bitflags! {
2929
/// Descriptor flags as defined by VIRTIO specification.
30+
///
31+
/// Note: The implementation never follows the indirect-table interpretation,
32+
/// so INDIRECT bit is effectively ignored.
3033
#[repr(transparent)]
3134
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
3235
pub struct DescFlags: u16 {
@@ -90,7 +93,7 @@ impl Descriptor {
9093
DescFlags::from_bits_truncate(self.flags)
9194
}
9295

93-
/// Did the guest mark this descriptor in the current guest round?
96+
/// Did the driver make this descriptor available in the current driver round?
9497
#[inline]
9598
pub fn is_avail(&self, wrap: bool) -> bool {
9699
let f = self.flags();
@@ -99,7 +102,7 @@ impl Descriptor {
99102
avail == wrap && used != wrap
100103
}
101104

102-
/// Did the host mark this descriptor used in the current host round?
105+
/// Did the device mark this descriptor used in the current device round?
103106
#[inline]
104107
pub fn is_used(&self, wrap: bool) -> bool {
105108
let f = self.flags();

src/hyperlight_common/src/virtq/event.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ const _: () = assert!(EventSuppression::WRAP_OFFSET == 0);
5353
const _: () = assert!(EventSuppression::FLAGS_OFFSET == 2);
5454

5555
impl EventSuppression {
56+
const FLAGS_MASK: u16 = 0x3;
57+
const DESC_EVENT_OFF_MASK: u16 = 0x7FFF;
58+
const DESC_EVENT_WRAP: u16 = 0x8000;
59+
5660
pub const SIZE: usize = core::mem::size_of::<Self>();
5761
pub const ALIGN: usize = core::mem::align_of::<Self>();
5862
pub const WRAP_OFFSET: usize = core::mem::offset_of!(Self, off_wrap);
@@ -68,27 +72,28 @@ impl EventSuppression {
6872

6973
/// Get the event flags.
7074
pub fn flags(&self) -> EventFlags {
71-
EventFlags::from_bits_truncate(self.flags & 0x3)
75+
EventFlags::from_bits_truncate(self.flags & Self::FLAGS_MASK)
7276
}
7377

7478
/// Set the event flags.
7579
pub fn set_flags(&mut self, flags: EventFlags) {
76-
self.flags = (self.flags & !0x3) | (flags.bits() & 0x3);
80+
self.flags = (self.flags & !Self::FLAGS_MASK) | (flags.bits() & Self::FLAGS_MASK);
7781
}
7882

7983
/// Get the descriptor event offset (bits 0-14).
8084
pub fn desc_event_off(&self) -> u16 {
81-
self.off_wrap & 0x7FFF
85+
self.off_wrap & Self::DESC_EVENT_OFF_MASK
8286
}
8387

8488
/// Check if the descriptor event wrap bit (bit 15) is set.
8589
pub fn desc_event_wrap(&self) -> bool {
86-
(self.off_wrap & 0x8000) != 0
90+
(self.off_wrap & Self::DESC_EVENT_WRAP) != 0
8791
}
8892

8993
/// Set the descriptor event offset and wrap bit.
9094
pub fn set_desc_event(&mut self, off: u16, wrap: bool) {
91-
self.off_wrap = (off & 0x7FFF) | if wrap { 0x8000 } else { 0 };
95+
self.off_wrap =
96+
(off & Self::DESC_EVENT_OFF_MASK) | if wrap { Self::DESC_EVENT_WRAP } else { 0 };
9297
}
9398

9499
/// Create an `EventSuppression` from a raw pointer with acquire semantics.

src/hyperlight_common/src/virtq/mod.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,18 @@ pub use ring::*;
7474
#[derive(Clone, Copy, Debug)]
7575
pub struct Layout {
7676
/// Packed ring descriptor table base in shared memory.
77-
pub desc_table_addr: u64,
77+
desc_table_addr: u64,
7878
/// Number of descriptors (ring size, must be power of 2).
79-
pub desc_table_len: u16,
79+
desc_table_len: u16,
8080
/// Driver-written event suppression area in shared memory.
81-
pub drv_evt_addr: u64,
81+
drv_evt_addr: u64,
8282
/// Device-written event suppression area in shared memory.
83-
pub dev_evt_addr: u64,
83+
dev_evt_addr: u64,
8484
}
8585

8686
#[inline]
8787
const fn align_up(val: usize, align: usize) -> usize {
88-
(val + align - 1) & !(align - 1)
88+
val.next_multiple_of(align)
8989
}
9090

9191
impl Layout {
@@ -131,6 +131,26 @@ impl Layout {
131131
})
132132
}
133133

134+
/// Packed ring descriptor table base in shared memory.
135+
pub const fn desc_table_addr(&self) -> u64 {
136+
self.desc_table_addr
137+
}
138+
139+
/// Number of descriptors in the ring.
140+
pub const fn desc_table_len(&self) -> u16 {
141+
self.desc_table_len
142+
}
143+
144+
/// Driver-written event suppression area in shared memory.
145+
pub const fn drv_evt_addr(&self) -> u64 {
146+
self.drv_evt_addr
147+
}
148+
149+
/// Device-written event suppression area in shared memory.
150+
pub const fn dev_evt_addr(&self) -> u64 {
151+
self.dev_evt_addr
152+
}
153+
134154
/// Calculate the memory size needed for a ring with `num_descs` descriptors,
135155
/// accounting for alignment requirements.
136156
pub const fn query_size(num_descs: usize) -> usize {
@@ -160,26 +180,26 @@ const _: () = {
160180

161181
let expected_size = Layout::query_size(num_descs);
162182

163-
assert!(layout.desc_table_addr == base);
164-
assert!(layout.desc_table_len as usize == num_descs);
183+
assert!(layout.desc_table_addr() == base);
184+
assert!(layout.desc_table_len() as usize == num_descs);
165185
assert!(
166186
layout
167-
.drv_evt_addr
187+
.drv_evt_addr()
168188
.is_multiple_of(EventSuppression::ALIGN as u64)
169189
);
170190
assert!(
171191
layout
172-
.dev_evt_addr
192+
.dev_evt_addr()
173193
.is_multiple_of(EventSuppression::ALIGN as u64)
174194
);
175195

176196
// Events don't overlap with descriptor table
177197
let desc_end = base + (num_descs * Descriptor::SIZE) as u64;
178-
assert!(layout.drv_evt_addr >= desc_end);
179-
assert!(layout.dev_evt_addr >= layout.drv_evt_addr + EventSuppression::SIZE as u64);
198+
assert!(layout.drv_evt_addr() >= desc_end);
199+
assert!(layout.dev_evt_addr() >= layout.drv_evt_addr() + EventSuppression::SIZE as u64);
180200

181201
// Total size from query_size covers entire layout
182-
let layout_end = layout.dev_evt_addr + EventSuppression::SIZE as u64;
202+
let layout_end = layout.dev_evt_addr() + EventSuppression::SIZE as u64;
183203
assert!(base + expected_size as u64 == layout_end);
184204
}
185205

0 commit comments

Comments
 (0)