Skip to content

Improve AD system, clear memory explosions, add dmnormAD and PDinverse_logdet #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
36c3caf
improve CppAD new_dynamic by catching IdenticalZero and IdenticalOne …
perrydv Jun 5, 2025
b5ac955
Ad chol_PDlogdet functions in C++, with GitHub copilot support
perrydv Jun 6, 2025
aec408d
progress towards a version of dmnorm for AD using prec_ldet setup
perrydv Jun 9, 2025
722b49a
ignore .vscode
perrydv Jun 14, 2025
c16c527
Remove .vscode (after adding to .gitignore)
perrydv Jun 14, 2025
93203e7
working on dmnormAD
perrydv Jun 19, 2025
18c9cf7
Updates for 2nd order reverse and clean up
perrydv Jun 23, 2025
e1076c1
Updates to PDinverse_logDet and getDerivs_internal to use subgraph_re…
perrydv Jul 3, 2025
4cb7428
add inDir, outDir, and outInds to nimDerivs. Also PDinverse_logdet
perrydv Jul 13, 2025
56fcc19
Add test-ADPDinverse_logdet.R
perrydv Jul 13, 2025
dd09382
Export dmnormAD stuff and do not pass n to C code.
paciorek Jul 16, 2025
f106ed2
Lift `PDinverse_logdet`.
paciorek Jul 16, 2025
0ad31ff
Fix transpose bug in `calc_dmnorm_prec_ldet_AltParams`.
paciorek Jul 17, 2025
0e2c3f3
Merge branch 'devel' into fix-cppad-memory-explosion
paciorek Jul 18, 2025
f5129a3
Add missing input for PDinverse_logdet test;
paciorek Jul 28, 2025
6a441ae
Fix seeming bug in error reporting in `test_AD2_oneCall`.
paciorek Jul 28, 2025
c958715
Fix first PDinverse test.
paciorek Jul 28, 2025
004efdb
Automate assignment of dmnormAD and add dmnormAD tests (#1575)
paciorek Jul 28, 2025
6f10359
Resolve merge conflict.
paciorek Jul 30, 2025
3fdbb21
Monkey with test-ADdmnorm.R to address finicky test failures.
paciorek Jul 30, 2025
f9310dd
Check for equality not identical in cOutput01$jac vs. cOutput012$jac …
paciorek Jul 30, 2025
63c298e
Tweak tests in light of AD changes.
paciorek Jul 30, 2025
99af509
Fix tweak to AD_test_utils.
paciorek Jul 30, 2025
9888258
Tweak verbosity and test ordering in test-ADdmnorm.R
paciorek Jul 30, 2025
cefe32e
write nimDerivs_ versions of CppAD conditionals based on atomic step …
perrydv Aug 3, 2025
3936c74
one more log_or_exp
perrydv Aug 3, 2025
bfef83e
initial steps to make dmnormAD take prec or cov parameterization
perrydv Aug 3, 2025
63d72dc
fix Weibull log_or_exp. fix dhalfflat use of new AD conditional.
perrydv Aug 3, 2025
9847635
Fix up handling of prec for dmnormAD.
paciorek Aug 5, 2025
a142ce2
fix updated dmnorm_inv_pd including renaming and adding tests
perrydv Aug 5, 2025
53b2377
Make minor testing comment change.
paciorek Aug 6, 2025
214c200
Fix dmnormAD alt param calc based on inv_ld having inverse.
paciorek Aug 6, 2025
f717b0b
Refine calc_dmnorm_inv_ld_AltParams.
paciorek Aug 6, 2025
174da49
Fix dwish-dmnormAD conjugacy checking and add tests.
paciorek Aug 6, 2025
dfd2c24
change order of res and log(res) in nimDerivs weibull
perrydv Aug 6, 2025
738556a
Fix minor merge conflict.
paciorek Aug 6, 2025
39eb2fa
Remove stray browsers in test file.
paciorek Aug 6, 2025
2ebf3cd
Remove another `browser`.
paciorek Aug 6, 2025
a0e1eb6
Fix dumb mistake in checking in cc_otherParamsCheck.
paciorek Aug 7, 2025
a90fef9
Fix dumb mistake #2 in checking in cc_otherParamsCheck.
paciorek Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ Eigen_local*
config.*
*.Rproj
profile/

.vscode
4 changes: 4 additions & 0 deletions packages/nimble/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ S3method(as.list, modelValuesBaseClass)
S3method(length, nimPointerList)
export(calc_dmnormConjugacyContributions)
export(calc_dmnormAltParams)
export(calc_dmnorm_inv_ld_AltParams)
export(calc_dwishAltParams)
export(calc_dcatConjugacyContributions)
export(CAR_calcM)
Expand Down Expand Up @@ -103,6 +104,7 @@ export(dinvgamma)
export(dinvwish_chol)
export(dlkj_corr_cholesky)
export(dmnorm_chol)
export(dmnorm_inv_ld)
export(dmulti)
export(dmvt_chol)
export(dsqrtinvgamma)
Expand Down Expand Up @@ -193,6 +195,7 @@ export(optimDefaultControl)
export(optimResultNimbleList)
export(parameterTransform)
export(pdexp)
export(PDinverse_logdet)
export(pexp_nimble)
export(phi)
export(pinvgamma)
Expand Down Expand Up @@ -225,6 +228,7 @@ export(rinvgamma)
export(rinvwish_chol)
export(rlkj_corr_cholesky)
export(rmnorm_chol)
export(rmnorm_inv_ld)
export(rmulti)
export(rmvt_chol)
export(rsqrtinvgamma)
Expand Down
1 change: 1 addition & 0 deletions packages/nimble/R/BUGS_BUGSdecl.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ nimblePreevaluationFunctionNames <- c('+',
'asCol',
'logdet',
'chol',
'PDinverse_logdet',
'inverse',
'forwardsolve',
'backsolve',
Expand Down
33 changes: 33 additions & 0 deletions packages/nimble/R/BUGS_modelDef.R
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,17 @@ modelDefClass$methods(reparameterizeDists = function() {
BUGSdecl <- declInfo[[i]] ## grab this current BUGS declation info object
if(BUGSdecl$type == 'determ') next ## skip deterministic nodes
code <- BUGSdecl$code ## grab the original code
if(BUGSdecl$distributionName == "dmnorm" && buildDerivs && getNimbleOption('useADdmnorm')) {
if(length(BUGSdecl$code) > 2 && "cholesky" %in% names(BUGSdecl$code[[3]])) {
messageIfVerbose(" [Note] Detected use of `cholesky` parameterization of `dmnorm` with a\n",
" derivative-enabled model. AD-optimized `dmnorm` is only available\n",
" for the `prec` or `cov` parameterizations. NIMBLE will use a version\n",
" of `dmnorm` not optimized for AD, which may result in inefficiency.")
} else {
BUGSdecl$distributionName <- "dmnormAD"
BUGSdecl$valueExpr[[1]] <- quote(dmnormAD)
}
}
valueExpr <- BUGSdecl$valueExpr ## grab the RHS (distribution)
distName <- BUGSdecl$distributionName #as.character(valueExpr[[1]])
if(!(distName %in% getAllDistributionsInfo('namesVector'))) stop('unknown distribution name: ', distName) ## error if the distribution isn't something we recognize
Expand Down Expand Up @@ -1059,6 +1070,16 @@ liftedCallsGetIndexingFromArgumentNumbers <- list(
CAR_calcEVs3 = c(3)
)

liftedCallsGetIndexingOther <- list(
## This is general in that it finds the number of elements of the matrix,
## but the input shouldn't be anything other than square.
PDinverse_logdet = function(argList) {
getlen <- function(arg) length(eval(arg))
list(substitute(1:N, list(N = prod(sapply(argList[[1]][3:length(argList[[1]])], getlen))+1)))
}
)


modelDefClass$methods(liftExpressionArgs = function() {
## overwrites declInfo (*and adds*), lifts any expressions in distribution arguments to new nodes
newDeclInfo <- list()
Expand Down Expand Up @@ -1125,6 +1146,7 @@ isExprLiftable <- function(paramExpr, type = NULL) {
callText <- getCallText(paramExpr)
if(callText == 'chol') return(TRUE) ## do lift calls to chol(...)
if(callText == 'inverse') return(TRUE) ## do lift calls to inverse(...)
if(callText == 'PDinverse_logdet') return(TRUE) ## do lift calls to PDinverse_logdet(...)
if(callText == 'CAR_calcNumIslands') return(TRUE) ## do lift calls to CAR_calcNumIslands(...)
if(callText == 'CAR_calcC') return(TRUE) ## do lift calls to CAR_calcC(...)
if(callText == 'CAR_calcM' ) return(TRUE) ## do lift calls to CAR_calcM(...)
Expand All @@ -1148,6 +1170,8 @@ isExprLiftable <- function(paramExpr, type = NULL) {
addNecessaryIndexingToNewNode <- function(newNodeNameExpr, paramExpr, indexVarExprs) {
if(is.call(paramExpr) && safeDeparse(paramExpr[[1]], warn = TRUE) %in% names(liftedCallsGetIndexingFromArgumentNumbers))
return(addNecessaryIndexingFromArgumentNumbers(newNodeNameExpr, paramExpr, indexVarExprs))
if(is.call(paramExpr) && safeDeparse(paramExpr[[1]], warn = TRUE) %in% names(liftedCallsGetIndexingOther))
return(addNecessaryIndexingOther(newNodeNameExpr, paramExpr, indexVarExprs))
usedIndexVarsList <- indexVarExprs[indexVarExprs %in% all.vars(paramExpr)] # this extracts any index variables which appear in 'paramExpr'
vectorizedIndexExprsList <- extractAnyVectorizedIndexExprs(paramExpr) # creates a list of any vectorized (:) indexing expressions appearing in 'paramExpr'
neededIndexExprsList <- c(usedIndexVarsList, vectorizedIndexExprsList)
Expand All @@ -1165,6 +1189,15 @@ addNecessaryIndexingFromArgumentNumbers <- function(newNodeNameExpr, paramExpr,
newNodeNameExprIndexed[3:(2+length(neededIndexExprsList))] <- neededIndexExprsList
return(newNodeNameExprIndexed)
}
addNecessaryIndexingOther <- function(newNodeNameExpr, paramExpr, indexVarExprs) {
paramExprCallName <- as.character(paramExpr[[1]])
neededIndexExprsList <- liftedCallsGetIndexingOther[[paramExprCallName]](as.list(paramExpr[-1]))
newNodeNameExprIndexed <- substitute(NAME[], list(NAME = newNodeNameExpr))
newNodeNameExprIndexed[3:(2+length(neededIndexExprsList))] <- neededIndexExprsList
return(newNodeNameExprIndexed)
}


extractAnyVectorizedIndexExprs <- function(expr) {
if(!(':' %in% all.names(expr))) return(list())
if(!is.call(expr)) return(list())
Expand Down
33 changes: 27 additions & 6 deletions packages/nimble/R/MCMC_conjugacy.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ conjugacyRelationshipsInputList <- list(
link = 'linear',
dependents = list(
##dmnorm = list(param = 'mean', contribution_mean = '(t(coeff) %*% prec %*% asCol(value-offset))[,1]', contribution_prec = 't(coeff) %*% prec %*% coeff')),
dmnorm = list(param = 'mean', contribution_mean = '(calc_dmnormConjugacyContributions(coeff, prec, value-offset, 1, 0))[,1]', contribution_prec = 'calc_dmnormConjugacyContributions(coeff, prec, value-offset, 2, 0)')),
dmnorm = list(param = 'mean', contribution_mean = '(calc_dmnormConjugacyContributions(coeff, prec, value-offset, 1, 0))[,1]', contribution_prec = 'calc_dmnormConjugacyContributions(coeff, prec, value-offset, 2, 0)'),
dmnormAD = list(param = 'mean', contribution_mean = '(calc_dmnormConjugacyContributions(coeff, prec, value-offset, 1, 0))[,1]', contribution_prec = 'calc_dmnormConjugacyContributions(coeff, prec, value-offset, 2, 0)')),
## LINK will be replaced with appropriate link via code processing
## original less efficient posterior definition:
## posterior = 'dmnorm_chol(mean = (inverse(prior_prec + contribution_prec) %*% (prior_prec %*% asCol(prior_mean) + asCol(contribution_mean)))[,1],
Expand All @@ -144,6 +145,21 @@ conjugacyRelationshipsInputList <- list(
mu <- backsolve(R, forwardsolve(t(R), A))[,1]
dmnorm_chol(mean = mu, cholesky = R, prec_param = 1) }'),

list(prior = 'dmnormAD',
link = 'linear',
dependents = list(
##dmnorm = list(param = 'mean', contribution_mean = '(t(coeff) %*% prec %*% asCol(value-offset))[,1]', contribution_prec = 't(coeff) %*% prec %*% coeff')),
dmnorm = list(param = 'mean', contribution_mean = '(calc_dmnormConjugacyContributions(coeff, prec, value-offset, 1, 0))[,1]', contribution_prec = 'calc_dmnormConjugacyContributions(coeff, prec, value-offset, 2, 0)'),
dmnormAD = list(param = 'mean', contribution_mean = '(calc_dmnormConjugacyContributions(coeff, prec, value-offset, 1, 0))[,1]', contribution_prec = 'calc_dmnormConjugacyContributions(coeff, prec, value-offset, 2, 0)')),
## LINK will be replaced with appropriate link via code processing
## original less efficient posterior definition:
## posterior = 'dmnorm_chol(mean = (inverse(prior_prec + contribution_prec) %*% (prior_prec %*% asCol(prior_mean) + asCol(contribution_mean)))[,1],
## cholesky = chol(prior_prec + contribution_prec),
## prec_param = 1)'),
posterior = '{ R <- chol(prior_prec + contribution_prec)
A <- prior_prec %*% asCol(prior_mean) + asCol(contribution_mean)
mu <- backsolve(R, forwardsolve(t(R), A))[,1]
dmnorm_chol(mean = mu, cholesky = R, prec_param = 1) }'),

## wishart
list(prior = 'dwish',
Expand All @@ -159,7 +175,8 @@ conjugacyRelationshipsInputList <- list(
## changing to only use link='identity' case, since the link='linear' case was not correct
## -DT March 2017
## dmnorm = list(param = 'prec', contribution_R = 'asCol(value-mean) %*% (asRow(value-mean) %*% coeff)', contribution_df = '1')),
dmnorm = list(param = 'prec', contribution_R = 'coeff * asCol(value-mean) %*% asRow(value-mean)', contribution_df = '1')),
dmnorm = list(param = 'prec', contribution_R = 'coeff * asCol(value-mean) %*% asRow(value-mean)', contribution_df = '1'),
dmnormAD = list(param = 'prec', contribution_R = 'coeff * asCol(value-mean) %*% asRow(value-mean)', contribution_df = '1')),
posterior = 'dwish_chol(cholesky = chol(prior_R + contribution_R),
df = prior_df + contribution_df,
scale_param = 0)'),
Expand All @@ -168,7 +185,8 @@ conjugacyRelationshipsInputList <- list(
list(prior = 'dinvwish',
link = 'multiplicativeScalar', # we only handle scalar 'coeff'; this naming is slightly awkward since for univar dists, link is of course scalar
dependents = list(
dmnorm = list(param = 'cov', contribution_S = 'asCol(value-mean) %*% asRow(value-mean) / coeff', contribution_df = '1')),
dmnorm = list(param = 'cov', contribution_S = 'asCol(value-mean) %*% asRow(value-mean) / coeff', contribution_df = '1'),
dmnormAD = list(param = 'cov', contribution_S = 'asCol(value-mean) %*% asRow(value-mean) / coeff', contribution_df = '1')),
posterior = 'dinvwish_chol(cholesky = chol(prior_S + contribution_S),
df = prior_df + contribution_df,
scale_param = 1)')
Expand Down Expand Up @@ -999,7 +1017,7 @@ conjugacyClass <- setRefClass(
for(contributionName in posteriorObject$neededContributionNames) {
if(!(contributionName %in% dependents[[distName]]$contributionNames)) next
contributionExpr <- dependents[[distName]]$contributionExprs[[contributionName]]
if(distName == 'dmnorm' && prior == 'dmnorm') {
if(distName %in% c('dmnorm','dmnormAD') && prior %in% c('dmnorm','dmnormAD')) {
## need to deal with [,1] in contribution_mean
if(contributionName == 'contribution_mean') tmpExpr <- contributionExpr[[2]][[2]] else tmpExpr <- contributionExpr
if(getNimbleOption('allowDynamicIndexing') && doDependentScreen) {
Expand Down Expand Up @@ -1348,8 +1366,11 @@ cc_otherParamsCheck <- function(model, depNode, targetNode, skipExpansionsNode =
if(!missing(depParamNodeName) && (names(paramsList)[i] == depParamNodeName)) {
expr <- depNodeExprExpanded
} else { expr <- cc_expandDetermNodesInExpr(model, paramsList[[i]], targetNode, skipExpansionsNode) }
if(cc_vectorizedComponentCheck(targetNode, expr)) return(FALSE)
if(cc_nodeInExpr(targetNode, expr)) { timesFound <- timesFound + 1 } ## we found 'targetNode'
## We expect to find target in PDinverse_logdet() one extra time when dmnormAD used.
if(!(length(expr) > 1 && expr[[1]] == "PDinverse_logdet")) {
if(cc_vectorizedComponentCheck(targetNode, expr)) return(FALSE)
if(cc_nodeInExpr(targetNode, expr)) { timesFound <- timesFound + 1 } ## we found 'targetNode'
}
}
if(timesFound == 0) stop('something went wrong; targetNode not found in any parameter expressions')
if(timesFound == 1) return(TRUE)
Expand Down
10 changes: 5 additions & 5 deletions packages/nimble/R/MCMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ essNF_multivariate <- nimbleFunction(
name = 'essNF_multivariate',
contains = essNFList_virtual,
setup = function(model, node) {
if(!(model$getDistribution(node) == 'dmnorm')) stop('something went wrong')
if(!(model$getDistribution(node) %in% c('dmnorm', 'dmnormAD'))) stop('sampler_ess: node `', node, '` does not have a dmnorm distribution')
},
run = function() {
mean <- model$getParam(node, 'mean')
Expand Down Expand Up @@ -1201,10 +1201,10 @@ sampler_ess <- nimbleFunction(
## nested function and function list definitions
essNFList <- nimbleFunctionList(essNFList_virtual)
if(model$getDistribution(target) == 'dnorm') essNFList[[1]] <- essNF_univariate(model, target)
if(model$getDistribution(target) == 'dmnorm') essNFList[[1]] <- essNF_multivariate(model, target)
if(model$getDistribution(target) %in% c('dmnorm','dmnormAD')) essNFList[[1]] <- essNF_multivariate(model, target)
## checks
if(length(target) > 1) stop('elliptical slice sampler only applies to one target node')
if(!(model$getDistribution(target) %in% c('dnorm', 'dmnorm'))) stop('elliptical slice sampler only applies to normal distributions')
if(!(model$getDistribution(target) %in% c('dnorm', 'dmnorm', 'dmnormAD'))) stop('elliptical slice sampler only applies to normal distributions')
targetNames <- createNamesString(target)
},
run = function() {
Expand Down Expand Up @@ -2862,7 +2862,7 @@ sampler_polyagamma <- nimbleFunction(

## Conjugacy checking, part 1.
if(check) {
if(!all(targetDists %in% c("dnorm", "dmnorm")))
if(!all(targetDists %in% c("dnorm", "dmnorm", "dmnormAD")))
stop("polyagamma sampler: all target nodes must have `dnorm` or `dmnorm` priors. ", checkMessage)
if(!all(model$getDistribution(yNodes) %in% c("dbern", "dbin")) )
stop("polyagamma sampler: response nodes must be distributed `dbern` or `dbin`. ", checkMessage)
Expand Down Expand Up @@ -2981,7 +2981,7 @@ sampler_polyagamma <- nimbleFunction(
singleSize <- FALSE

dnormNodes <- targetDists == "dnorm"
dmnormNodes <- targetDists == "dmnorm"
dmnormNodes <- targetDists %in% c("dmnorm", "dmnormAD")
n_dnorm <- sum(dnormNodes)
n_dmnorm <- sum(dmnormNodes)

Expand Down
22 changes: 21 additions & 1 deletion packages/nimble/R/cppDefs_ADtools.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ make_deriv_function <- function(origFun,
nDim = 1, type = 'double', ref = TRUE, const = TRUE))
newFun$args$addSymbol(cppVar(baseType = 'bool', name = "DO_UPDATE_"))
newFun$args$addSymbol(cppVar(baseType = 'bool', name = "RESET_"))
newFun$args$addSymbol(cppNimArr(name = 'ARGZ_outInds__',
nDim = 1, type = 'double', ref = TRUE, const = TRUE))

inDirSym <- cppNimArr(name = 'ARGZ_inDir__',
nDim = 1, type = 'double', ref = TRUE, const = TRUE)
outDirSym <- cppNimArr(name = 'ARGZ_outDir__',
nDim = 1, type = 'double', ref = TRUE, const = TRUE)
if(meta) {
inDirSym <- cppVarSym2templateTypeCppVarSym(inDirSym)
outDirSym <- cppVarSym2templateTypeCppVarSym(outDirSym)
}
newFun$args$addSymbol(inDirSym)
newFun$args$addSymbol(outDirSym)

newFun$args$addSymbol(cppVar(name = 'ARGZ_ADinfo_',
ref = TRUE,
baseType = "nimbleCppADinfoClass"))
Expand Down Expand Up @@ -233,14 +247,20 @@ make_deriv_function <- function(origFun,
getDerivsRcall <- substitute(returnList_ <- GETDERIVS_WRAPPER( INNERCALL,
ARGZ_nimDerivsOrders_,
ARGZ_wrtVector_ ,
ARGZ_outInds__,
ARGZ_inDir__,
ARGZ_outDir__,
recordingInfo_),
list(INNERCALL = innerRcall,
GETDERIVS_WRAPPER = as.name(getDerivs_wrapper)))
} else {
getDerivs_wrapper <- 'getDerivs_wrapper'
getDerivsRcall <- substitute(returnList_ <- GETDERIVS_WRAPPER( INNERCALL,
ARGZ_nimDerivsOrders_,
ARGZ_wrtVector_ ),
ARGZ_wrtVector_ ,
ARGZ_outInds__,
ARGZ_inDir__,
ARGZ_outDir__),
list(INNERCALL = innerRcall,
GETDERIVS_WRAPPER = as.name(getDerivs_wrapper)))
}
Expand Down
Loading
Loading