Skip to content

Commit 79ee809

Browse files
committed
fix(virtq): address code review
Signed-off-by: Tomasz Andrzejak <andreiltd@gmail.com>
1 parent 4f3cbf7 commit 79ee809

3 files changed

Lines changed: 30 additions & 33 deletions

File tree

src/hyperlight_common/src/virtq/access.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,10 @@ pub trait MemOps {
4747
/// * `addr` - Guest physical address to read from
4848
/// * `dst` - Destination buffer to fill
4949
///
50-
/// # Returns
51-
///
52-
/// Number of bytes actually read (should equal `dst.len()` on success).
53-
///
5450
/// # Safety
5551
///
5652
/// The caller must ensure `addr` is valid and points to at least `dst.len()` bytes.
57-
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<usize, Self::Error>;
53+
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error>;
5854

5955
/// Write bytes to physical memory.
6056
///
@@ -63,14 +59,10 @@ pub trait MemOps {
6359
/// * `addr` - address to write to
6460
/// * `src` - Source data to write
6561
///
66-
/// # Returns
67-
///
68-
/// Number of bytes actually written (should equal `src.len()` on success).
69-
///
7062
/// # Safety
7163
///
7264
/// The caller must ensure `addr` is valid and points to at least `src.len()` bytes.
73-
fn write(&self, addr: u64, src: &[u8]) -> Result<usize, Self::Error>;
65+
fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error>;
7466

7567
/// Load a u16 with acquire semantics.
7668
///
@@ -140,11 +132,11 @@ pub trait MemOps {
140132
impl<T: MemOps> MemOps for Arc<T> {
141133
type Error = T::Error;
142134

143-
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<usize, Self::Error> {
135+
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> {
144136
(**self).read(addr, dst)
145137
}
146138

147-
fn write(&self, addr: u64, src: &[u8]) -> Result<usize, Self::Error> {
139+
fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> {
148140
(**self).write(addr, src)
149141
}
150142

src/hyperlight_common/src/virtq/mod.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,31 @@ impl Layout {
9292
/// Create a Layout from a base address and number of descriptors.
9393
///
9494
/// The base address must be aligned to `Descriptor::ALIGN`.
95+
/// The number of descriptors must be a power of 2.
9596
/// The memory region starting at `base` must be at least `Layout::query_size(num_descs)` bytes.
9697
///
9798
/// # Safety
9899
/// - `base` must be valid for `Layout::query_size(num_descs)` bytes.
99100
/// - `base` must be aligned to `Descriptor::ALIGN`.
100101
/// - Memory must remain valid for the lifetime of the ring.
101102
pub const unsafe fn from_base(base: u64, num_descs: NonZeroU16) -> Result<Self, RingError> {
103+
let num_descs = num_descs.get() as usize;
104+
if !num_descs.is_power_of_two() {
105+
return Err(RingError::InvalidLayout);
106+
}
107+
102108
if !base.is_multiple_of(Descriptor::ALIGN as u64) {
103109
return Err(RingError::InvalidLayout);
104110
}
105111

106-
let desc_size = num_descs.get() as usize * Descriptor::SIZE;
112+
if base
113+
.checked_add(Layout::query_size(num_descs) as u64)
114+
.is_none()
115+
{
116+
return Err(RingError::InvalidLayout);
117+
}
118+
119+
let desc_size = num_descs * Descriptor::SIZE;
107120
let event_size = EventSuppression::SIZE;
108121
let event_align = EventSuppression::ALIGN;
109122

@@ -112,7 +125,7 @@ impl Layout {
112125

113126
Ok(Self {
114127
desc_table_addr: base,
115-
desc_table_len: num_descs.get(),
128+
desc_table_len: num_descs as u16,
116129
drv_evt_addr: base + drv_evt_offset as u64,
117130
dev_evt_addr: base + dev_evt_offset as u64,
118131
})
@@ -170,6 +183,10 @@ const _: () = {
170183
assert!(base + expected_size as u64 == layout_end);
171184
}
172185

186+
unsafe {
187+
assert!(Layout::from_base(u64::MAX, NonZeroU16::new(1).unwrap()).is_err());
188+
}
189+
173190
verify_layout(1);
174191
verify_layout(2);
175192
verify_layout(4);

src/hyperlight_common/src/virtq/ring.rs

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ impl BufferChainBuilder<Writable> {
288288
///
289289
/// Contains a scatter-gather list of [`BufferElement`]s, divided into
290290
/// readable (driver->device) and writable (device->driver) sections.
291-
#[derive(Debug, Default, Clone)]
291+
#[derive(Debug, Clone)]
292292
pub struct BufferChain {
293293
/// All buffer elements (readable followed by writable)
294294
elems: SmallVec<[BufferElement; 16]>,
@@ -1350,20 +1350,20 @@ pub(crate) mod tests {
13501350
impl MemOps for TestMem {
13511351
type Error = core::convert::Infallible;
13521352

1353-
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<usize, Self::Error> {
1353+
fn read(&self, addr: u64, dst: &mut [u8]) -> Result<(), Self::Error> {
13541354
let src = self.ptr_for_addr(addr);
13551355
unsafe {
13561356
ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len());
13571357
}
1358-
Ok(dst.len())
1358+
Ok(())
13591359
}
13601360

1361-
fn write(&self, addr: u64, src: &[u8]) -> Result<usize, Self::Error> {
1361+
fn write(&self, addr: u64, src: &[u8]) -> Result<(), Self::Error> {
13621362
let dst = self.ptr_for_addr(addr);
13631363
unsafe {
13641364
ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len());
13651365
}
1366-
Ok(src.len())
1366+
Ok(())
13671367
}
13681368

13691369
fn read_val<T: Pod>(&self, addr: u64) -> Result<T, Self::Error> {
@@ -1842,18 +1842,6 @@ pub(crate) mod tests {
18421842
));
18431843
}
18441844

1845-
#[test]
1846-
fn test_empty_chain_rejected() {
1847-
let chain = BufferChain::default();
1848-
assert_eq!(chain.len(), 0);
1849-
1850-
let ring = make_ring(4);
1851-
let mut producer = make_producer(&ring);
1852-
1853-
let result = producer.submit_available(&chain);
1854-
assert!(matches!(result, Err(RingError::EmptyChain)));
1855-
}
1856-
18571845
#[test]
18581846
fn test_wrap_stress() {
18591847
let ring = make_ring(4);
@@ -2491,7 +2479,7 @@ pub(crate) mod tests {
24912479
// Out-of-order multi-length explicit
24922480
#[test]
24932481
fn test_out_of_order_multi_length() {
2494-
let ring = make_ring(12);
2482+
let ring = make_ring(16);
24952483
let mut producer = make_producer(&ring);
24962484
let mut consumer = make_consumer(&ring);
24972485

@@ -3311,7 +3299,7 @@ mod fuzz {
33113299

33123300
impl Arbitrary for Scenario {
33133301
fn arbitrary(g: &mut Gen) -> Self {
3314-
let table_size = usize::arbitrary(g) % MAX_RING + 1;
3302+
let table_size = (usize::arbitrary(g) % MAX_RING + 1).next_power_of_two();
33153303
let num_ops = usize::arbitrary(g) % MAX_OPS + 1;
33163304

33173305
let ops = (0..num_ops).map(|_| Op::arbitrary(g)).collect();

0 commit comments

Comments
 (0)