Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 11 additions & 40 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ edition = "2021"
crate-type = ["staticlib", "cdylib"]

[dependencies]
ocaml = { version = "^1.0.0-beta" }
rsdd = { git = "https://github.com/neuppl/rsdd", rev = "1613459" }
ocaml = {git = "https://github.com/zshipko/ocaml-rs"}
rsdd = { git = "https://github.com/minsungc/rsdd-dappl", rev = "43816ddc9aadb96606fc10bf2d33cfe7e3f018d5" }

[build-dependencies]
ocaml-build = {version = "^1.0.0-beta"}
111 changes: 83 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use rsdd::{
builder::{bdd::RobddBuilder, cache::AllIteTable, BottomUpBuilder},
builder::{bdd::{RobddBuilder, BddBuilder}, cache::AllIteTable, BottomUpBuilder},
constants::primes,
repr::{BddPtr, Cnf, DDNNFPtr, PartialModel, VarLabel, VarOrder, WmcParams},
util::semirings::{ExpectedUtility, FiniteField, RealSemiring, Semiring},
Expand Down Expand Up @@ -36,15 +36,16 @@ unsafe impl ocaml::ToValue for RsddVarLabel {
unsafe impl ocaml::FromValue for RsddVarLabel {
fn from_value(v: ocaml::Value) -> Self {
let i = unsafe { v.int64_val() };
RsddVarLabel(VarLabel::new(i as u64))
RsddVarLabel(VarLabel::new(i.try_into().unwrap()))
}
}


// disc/dice interface

#[ocaml::func]
#[ocaml::sig("int64 -> rsdd_bdd_builder")]
pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer<RsddBddBuilder> {
pub fn mk_bdd_builder_default_order(num_vars: i64) -> ocaml::Pointer<RsddBddBuilder> {
RsddBddBuilder(RobddBuilder::<AllIteTable<BddPtr>>::new(
VarOrder::linear_order(num_vars as usize),
))
Expand All @@ -56,9 +57,35 @@ pub fn mk_bdd_builder_default_order(num_vars: u64) -> ocaml::Pointer<RsddBddBuil
pub fn bdd_new_var(
builder: &'static RsddBddBuilder,
polarity: bool,
) -> (u64, ocaml::Pointer<RsddBddPtr>) {
) -> (i64, ocaml::Pointer<RsddBddPtr>) {
let (lbl, ptr) = builder.0.new_var(polarity);
(lbl.value(), RsddBddPtr(ptr).into())
(lbl.value().try_into().unwrap(), RsddBddPtr(ptr).into())
}

#[ocaml::func]
#[ocaml::sig("int64 -> rsdd_var_label")]
pub fn mk_varlabel(
i : i64
) -> RsddVarLabel {
RsddVarLabel(VarLabel::new(i.try_into().unwrap()))
}

#[ocaml::func]
#[ocaml::sig("rsdd_var_label -> int64")]
pub fn extract_varlabel(
v : RsddVarLabel
) -> i64 {
v.0.value() as i64
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_var_label -> bool -> rsdd_bdd_ptr")]
pub fn bdd_var(
builder: &'static RsddBddBuilder,
lbl : RsddVarLabel,
polarity : bool,
) -> ocaml::Pointer<RsddBddPtr> {
RsddBddPtr(builder.0.var(lbl.0, polarity)).into()
}

#[ocaml::func]
Expand Down Expand Up @@ -101,6 +128,16 @@ pub fn bdd_negate(
RsddBddPtr(builder.0.negate(bdd.0)).into()
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> int64 list -> rsdd_bdd_ptr")]
pub fn bdd_exactlyone(
builder: &'static RsddBddBuilder,
l : ocaml::List<i64>,
) -> ocaml::Pointer<RsddBddPtr> {
let l_of_varlabels : Vec<_> = l.into_vec().iter().map(|x| VarLabel::new_usize(*x as usize)).collect();
RsddBddPtr(builder.0.exactly_one_of_varlabels(&l_of_varlabels)).into()
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_bdd_ptr")]
pub fn bdd_true(builder: &'static RsddBddBuilder) -> ocaml::Pointer<RsddBddPtr> {
Expand Down Expand Up @@ -139,10 +176,10 @@ pub fn bdd_eq(builder: &'static RsddBddBuilder, a: &RsddBddPtr, b: &RsddBddPtr)

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> int64")]
pub fn bdd_topvar(bdd: &RsddBddPtr) -> u64 {
pub fn bdd_topvar(bdd: &RsddBddPtr) -> i64 {
match (bdd.0).var_safe() {
Some(x) => x.value(),
None => 0, // TODO: provide a better version for this, maybe a Maybe/Option?
Some(x) => x.value().try_into().unwrap(),
None => -1, // TODO: provide a better version for this, maybe a Maybe/Option?
}
}

Expand Down Expand Up @@ -180,7 +217,7 @@ pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer<Rsdd
.enumerate()
.map(|(index, (a, b))| {
(
VarLabel::new(index as u64),
VarLabel::new(index.try_into().unwrap()),
(RealSemiring(*a), RealSemiring(*b)),
)
}),
Expand All @@ -190,57 +227,75 @@ pub fn new_wmc_params_r(weights: ocaml::List<(f64, f64)>) -> ocaml::Pointer<Rsdd

// branch & bound, expected semiring items
#[ocaml::sig]
#[derive(ocaml::ToValue, ocaml::FromValue)]
pub struct RsddExpectedUtility(ExpectedUtility);
ocaml::custom!(RsddExpectedUtility);

#[ocaml::sig]
pub struct RsddWmcParamsEU(WmcParams<ExpectedUtility>);
ocaml::custom!(RsddWmcParamsEU);


#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model")]
pub fn bdd_bb(
#[ocaml::sig("rsdd_expected_utility -> float * float")]
pub fn extract(
eu : RsddExpectedUtility
) -> (f64, f64) {
let v = eu.0 ;
(v.0, v.1)
}


#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model * int64")]
pub fn bdd_meu_without_cache(
bdd: &'static RsddBddPtr,
evidence: &'static RsddBddPtr,
join_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
num_vars: i64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
RsddExpectedUtility,
ocaml::Pointer<RsddPartialModel>,
i64
) {
let (eu, pm) = bdd.0.bb(
let (eu, pm, size) = bdd.0.meu(
evidence.0,
&join_vars
.into_linked_list()
.iter()
.map(|x| x.0)
.collect::<Vec<_>>(),
num_vars as usize,
num_vars.try_into().unwrap(),
&wmc.0,
);
(RsddExpectedUtility(eu).into(), RsddPartialModel(pm).into())
(RsddExpectedUtility(eu), RsddPartialModel(pm).into(), size as i64)
}

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model")]
#[ocaml::sig("rsdd_bdd_ptr -> rsdd_bdd_ptr -> rsdd_var_label list -> int64 -> rsdd_wmc_params_e_u -> rsdd_expected_utility * rsdd_partial_model * int64")]
pub fn bdd_meu(
bdd: &'static RsddBddPtr,
decision_vars: ocaml::List<RsddVarLabel>,
num_vars: u64,
evidence: &'static RsddBddPtr,
join_vars: ocaml::List<RsddVarLabel>,
num_vars: i64,
wmc: &RsddWmcParamsEU,
) -> (
ocaml::Pointer<RsddExpectedUtility>,
RsddExpectedUtility,
ocaml::Pointer<RsddPartialModel>,
i64
) {
let (eu, pm) = bdd.0.bb(
&decision_vars
let (eu, pm, size) = bdd.0.bb(
evidence.0,
&join_vars
.into_linked_list()
.iter()
.map(|x| x.0)
.collect::<Vec<_>>(),
num_vars as usize,
num_vars.try_into().unwrap(),
&wmc.0,
);
(RsddExpectedUtility(eu).into(), RsddPartialModel(pm).into())
(RsddExpectedUtility(eu), RsddPartialModel(pm).into(), size as i64)
}

#[ocaml::func]
Expand All @@ -255,7 +310,7 @@ pub fn new_wmc_params_eu(
.enumerate()
.map(|(index, (a, b))| {
(
VarLabel::new(index as u64),
VarLabel::new(index.try_into().unwrap()),
(ExpectedUtility(a.0, a.1), ExpectedUtility(b.0, b.1)),
)
}),
Expand All @@ -282,15 +337,15 @@ pub fn bdd_builder_compile_cnf(

#[ocaml::func]
#[ocaml::sig("rsdd_bdd_builder -> rsdd_bdd_ptr -> int64")]
pub fn bdd_model_count(builder: &'static RsddBddBuilder, bdd: &'static RsddBddPtr) -> u64 {
pub fn bdd_model_count(builder: &'static RsddBddBuilder, bdd: &'static RsddBddPtr) -> i64 {
let num_vars = builder.0.num_vars();
let smoothed = builder.0.smooth(bdd.0, num_vars);
let unweighted_params: WmcParams<FiniteField<{ primes::U64_LARGEST }>> =
WmcParams::new(HashMap::from_iter(
(0..num_vars as u64)
(0..num_vars.try_into().unwrap())
.map(|v| (VarLabel::new(v), (FiniteField::one(), FiniteField::one()))),
));

let mc = smoothed.unsmoothed_wmc(&unweighted_params).value();
mc as u64
mc.try_into().unwrap()
}
Loading