@@ -523,7 +523,20 @@ impl<'a> Tensor<'a> {
523
523
}
524
524
}
525
525
526
- pub fn byte_size ( ) { }
526
+ pub fn byte_size ( & self ) -> Result < usize > {
527
+ let mut byte_size: usize = 0 ;
528
+ let ( success, maybe_error) =
529
+ unsafe { try_func ! ( larodGetTensorByteSize, self . ptr, & mut byte_size) } ;
530
+ if success {
531
+ debug_assert ! (
532
+ maybe_error. is_none( ) ,
533
+ "larodGetTensorByteSize indicated success AND returned an error!"
534
+ ) ;
535
+ Ok ( byte_size)
536
+ } else {
537
+ Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
538
+ }
539
+ }
527
540
528
541
pub fn dims ( & self ) -> Result < & [ usize ] > {
529
542
let ( dims, maybe_error) = unsafe { try_func ! ( larodGetTensorDims, self . ptr) } ;
@@ -532,13 +545,6 @@ impl<'a> Tensor<'a> {
532
545
maybe_error. is_none( ) ,
533
546
"larodGetTensorDims indicated success AND returned an error!"
534
547
) ;
535
- // let d = unsafe {
536
- // (*dims)
537
- // .dims
538
- // .into_iter()
539
- // .take((*dims).len)
540
- // .collect::<Vec<usize>>()
541
- // };
542
548
let ( left, _) = unsafe { ( * dims) . dims . split_at ( ( * dims) . len ) } ;
543
549
Ok ( left)
544
550
} else {
@@ -547,13 +553,15 @@ impl<'a> Tensor<'a> {
547
553
}
548
554
549
555
pub fn set_dims ( & self , dims : & [ usize ] ) -> Result < ( ) > {
550
- let mut dim_array: [ usize ; 12 ] = [ 0 ; 12 ] ;
551
- for ( idx, dim) in dims. iter ( ) . take ( 12 ) . enumerate ( ) {
552
- dim_array[ idx] = * dim;
556
+ if dims. len ( ) > 12 {
557
+ return Err ( Error :: InvalidInput ) ;
553
558
}
559
+
560
+ let mut dim_array: [ usize ; 12 ] = [ 0 ; 12 ] ;
561
+ dim_array[ ..12 . min ( dims. len ( ) ) ] . copy_from_slice ( & dims[ ..12 . min ( dims. len ( ) ) ] ) ;
554
562
let dims_struct = larodTensorDims {
555
563
dims : dim_array,
556
- len : dims. len ( ) ,
564
+ len : dims. len ( ) . min ( 12 ) ,
557
565
} ;
558
566
let ( success, maybe_error) =
559
567
unsafe { try_func ! ( larodSetTensorDims, self . ptr, & dims_struct) } ;
@@ -589,7 +597,29 @@ impl<'a> Tensor<'a> {
589
597
Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
590
598
}
591
599
}
592
- pub fn set_pitches ( ) { }
600
+ pub fn set_pitches ( & mut self , pitches : & [ usize ] ) -> Result < ( ) > {
601
+ if pitches. len ( ) > 12 {
602
+ return Err ( Error :: InvalidInput ) ;
603
+ }
604
+
605
+ let mut pitch_array: [ usize ; 12 ] = [ 0 ; 12 ] ;
606
+ pitch_array[ ..12 . min ( pitches. len ( ) ) ] . copy_from_slice ( & pitches[ ..12 . min ( pitches. len ( ) ) ] ) ;
607
+ let pitch_struct = larodTensorPitches {
608
+ pitches : pitch_array,
609
+ len : pitches. len ( ) . min ( 12 ) ,
610
+ } ;
611
+ let ( success, maybe_error) =
612
+ unsafe { try_func ! ( larodSetTensorPitches, self . ptr, & pitch_struct) } ;
613
+ if success {
614
+ debug_assert ! (
615
+ maybe_error. is_none( ) ,
616
+ "larodSetTensorPitches indicated success AND returned an error!"
617
+ ) ;
618
+ Ok ( ( ) )
619
+ } else {
620
+ Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
621
+ }
622
+ }
593
623
pub fn data_type ( ) { }
594
624
pub fn set_data_type ( ) { }
595
625
pub fn layout ( & self ) -> Result < larodTensorLayout > {
@@ -601,11 +631,39 @@ impl<'a> Tensor<'a> {
601
631
Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
602
632
}
603
633
}
604
- pub fn set_layout ( ) { }
634
+ pub fn set_layout ( & mut self , layout : TensorLayout ) -> Result < ( ) > {
635
+ let ( success, maybe_error) = unsafe { try_func ! ( larodSetTensorLayout, self . ptr, layout) } ;
636
+ if success {
637
+ debug_assert ! (
638
+ maybe_error. is_none( ) ,
639
+ "larodSetTensorLayout indicated success AND returned an error!"
640
+ ) ;
641
+ Ok ( ( ) )
642
+ } else {
643
+ Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
644
+ }
645
+ }
605
646
pub fn fd ( & self ) -> Option < std:: os:: fd:: BorrowedFd < ' _ > > {
606
647
self . buffer . as_ref ( ) . map ( |f| f. as_fd ( ) )
607
648
}
608
649
650
+ /// Set the file descriptor for the tensor to use.
651
+ pub fn set_fd ( & mut self , fd : BorrowedFd ) -> Result < ( ) > {
652
+ let ( success, maybe_error) =
653
+ unsafe { try_func ! ( larodSetTensorFd, self . ptr, fd. as_raw_fd( ) ) } ;
654
+ if success {
655
+ debug_assert ! (
656
+ maybe_error. is_none( ) ,
657
+ "larodSetTensorFd indicated success AND returned an error!"
658
+ ) ;
659
+ self . mmap = None ;
660
+ self . buffer = None ;
661
+ Ok ( ( ) )
662
+ } else {
663
+ Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
664
+ }
665
+ }
666
+
609
667
/// Use a memory mapped file as a buffer for this tensor.
610
668
/// The method name here differs a bit from the larodSetTensorFd,
611
669
/// but aligns better with the need for the tensor to own the
@@ -635,14 +693,28 @@ impl<'a> Tensor<'a> {
635
693
}
636
694
}
637
695
pub fn fd_size ( ) { }
638
- pub fn set_fd_size ( ) { }
696
+ pub fn set_fd_size ( & mut self , size : usize ) -> Result < ( ) > {
697
+ let ( success, maybe_error) = unsafe { try_func ! ( larodSetTensorFdSize, self . ptr, size) } ;
698
+ if success {
699
+ debug_assert ! (
700
+ maybe_error. is_none( ) ,
701
+ "larodSetTensorFdSize indicated success AND returned an error!"
702
+ ) ;
703
+ Ok ( ( ) )
704
+ } else {
705
+ Err ( maybe_error. unwrap_or ( Error :: MissingLarodError ) )
706
+ }
707
+ }
639
708
pub fn fd_offset ( ) { }
640
709
pub fn set_fd_offset ( ) { }
641
710
pub fn fd_props ( ) { }
642
711
pub fn set_fd_props ( ) { }
643
712
pub fn as_slice ( & self ) -> Option < & [ u8 ] > {
644
713
self . mmap . as_deref ( )
645
714
}
715
+ pub fn as_mut_slice ( & mut self ) -> Option < & mut [ u8 ] > {
716
+ self . mmap . as_deref_mut ( )
717
+ }
646
718
pub fn copy_from_slice ( & mut self , slice : & [ u8 ] ) {
647
719
if let Some ( mmap) = self . mmap . as_mut ( ) {
648
720
mmap. copy_from_slice ( slice) ;
0 commit comments