Skip to content

Commit 36eb55a

Browse files
committed
implement more Tensor methods
1 parent 6aa9c1e commit 36eb55a

File tree

1 file changed

+87
-15
lines changed

1 file changed

+87
-15
lines changed

crates/larod/src/lib.rs

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,20 @@ impl<'a> Tensor<'a> {
523523
}
524524
}
525525

526-
pub fn byte_size() {}
526+
pub fn byte_size(&self) -> Result<usize> {
527+
let mut byte_size: usize = 0;
528+
let (success, maybe_error) =
529+
unsafe { try_func!(larodGetTensorByteSize, self.ptr, &mut byte_size) };
530+
if success {
531+
debug_assert!(
532+
maybe_error.is_none(),
533+
"larodGetTensorByteSize indicated success AND returned an error!"
534+
);
535+
Ok(byte_size)
536+
} else {
537+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
538+
}
539+
}
527540

528541
pub fn dims(&self) -> Result<&[usize]> {
529542
let (dims, maybe_error) = unsafe { try_func!(larodGetTensorDims, self.ptr) };
@@ -532,13 +545,6 @@ impl<'a> Tensor<'a> {
532545
maybe_error.is_none(),
533546
"larodGetTensorDims indicated success AND returned an error!"
534547
);
535-
// let d = unsafe {
536-
// (*dims)
537-
// .dims
538-
// .into_iter()
539-
// .take((*dims).len)
540-
// .collect::<Vec<usize>>()
541-
// };
542548
let (left, _) = unsafe { (*dims).dims.split_at((*dims).len) };
543549
Ok(left)
544550
} else {
@@ -547,13 +553,15 @@ impl<'a> Tensor<'a> {
547553
}
548554

549555
pub fn set_dims(&self, dims: &[usize]) -> Result<()> {
550-
let mut dim_array: [usize; 12] = [0; 12];
551-
for (idx, dim) in dims.iter().take(12).enumerate() {
552-
dim_array[idx] = *dim;
556+
if dims.len() > 12 {
557+
return Err(Error::InvalidInput);
553558
}
559+
560+
let mut dim_array: [usize; 12] = [0; 12];
561+
dim_array[..12.min(dims.len())].copy_from_slice(&dims[..12.min(dims.len())]);
554562
let dims_struct = larodTensorDims {
555563
dims: dim_array,
556-
len: dims.len(),
564+
len: dims.len().min(12),
557565
};
558566
let (success, maybe_error) =
559567
unsafe { try_func!(larodSetTensorDims, self.ptr, &dims_struct) };
@@ -589,7 +597,29 @@ impl<'a> Tensor<'a> {
589597
Err(maybe_error.unwrap_or(Error::MissingLarodError))
590598
}
591599
}
592-
pub fn set_pitches() {}
600+
pub fn set_pitches(&mut self, pitches: &[usize]) -> Result<()> {
601+
if pitches.len() > 12 {
602+
return Err(Error::InvalidInput);
603+
}
604+
605+
let mut pitch_array: [usize; 12] = [0; 12];
606+
pitch_array[..12.min(pitches.len())].copy_from_slice(&pitches[..12.min(pitches.len())]);
607+
let pitch_struct = larodTensorPitches {
608+
pitches: pitch_array,
609+
len: pitches.len().min(12),
610+
};
611+
let (success, maybe_error) =
612+
unsafe { try_func!(larodSetTensorPitches, self.ptr, &pitch_struct) };
613+
if success {
614+
debug_assert!(
615+
maybe_error.is_none(),
616+
"larodSetTensorPitches indicated success AND returned an error!"
617+
);
618+
Ok(())
619+
} else {
620+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
621+
}
622+
}
593623
pub fn data_type() {}
594624
pub fn set_data_type() {}
595625
pub fn layout(&self) -> Result<larodTensorLayout> {
@@ -601,11 +631,39 @@ impl<'a> Tensor<'a> {
601631
Err(maybe_error.unwrap_or(Error::MissingLarodError))
602632
}
603633
}
604-
pub fn set_layout() {}
634+
pub fn set_layout(&mut self, layout: TensorLayout) -> Result<()> {
635+
let (success, maybe_error) = unsafe { try_func!(larodSetTensorLayout, self.ptr, layout) };
636+
if success {
637+
debug_assert!(
638+
maybe_error.is_none(),
639+
"larodSetTensorLayout indicated success AND returned an error!"
640+
);
641+
Ok(())
642+
} else {
643+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
644+
}
645+
}
605646
pub fn fd(&self) -> Option<std::os::fd::BorrowedFd<'_>> {
606647
self.buffer.as_ref().map(|f| f.as_fd())
607648
}
608649

650+
/// Set the file descriptor for the tensor to use.
651+
pub fn set_fd(&mut self, fd: BorrowedFd) -> Result<()> {
652+
let (success, maybe_error) =
653+
unsafe { try_func!(larodSetTensorFd, self.ptr, fd.as_raw_fd()) };
654+
if success {
655+
debug_assert!(
656+
maybe_error.is_none(),
657+
"larodSetTensorFd indicated success AND returned an error!"
658+
);
659+
self.mmap = None;
660+
self.buffer = None;
661+
Ok(())
662+
} else {
663+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
664+
}
665+
}
666+
609667
/// Use a memory mapped file as a buffer for this tensor.
610668
/// The method name here differs a bit from the larodSetTensorFd,
611669
/// but aligns better with the need for the tensor to own the
@@ -635,14 +693,28 @@ impl<'a> Tensor<'a> {
635693
}
636694
}
637695
pub fn fd_size() {}
638-
pub fn set_fd_size() {}
696+
pub fn set_fd_size(&mut self, size: usize) -> Result<()> {
697+
let (success, maybe_error) = unsafe { try_func!(larodSetTensorFdSize, self.ptr, size) };
698+
if success {
699+
debug_assert!(
700+
maybe_error.is_none(),
701+
"larodSetTensorFdSize indicated success AND returned an error!"
702+
);
703+
Ok(())
704+
} else {
705+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
706+
}
707+
}
639708
pub fn fd_offset() {}
640709
pub fn set_fd_offset() {}
641710
pub fn fd_props() {}
642711
pub fn set_fd_props() {}
643712
pub fn as_slice(&self) -> Option<&[u8]> {
644713
self.mmap.as_deref()
645714
}
715+
pub fn as_mut_slice(&mut self) -> Option<&mut [u8]> {
716+
self.mmap.as_deref_mut()
717+
}
646718
pub fn copy_from_slice(&mut self, slice: &[u8]) {
647719
if let Some(mmap) = self.mmap.as_mut() {
648720
mmap.copy_from_slice(slice);

0 commit comments

Comments
 (0)