Skip to content

Commit 96b2918

Browse files
committed
implement compile time definition of inference backend
1 parent 36eb55a commit 96b2918

File tree

1 file changed

+196
-15
lines changed

1 file changed

+196
-15
lines changed

crates/larod/src/lib.rs

Lines changed: 196 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
//! # TODOs:
3737
//! - [ ] [larodDisconnect](https://axiscommunications.github.io/acap-documentation/docs/api/src/api/larod/html/larod_8h.html#ab8f97b4b4d15798384ca25f32ca77bba)
3838
//! indicates it may fail to "kill a session." What are the implications if it fails to kill a session? Can we clear the sessions?
39-
39+
use crate::inference::PrivateSupportedBackend;
4040
use core::slice;
4141
pub use larod_sys::larodAccess as LarodAccess;
42+
pub use larod_sys::larodTensorLayout as TensorLayout;
4243
use larod_sys::*;
4344
use memmap2::{Mmap, MmapMut};
4445
use std::{
@@ -47,7 +48,8 @@ use std::{
4748
fs::File,
4849
marker::PhantomData,
4950
ops,
50-
os::fd::{AsFd, AsRawFd},
51+
os::fd::{AsFd, AsRawFd, BorrowedFd},
52+
path::Path,
5153
ptr::{self},
5254
};
5355

@@ -153,6 +155,8 @@ pub enum Error {
153155
IOError(std::io::Error),
154156
#[error("attempted operation without satisfying all required dependencies")]
155157
UnsatisfiedDependencies,
158+
#[error("an input parameter was incorrect")]
159+
InvalidInput,
156160
}
157161

158162
// impl LarodError {
@@ -844,11 +848,30 @@ pub enum PreProcBackend {
844848
RemoteOpenCLGPU,
845849
}
846850

847-
#[derive(Debug, Default)]
851+
#[derive(Debug)]
852+
pub enum Device {
853+
CPU,
854+
ARTPEC7GPU,
855+
ARTPEC8DLPU,
856+
ARTPEC9DLPU,
857+
}
858+
859+
#[derive(Debug)]
848860
pub enum InferenceChip {
849-
#[default]
850-
TFLiteCPU,
851-
TFLiteDLPU,
861+
TFLite(Device),
862+
}
863+
864+
impl InferenceChip {
865+
pub fn as_str(&self) -> &str {
866+
match self {
867+
InferenceChip::TFLite(d) => match d {
868+
Device::CPU => "cpu-tflite",
869+
Device::ARTPEC7GPU => "axis-a7-gpu-tflite",
870+
Device::ARTPEC8DLPU => "axis-a8-dlpu-tflite",
871+
Device::ARTPEC9DLPU => "a9-dlpu-tflite",
872+
},
873+
}
874+
}
852875
}
853876

854877
#[derive(Debug, Default)]
@@ -1218,19 +1241,19 @@ impl<'a> Drop for JobRequest<'a> {
12181241
}
12191242

12201243
// #[derive(Default)]
1221-
// pub struct ModelBuilder {
1222-
// file_path: Option<PathBuf>,
1244+
// pub struct ModelBuilder<'a> {
1245+
// file_path: Option<&'a Path>,
12231246
// device: InferenceChip,
12241247
// crop: Option<(u32, u32, u32, u32)>,
12251248
// }
12261249

1227-
// impl ModelBuilder {
1250+
// impl<'a> ModelBuilder<'a> {
12281251
// pub fn new() -> Self {
12291252
// ModelBuilder::default()
12301253
// }
12311254

1232-
// pub fn source_file(mut self, path: PathBuf) -> Self {
1233-
// self.file_path = Some(path);
1255+
// pub fn source_file<P: AsRef<Path>>(mut self, path: &'a P) -> Self {
1256+
// self.file_path = Some(path.as_ref());
12341257
// self
12351258
// }
12361259

@@ -1239,12 +1262,111 @@ impl<'a> Drop for JobRequest<'a> {
12391262
// self
12401263
// }
12411264

1242-
// pub fn with_crop(mut self, crop: (u32, u32, u32, u32)) -> Self {
1243-
// self.crop = Some(crop);
1244-
// self
1265+
// pub fn load(self, session: Session) -> Model {
1266+
// File::open(s)
1267+
// }
1268+
// }
1269+
1270+
mod inference {
1271+
pub trait PrivateSupportedBackend {
1272+
fn as_str() -> &'static str;
1273+
}
1274+
}
1275+
1276+
// pub trait SupportedBackend: inference::PrivateSupportedBackend {
1277+
// fn as_str() -> &'static str;
1278+
// }
1279+
1280+
// Marker types
1281+
pub struct TFLite;
1282+
pub struct CVFlowNN;
1283+
pub struct Native;
1284+
1285+
// Hardware types that specify which modes they support
1286+
pub struct CPU;
1287+
pub struct EdgeTPU;
1288+
pub struct GPU;
1289+
pub struct Artpec7GPU;
1290+
pub struct Artpec8DLPU;
1291+
pub struct Artpec9DLPU;
1292+
pub struct ArmNNCPU;
1293+
1294+
impl inference::PrivateSupportedBackend for (TFLite, CPU) {
1295+
fn as_str() -> &'static str {
1296+
"cpu-tflite"
1297+
}
1298+
}
1299+
impl inference::PrivateSupportedBackend for (TFLite, Artpec7GPU) {
1300+
fn as_str() -> &'static str {
1301+
"axis-a7-gpu-tflite"
1302+
}
1303+
}
1304+
impl inference::PrivateSupportedBackend for (TFLite, Artpec8DLPU) {
1305+
fn as_str() -> &'static str {
1306+
"axis-a8-dlpu-tflite"
1307+
}
1308+
}
1309+
impl inference::PrivateSupportedBackend for (TFLite, Artpec9DLPU) {
1310+
fn as_str() -> &'static str {
1311+
"a9-dlpu-tflite"
1312+
}
1313+
}
1314+
1315+
// // A type-safe configuration
1316+
// pub struct InferenceBackend<M, H> {
1317+
// mode: M,
1318+
// hardware: H,
1319+
// }
1320+
1321+
// impl SupportedBackend for (TFLite, CPU) {
1322+
// fn as_str() -> &'static str {
1323+
// "cpu-tflite"
12451324
// }
1325+
// }
12461326

