1- from typing import Callable , Dict
1+ from typing import Callable , Dict , Tuple
2+ import warnings
23
34import torch
45from torch import linalg
@@ -421,80 +422,148 @@ def compute_fibre_submatrix_random(
421422
422423 return fibre_matrix
423424
424- def compute_fibre_submatrix_aca (self , grid : Grid , k : int ) -> Tensor :
425-
426- for iter in range (self .options .max_fibres ):
427-
428- random_inds = grid .sample_indices (self .options .num_aca )
429- random_points = grid .indices2points (random_inds )
430-
431- M_vals = self .target_func (random_points )
432-
433- if iter == 0 :
425+ @staticmethod
426+ def _find_evaluated_points (
427+ new_inds : Tensor ,
428+ inds_eval : Tensor ,
429+ vals_eval : Tensor
430+ ) -> Tuple [Tensor , Tensor ]:
431+ """Returns a mask elements of a set of indices that have been
432+ computed previously, as well as the computed values.
433+ """
434+ diffs = (new_inds [:, None , :] - inds_eval [None , ...]).abs ().sum (dim = 2 )
435+ inds_prev = diffs .argmin (dim = 1 )
436+ mask = diffs .min (dim = 1 ).values < EPS
437+ mask_vals = vals_eval [inds_prev [mask ]]
438+ return mask , mask_vals
439+
440+ def _generate_points_aca (self , n : int , grid : Grid ) -> Tuple [Tensor , Tensor ]:
441+ """Returns a set of random indices and the corresponding
442+ function values.
443+ """
444+ inds_rand = grid .sample_indices (n )
445+ random_points = grid .indices2points (inds_rand )
446+ func_vals = self .target_func (random_points )
447+ self .num_eval_fibres += func_vals .numel ()
448+ return inds_rand , func_vals
449+
450+ def _initialise_index_set_aca (self , grid : Grid ) -> Tuple [Tensor , Tensor ]:
451+ """Initialises the index set defining the current cross by
452+ sampling from the coefficient tensor at random. This is
453+ repeated multiple times in case the sampled elements are
454+ uniformly zero (see also implementation by Strossner et al.).
455+ """
434456
435- max_residual = M_vals .max ()
436- max_residual_index = M_vals .abs ().argmax ()
437- max_index = random_inds [max_residual_index , :]
457+ num_initialisation_batches = 5
458+ num_aca = self .options .num_aca
438459
439- inds = torch .atleast_2d (max_index )
460+ for _ in range (num_initialisation_batches ):
461+ inds_rand , func_vals = self ._generate_points_aca (num_aca , grid )
462+ if func_vals .abs ().max () > 0.0 :
463+ break
464+
465+ if func_vals .abs ().max () == 0.0 :
466+ msg = (
467+ "ACA: None of the sampled fibre elements are nonzero. "
468+ "Consider rescaling the target function. If you are "
469+ "confident the target function is scaled appropriately, "
470+ "consider using a refined grid, larger core ranks, an "
471+ "increased number of bridging densities, or a larger "
472+ "value for num_aca."
473+ )
474+ warnings .warn (msg )
475+
476+ max_residual_index = func_vals .abs ().argmax ()
477+ inds = torch .atleast_2d (inds_rand [max_residual_index ])
478+ vals = torch .atleast_1d (func_vals [max_residual_index ])
479+ return inds , vals
480+
481+ def compute_fibre_submatrix_aca (self , grid : Grid , k : int ) -> Tensor :
440482
441- else :
483+ num_aca = self .options .num_aca
484+ inds , vals = self ._initialise_index_set_aca (grid )
442485
443- num_inds = inds .shape [0 ]
486+ # Keep track of elements of the cross that have been evaluated
487+ inds_eval = inds .clone ()
488+ vals_eval = vals .clone ()
444489
445- # Compute intersection matrix (NOTE: some of this
446- # will have actually been computed at previous
447- # iterations...)
448- inds_int = inds .repeat (num_inds , 1 )
449- inds_int [:, k ] = inds [:, k ].repeat_interleave (num_inds , dim = 0 )
490+ for _ in range (1 , self .options .max_fibres ):
450491
451- inds_row = random_inds . repeat ( num_inds , 1 )
452- inds_row [:, k ] = inds [:, k ]. repeat_interleave ( self .options . num_aca , dim = 0 )
492+ num_inds = inds . shape [ 0 ]
493+ inds_rand , func_vals = self ._generate_points_aca ( num_aca , grid )
453494
454- inds_col = inds .repeat (self .options .num_aca , 1 )
455- inds_col [:, k ] = random_inds [:, k ].repeat_interleave (num_inds , dim = 0 )
495+ inds_int = inds .repeat (num_inds , 1 )
496+ inds_int [:, k ] = inds [:, k ].repeat_interleave (num_inds , dim = 0 )
497+ inds_row = inds_rand .repeat (num_inds , 1 )
498+ inds_row [:, k ] = inds [:, k ].repeat_interleave (num_aca , dim = 0 )
499+ inds_col = inds .repeat (self .options .num_aca , 1 )
500+ inds_col [:, k ] = inds_rand [:, k ].repeat_interleave (num_inds , dim = 0 )
456501
457- points_int = grid .indices2points (inds_int )
458- points_row = grid .indices2points (inds_row )
459- points_col = grid .indices2points (inds_col )
502+ points_int = grid .indices2points (inds_int )
503+ points_row = grid .indices2points (inds_row )
504+ points_col = grid .indices2points (inds_col )
460505
461- B_int = self .target_func (points_int )
462- B_int = B_int .reshape (num_inds , num_inds )
463- B_rows = self .target_func (points_row )
464- B_rows = B_rows .reshape (num_inds , self .options .num_aca )
465- B_cols = self .target_func (points_col )
466- B_cols = B_cols .reshape (self .options .num_aca , num_inds )
506+ mask , mask_vals = self ._find_evaluated_points (
507+ inds_int , inds_eval , vals_eval
508+ )
467509
468- self .num_eval_fibres += (
469- 2 * num_inds * self .options .num_aca
470- + num_inds ** 2
471- )
510+ # Form intersection submatrix (avoiding the evaluation
511+ # of function values that were previously computed)
512+ B_int = torch .zeros (inds_int .shape [0 ])
513+ B_int [mask ] = mask_vals
514+ if (~ mask ).any ():
515+ B_int [~ mask ] = self .target_func (points_int [~ mask ])
516+
517+ B_rows = self .target_func (points_row )
518+ B_cols = self .target_func (points_col )
519+
520+ B_int = B_int .reshape (num_inds , num_inds )
521+ B_rows = B_rows .reshape (num_inds , num_aca )
522+ B_cols = B_cols .reshape (num_aca , num_inds )
523+
524+ inds_eval = inds_int .clone ()
525+ vals_eval = B_int .flatten ()
526+
527+ num_eval_int = int ((~ mask ).sum ())
528+ self .num_eval_fibres += (
529+ num_eval_int + B_rows .numel () + B_cols .numel ()
530+ )
472531
473- # Check for (near-)singularity of intersection matrix
474- # (also done in implementation by Strossner et al.).
475- if linalg .cond (B_int ) > 1.0 / EPS :
476- break
477-
478- # Update index set with index of maximum residual
479- B_vals = B_cols @ linalg .solve (B_int , B_rows )
480- residuals = torch .diag (M_vals - B_vals ).abs ()
481- max_residual = residuals .max ()
482- max_residual_index = residuals .abs ().argmax ()
483- max_index = random_inds [max_residual_index , :]
484- inds = torch .vstack ((inds , max_index ))
532+ # Check for (near-)singularity of intersection matrix
533+ # (also done in implementation by Strossner et al.).
534+ # This occurs for functions where the fibre matrices
535+ # are exactly low rank.
536+ if linalg .cond (B_int ) > 1.0 / EPS :
537+ break
485538
486- if max_residual < self .options .tol_aca and iter > 1 :
539+ cross_vals = B_cols @ linalg .solve (B_int , B_rows )
540+ residuals = torch .diag (func_vals - cross_vals ).abs ()
541+ if residuals .max () < self .options .tol_aca :
487542 break
543+
544+ # Update index set
545+ max_index = inds_rand [residuals .argmax (), :]
546+ inds = torch .vstack ((inds , max_index ))
488547
489548 n_k = self .bases [k ].cardinality
490549 num_inds = inds .shape [0 ]
491550
492551 fibre_inds = inds .repeat (n_k , 1 )
493- fibre_inds [:, k ] = torch .arange (n_k , device = self .device ). repeat_interleave ( num_inds , dim = 0 )
494-
552+ ii = torch .arange (n_k , device = self .device )
553+ fibre_inds [:, k ] = ii . repeat_interleave ( num_inds , dim = 0 )
495554 fibre_points = grid .indices2points (fibre_inds )
496- fibre_matrix = self .target_func (fibre_points ).reshape (n_k , num_inds )
497- self .num_eval_fibres += n_k * num_inds
555+
556+ mask , mask_vals = self ._find_evaluated_points (
557+ fibre_inds , inds_eval , vals_eval
558+ )
559+
560+ fibre_matrix = torch .zeros ((n_k * num_inds ,))
561+ fibre_matrix [mask ] = mask_vals
562+ fibre_matrix [~ mask ] = self .target_func (fibre_points [~ mask ])
563+ fibre_matrix = fibre_matrix .reshape (n_k , num_inds )
564+
565+ num_eval_new = int ((~ mask ).sum ())
566+ self .num_eval_fibres += num_eval_new
498567
499568 return fibre_matrix
500569
@@ -580,9 +649,10 @@ def approximate(
580649 return
581650
582651 def clone (self ):
583- # Note: can't copy the cores and index sets over, because the
584- # indices corresponding to the DEIM projection onto the reduced
585- # bases in each dimension can change.
652+ # Note: we cannot copy the cores and index sets over, because
653+ # the indices corresponding to the DEIM projection onto the
654+ # reduced bases in each dimension can change. Instead we start
655+ # from scratch.
586656 tt = TT (self .tt .options , device = self .device )
587657 ftt = EFTT (self .bases , tt , self .options , device = self .device )
588658 return ftt
0 commit comments