Skip to content

Commit abd2164

Browse files
committed
Tidy up ACA code
1 parent ec79049 commit abd2164

File tree

1 file changed

+129
-59
lines changed

1 file changed

+129
-59
lines changed

deep_tensor/ftt/ftt.py

Lines changed: 129 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Callable, Dict
1+
from typing import Callable, Dict, Tuple
2+
import warnings
23

34
import torch
45
from torch import linalg
@@ -421,80 +422,148 @@ def compute_fibre_submatrix_random(
421422

422423
return fibre_matrix
423424

424-
def compute_fibre_submatrix_aca(self, grid: Grid, k: int) -> Tensor:
425-
426-
for iter in range(self.options.max_fibres):
427-
428-
random_inds = grid.sample_indices(self.options.num_aca)
429-
random_points = grid.indices2points(random_inds)
430-
431-
M_vals = self.target_func(random_points)
432-
433-
if iter == 0:
425+
@staticmethod
426+
def _find_evaluated_points(
427+
new_inds: Tensor,
428+
inds_eval: Tensor,
429+
vals_eval: Tensor
430+
) -> Tuple[Tensor, Tensor]:
431+
"""Returns a mask elements of a set of indices that have been
432+
computed previously, as well as the computed values.
433+
"""
434+
diffs = (new_inds[:, None, :] - inds_eval[None, ...]).abs().sum(dim=2)
435+
inds_prev = diffs.argmin(dim=1)
436+
mask = diffs.min(dim=1).values < EPS
437+
mask_vals = vals_eval[inds_prev[mask]]
438+
return mask, mask_vals
439+
440+
def _generate_points_aca(self, n: int, grid: Grid) -> Tuple[Tensor, Tensor]:
441+
"""Returns a set of random indices and the corresponding
442+
function values.
443+
"""
444+
inds_rand = grid.sample_indices(n)
445+
random_points = grid.indices2points(inds_rand)
446+
func_vals = self.target_func(random_points)
447+
self.num_eval_fibres += func_vals.numel()
448+
return inds_rand, func_vals
449+
450+
def _initialise_index_set_aca(self, grid: Grid) -> Tuple[Tensor, Tensor]:
451+
"""Initialises the index set defining the current cross by
452+
sampling from the coefficient tensor at random. This is
453+
repeated multiple times in case the sampled elements are
454+
uniformly zero (see also implementation by Strossner et al.).
455+
"""
434456

435-
max_residual = M_vals.max()
436-
max_residual_index = M_vals.abs().argmax()
437-
max_index = random_inds[max_residual_index, :]
457+
num_initialisation_batches = 5
458+
num_aca = self.options.num_aca
438459

439-
inds = torch.atleast_2d(max_index)
460+
for _ in range(num_initialisation_batches):
461+
inds_rand, func_vals = self._generate_points_aca(num_aca, grid)
462+
if func_vals.abs().max() > 0.0:
463+
break
464+
465+
if func_vals.abs().max() == 0.0:
466+
msg = (
467+
"ACA: None of the sampled fibre elements are nonzero. "
468+
"Consider rescaling the target function. If you are "
469+
"confident the target function is scaled appropriately, "
470+
"consider using a refined grid, larger core ranks, an "
471+
"increased number of bridging densities, or a larger "
472+
"value for num_aca."
473+
)
474+
warnings.warn(msg)
475+
476+
max_residual_index = func_vals.abs().argmax()
477+
inds = torch.atleast_2d(inds_rand[max_residual_index])
478+
vals = torch.atleast_1d(func_vals[max_residual_index])
479+
return inds, vals
480+
481+
def compute_fibre_submatrix_aca(self, grid: Grid, k: int) -> Tensor:
440482

441-
else:
483+
num_aca = self.options.num_aca
484+
inds, vals = self._initialise_index_set_aca(grid)
442485

443-
num_inds = inds.shape[0]
486+
# Keep track of elements of the cross that have been evaluated
487+
inds_eval = inds.clone()
488+
vals_eval = vals.clone()
444489

445-
# Compute intersection matrix (NOTE: some of this
446-
# will have actually been computed at previous
447-
# iterations...)
448-
inds_int = inds.repeat(num_inds, 1)
449-
inds_int[:, k] = inds[:, k].repeat_interleave(num_inds, dim=0)
490+
for _ in range(1, self.options.max_fibres):
450491

451-
inds_row = random_inds.repeat(num_inds, 1)
452-
inds_row[:, k] = inds[:, k].repeat_interleave(self.options.num_aca, dim=0)
492+
num_inds = inds.shape[0]
493+
inds_rand, func_vals = self._generate_points_aca(num_aca, grid)
453494

454-
inds_col = inds.repeat(self.options.num_aca, 1)
455-
inds_col[:, k] = random_inds[:, k].repeat_interleave(num_inds, dim=0)
495+
inds_int = inds.repeat(num_inds, 1)
496+
inds_int[:, k] = inds[:, k].repeat_interleave(num_inds, dim=0)
497+
inds_row = inds_rand.repeat(num_inds, 1)
498+
inds_row[:, k] = inds[:, k].repeat_interleave(num_aca, dim=0)
499+
inds_col = inds.repeat(self.options.num_aca, 1)
500+
inds_col[:, k] = inds_rand[:, k].repeat_interleave(num_inds, dim=0)
456501

457-
points_int = grid.indices2points(inds_int)
458-
points_row = grid.indices2points(inds_row)
459-
points_col = grid.indices2points(inds_col)
502+
points_int = grid.indices2points(inds_int)
503+
points_row = grid.indices2points(inds_row)
504+
points_col = grid.indices2points(inds_col)
460505

461-
B_int = self.target_func(points_int)
462-
B_int = B_int.reshape(num_inds, num_inds)
463-
B_rows = self.target_func(points_row)
464-
B_rows = B_rows.reshape(num_inds, self.options.num_aca)
465-
B_cols = self.target_func(points_col)
466-
B_cols = B_cols.reshape(self.options.num_aca, num_inds)
506+
mask, mask_vals = self._find_evaluated_points(
507+
inds_int, inds_eval, vals_eval
508+
)
467509

468-
self.num_eval_fibres += (
469-
2 * num_inds * self.options.num_aca
470-
+ num_inds ** 2
471-
)
510+
# Form intersection submatrix (avoiding the evaluation
511+
# of function values that were previously computed)
512+
B_int = torch.zeros(inds_int.shape[0])
513+
B_int[mask] = mask_vals
514+
if (~mask).any():
515+
B_int[~mask] = self.target_func(points_int[~mask])
516+
517+
B_rows = self.target_func(points_row)
518+
B_cols = self.target_func(points_col)
519+
520+
B_int = B_int.reshape(num_inds, num_inds)
521+
B_rows = B_rows.reshape(num_inds, num_aca)
522+
B_cols = B_cols.reshape(num_aca, num_inds)
523+
524+
inds_eval = inds_int.clone()
525+
vals_eval = B_int.flatten()
526+
527+
num_eval_int = int((~mask).sum())
528+
self.num_eval_fibres += (
529+
num_eval_int + B_rows.numel() + B_cols.numel()
530+
)
472531

473-
# Check for (near-)singularity of intersection matrix
474-
# (also done in implementation by Strossner et al.).
475-
if linalg.cond(B_int) > 1.0 / EPS:
476-
break
477-
478-
# Update index set with index of maximum residual
479-
B_vals = B_cols @ linalg.solve(B_int, B_rows)
480-
residuals = torch.diag(M_vals - B_vals).abs()
481-
max_residual = residuals.max()
482-
max_residual_index = residuals.abs().argmax()
483-
max_index = random_inds[max_residual_index, :]
484-
inds = torch.vstack((inds, max_index))
532+
# Check for (near-)singularity of intersection matrix
533+
# (also done in implementation by Strossner et al.).
534+
# This occurs for functions where the fibre matrices
535+
# are exactly low rank.
536+
if linalg.cond(B_int) > 1.0 / EPS:
537+
break
485538

486-
if max_residual < self.options.tol_aca and iter > 1:
539+
cross_vals = B_cols @ linalg.solve(B_int, B_rows)
540+
residuals = torch.diag(func_vals - cross_vals).abs()
541+
if residuals.max() < self.options.tol_aca:
487542
break
543+
544+
# Update index set
545+
max_index = inds_rand[residuals.argmax(), :]
546+
inds = torch.vstack((inds, max_index))
488547

489548
n_k = self.bases[k].cardinality
490549
num_inds = inds.shape[0]
491550

492551
fibre_inds = inds.repeat(n_k, 1)
493-
fibre_inds[:, k] = torch.arange(n_k, device=self.device).repeat_interleave(num_inds, dim=0)
494-
552+
ii = torch.arange(n_k, device=self.device)
553+
fibre_inds[:, k] = ii.repeat_interleave(num_inds, dim=0)
495554
fibre_points = grid.indices2points(fibre_inds)
496-
fibre_matrix = self.target_func(fibre_points).reshape(n_k, num_inds)
497-
self.num_eval_fibres += n_k * num_inds
555+
556+
mask, mask_vals = self._find_evaluated_points(
557+
fibre_inds, inds_eval, vals_eval
558+
)
559+
560+
fibre_matrix = torch.zeros((n_k*num_inds,))
561+
fibre_matrix[mask] = mask_vals
562+
fibre_matrix[~mask] = self.target_func(fibre_points[~mask])
563+
fibre_matrix = fibre_matrix.reshape(n_k, num_inds)
564+
565+
num_eval_new = int((~mask).sum())
566+
self.num_eval_fibres += num_eval_new
498567

499568
return fibre_matrix
500569

@@ -580,9 +649,10 @@ def approximate(
580649
return
581650

582651
def clone(self):
583-
# Note: can't copy the cores and index sets over, because the
584-
# indices corresponding to the DEIM projection onto the reduced
585-
# bases in each dimension can change.
652+
# Note: we cannot copy the cores and index sets over, because
653+
# the indices corresponding to the DEIM projection onto the
654+
# reduced bases in each dimension can change. Instead we start
655+
# from scratch.
586656
tt = TT(self.tt.options, device=self.device)
587657
ftt = EFTT(self.bases, tt, self.options, device=self.device)
588658
return ftt

0 commit comments

Comments
 (0)