1247-
// pub fn load(self, session: Session) -> Model {}
1327+
// impl SupportedBackend for (TFLite, Artpec8DLPU) {
1328+
// fn new() -> (TFLite, Artpec8DLPU) {
1329+
// (TFLite, Artpec8DLPU)
1330+
// }
1331+
// fn as_str() -> &str {
1332+
// "cpu-tflite"
1333+
// }
1334+
// }
1335+
1336+
// impl SupportedBackend for InferenceBackend<TFLite, Artpec7GPU> {
1337+
// fn new() -> InferenceBackend<TFLite, Artpec7GPU> {
1338+
// InferenceBackend {
1339+
// mode: TFLite,
1340+
// hardware: Artpec7GPU,
1341+
// }
1342+
// }
1343+
// fn as_str(&self) -> &str {
1344+
// "axis-a7-gpu-tflite"
1345+
// }
1346+
// }
1347+
1348+
// impl SupportedBackend for InferenceBackend<TFLite, Artpec8DLPU> {
1349+
// fn new() -> InferenceBackend<TFLite, Artpec8DLPU> {
1350+
// InferenceBackend {
1351+
// mode: TFLite,
1352+
// hardware: Artpec8DLPU,
1353+
// }
1354+
// }
1355+
// fn as_str(&self) -> &str {
1356+
// "axis-a8-dlpu-tflite"
1357+
// }
1358+
// }
1359+
1360+
// impl SupportedBackend for InferenceBackend<TFLite, Artpec9DLPU> {
1361+
// fn new() -> InferenceBackend<TFLite, Artpec9DLPU> {
1362+
// InferenceBackend {
1363+
// mode: TFLite,
1364+
// hardware: Artpec9DLPU,
1365+
// }
1366+
// }
1367+
// fn as_str(&self) -> &str {
1368+
// "a9-dlpu-tflite"
1369+
// }
12481370
// }
12491371

12501372
pub struct InferenceModel<'a> {
@@ -1254,9 +1376,68 @@ pub struct InferenceModel<'a> {
12541376
num_inputs: usize,
12551377
output_tensors: Option<LarodTensorContainer<'a>>,
12561378
num_outputs: usize,
1379+
params: Option<LarodMap>,
12571380
}
12581381

12591382
impl<'a> InferenceModel<'a> {
1383+
pub fn new<M, H, P>(
1384+
session: &'a Session,
1385+
model_file: P,
1386+
// chip: InferenceBackend<M, H>,
1387+
access: LarodAccess,
1388+
name: &str,
1389+
params: Option<LarodMap>,
1390+
) -> Result<InferenceModel<'a>>
1391+
where
1392+
(M, H): inference::PrivateSupportedBackend,
1393+
P: AsRef<Path>,
1394+
{
1395+
let f = File::open(model_file).map_err(Error::IOError)?;
1396+
let Ok(device_name) = CString::new(<(M, H)>::as_str()) else {
1397+
return Err(Error::CStringAllocation);
1398+
};
1399+
let (device, maybe_device_error) =
1400+
unsafe { try_func!(larodGetDevice, session.conn, device_name.as_ptr(), 0) };
1401+
if !device.is_null() {
1402+
debug_assert!(
1403+
maybe_device_error.is_none(),
1404+
"larodGetDevice indicated success AND returned an error!"
1405+
);
1406+
} else {
1407+
return Err(maybe_device_error.unwrap_or(Error::MissingLarodError));
1408+
}
1409+
let Ok(name) = CString::new(name) else {
1410+
return Err(Error::CStringAllocation);
1411+
};
1412+
let (larod_model_ptr, maybe_error) = unsafe {
1413+
try_func!(
1414+
larodLoadModel,
1415+
session.conn,
1416+
f.as_raw_fd(),
1417+
device,
1418+
access,
1419+
name.as_ptr(),
1420+
params.map_or_else(|| ptr::null(), |p| p.raw)
1421+
)
1422+
};
1423+
if !larod_model_ptr.is_null() {
1424+
debug_assert!(
1425+
maybe_device_error.is_none(),
1426+
"larodLoadModel indicated success AND returned an error!"
1427+
);
1428+
Ok(InferenceModel {
1429+
session: session,
1430+
ptr: larod_model_ptr,
1431+
input_tensors: None,
1432+
num_inputs: 0,
1433+
output_tensors: None,
1434+
num_outputs: 0,
1435+
params: None,
1436+
})
1437+
} else {
1438+
Err(maybe_error.unwrap_or(Error::MissingLarodError))
1439+
}
1440+
}
12601441
pub fn id() -> Result<()> {
12611442
Ok(())
12621443
}

0 commit comments

Comments
 (0)