Skip to content

Commit f989552

Browse files
committed
Armadillo 14.4.2 rc
1 parent 331194d commit f989552

File tree

7 files changed

+154
-80
lines changed

7 files changed

+154
-80
lines changed

inst/include/armadillo_bits/config.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,18 @@
303303
#undef ARMA_64BIT_WORD
304304
#endif
305305

306+
#if (defined(ARMA_BLAS_LONG_LONG) && defined(ARMA_USE_WRAPPER))
307+
#pragma message ("WARNING: use of ARMA_BLAS_LONG_LONG in conjunction with ARMA_USE_WRAPPER is not supported")
308+
#endif
309+
310+
#if (defined(ARMA_BLAS_64BIT_INT) && defined(ARMA_USE_WRAPPER))
311+
#pragma message ("WARNING: use of ARMA_BLAS_64BIT_INT in conjunction with ARMA_USE_WRAPPER is not supported")
312+
#endif
313+
314+
#if (defined(ARMA_SUPERLU_64BIT_INT) && defined(ARMA_USE_WRAPPER))
315+
#pragma message ("WARNING: use of ARMA_SUPERLU_64BIT_INT in conjunction with ARMA_USE_WRAPPER is not supported")
316+
#endif
317+
306318
// for compatibility with earlier versions of Armadillo
307319
#if defined(ARMA_BLAS_LONG) || defined(ARMA_BLAS_LONG_LONG)
308320
#undef ARMA_BLAS_64BIT_INT

inst/include/armadillo_bits/def_superlu.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@
1919

2020
extern "C"
2121
{
22-
extern void arma_wrapper(sgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);
23-
extern void arma_wrapper(dgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);
24-
extern void arma_wrapper(cgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);
25-
extern void arma_wrapper(zgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);
22+
extern void arma_wrapper(sgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, superlu::int_t*);
23+
extern void arma_wrapper(dgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, superlu::int_t*);
24+
extern void arma_wrapper(cgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, superlu::int_t*);
25+
extern void arma_wrapper(zgssv)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, superlu::int_t*);
2626

27-
extern void arma_wrapper(sgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*);
28-
extern void arma_wrapper(dgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*);
29-
extern void arma_wrapper(cgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*);
30-
extern void arma_wrapper(zgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*);
27+
extern void arma_wrapper(sgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, superlu::int_t, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, superlu::int_t*);
28+
extern void arma_wrapper(dgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, superlu::int_t, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, superlu::int_t*);
29+
extern void arma_wrapper(cgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, superlu::int_t, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, superlu::int_t*);
30+
extern void arma_wrapper(zgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, superlu::int_t, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, superlu::int_t*);
3131

32-
extern void arma_wrapper(sgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*);
33-
extern void arma_wrapper(dgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*);
34-
extern void arma_wrapper(cgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*);
35-
extern void arma_wrapper(zgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*);
32+
extern void arma_wrapper(sgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, superlu::int_t, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, superlu::int_t*);
33+
extern void arma_wrapper(dgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, superlu::int_t, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, superlu::int_t*);
34+
extern void arma_wrapper(cgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, superlu::int_t, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, superlu::int_t*);
35+
extern void arma_wrapper(zgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, superlu::int_t, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, superlu::int_t*);
3636

3737
extern void arma_wrapper(sgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);
3838
extern void arma_wrapper(dgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*);

inst/include/armadillo_bits/glue_times_meat.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,40 @@ glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const
104104

105105
const strip_inv<T1> A_strip(X.A);
106106

107+
typedef typename strip_inv<T1>::stored_type T1_stripped;
108+
109+
if( (is_cx<eT>::no) && (strip_inv<T1>::do_inv_gen) && (is_Mat<T1_stripped>::value) && (is_Mat<T2>::value) )
110+
{
111+
const unwrap<T1_stripped> UA(A_strip.M);
112+
const unwrap<T2 > UB(X.B);
113+
114+
const typename unwrap<T1_stripped>::stored_type& A = UA.M;
115+
const typename unwrap<T2 >::stored_type& B = UB.M;
116+
117+
const uword N = A.n_rows;
118+
119+
if( (N > 0) && (N <= uword(3)) && (N == A.n_cols) && (N == B.n_rows) && (void_ptr(&out) != void_ptr(&B)) )
120+
{
121+
arma_debug_print("glue_times_redirect<2>::apply(): inv tiny matrix optimisation");
122+
123+
Mat<eT> AA(N, N, arma_nozeros_indicator());
124+
125+
arrayops::copy(AA.memptr(), A.memptr(), AA.n_elem);
126+
127+
bool inv_status = false;
128+
129+
if(N == 1) { const eT a = AA[0]; AA[0] = eT(1) / a; inv_status = (a != eT(0)); }
130+
if(N == 2) { inv_status = op_inv_gen_full::apply_tiny_2x2(AA); }
131+
if(N == 3) { inv_status = op_inv_gen_full::apply_tiny_3x3(AA); }
132+
133+
if(inv_status) { glue_times::apply<eT,false,false,false>(out, AA, B, eT(0)); return; }
134+
135+
arma_debug_print("glue_times_redirect<2>::apply(): inv tiny matrix optimisation failed");
136+
137+
// fallthrough if optimisation failed
138+
}
139+
}
140+
107141
Mat<eT> A = A_strip.M;
108142

109143
arma_conform_check( (A.is_square() == false), "inv(): given matrix must be square sized" );

inst/include/armadillo_bits/include_superlu.hpp

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//
1717
// ------------------------------------------------------------------------
1818
//
19-
// This file includes portions of SuperLU 5.2 software,
19+
// This file includes portions of SuperLU 7.0 software,
2020
// licensed under the following conditions.
2121
//
2222
// Copyright (c) 2003, The Regents of the University of California, through
@@ -64,16 +64,23 @@
6464
// and manually specify a few SuperLU structures and function prototypes.
6565
//
6666
// CAVEAT:
67-
// This code requires SuperLU version 5.2,
68-
// and assumes that newer 5.x versions will have no API changes.
67+
// This code requires SuperLU version 7.0, and assumes that newer 7.x versions have no API changes.
6968

7069
namespace arma
7170
{
7271
namespace superlu
7372
{
74-
// slu_*defs.h has int typedefed to int_t.
75-
// I'll just write it as int for simplicity, where I can, but supermatrix.h needs int_t.
76-
typedef int int_t;
73+
// superlu_config.h uses either int or int64_t as int_t
74+
75+
#if defined(ARMA_SUPERLU_64BIT_INT)
76+
#if defined(INT64_MAX)
77+
typedef std::int64_t int_t;
78+
#else
79+
typedef long long int_t;
80+
#endif
81+
#else
82+
typedef int int_t;
83+
#endif
7784
}
7885
}
7986

@@ -84,7 +91,7 @@ namespace arma
8491
namespace superlu
8592
{
8693
// Include supermatrix.h. This gives us SuperMatrix.
87-
// Put it in the slu namespace.
94+
// Put it in the superlu namespace.
8895
// For versions of SuperLU I am familiar with, supermatrix.h does not include any other files.
8996
// Therefore, putting it in the superlu namespace is reasonably safe.
9097
// This same reasoning is true for superlu_enum_consts.h.
@@ -120,7 +127,7 @@ namespace superlu
120127

121128
#undef ARMA_SLU_STR1
122129
#undef ARMA_SLU_STR2
123-
130+
124131
#undef ARMA_SLU_HEADER_A
125132
#undef ARMA_SLU_HEADER_B
126133

@@ -130,7 +137,7 @@ namespace superlu
130137
{
131138
int* panel_histo;
132139
double* utime;
133-
float* ops;
140+
float* ops; // NOTE: orig definition is flops_t* ops, where flops_t = float
134141
int TinyPivots;
135142
int RefineSteps;
136143
int expansions;
@@ -174,33 +181,33 @@ namespace superlu
174181

175182
typedef struct e_node
176183
{
177-
int size;
184+
int_t size;
178185
void* mem;
179186
} ExpHeader;
180187

181188
typedef struct
182189
{
183-
int size;
184-
int used;
185-
int top1;
186-
int top2;
190+
int_t size;
191+
int_t used;
192+
int_t top1;
193+
int_t top2;
187194
void* array;
188195
} LU_stack_t;
189196

190197
typedef struct
191198
{
192199
int* xsup;
193200
int* supno;
194-
int* lsub;
195-
int* xlsub;
201+
int_t* lsub;
202+
int_t* xlsub;
196203
void* lusup;
197-
int* xlusup;
204+
int_t* xlusup;
198205
void* ucol;
199-
int* usub;
200-
int* xusub;
201-
int nzlmax;
202-
int nzumax;
203-
int nzlumax;
206+
int_t* usub;
207+
int_t* xusub;
208+
int_t nzlmax;
209+
int_t nzumax;
210+
int_t nzlumax;
204211
int n;
205212
LU_space_t MemModel;
206213
int num_expansions;
@@ -283,23 +290,23 @@ namespace superlu
283290
{
284291
int* panel_histo;
285292
double* utime;
286-
float* ops;
293+
float* ops; // NOTE: orig definition is flops_t* ops, where flops_t = float
287294
int TinyPivots;
288295
int RefineSteps;
289296
int expansions;
290297
} SuperLUStat_t;
291298

292-
typedef enum {NO, YES} yes_no_t;
299+
typedef enum {NO, YES} yes_no_t;
293300
typedef enum {DOFACT, SamePattern, SamePattern_SameRowPerm, FACTORED} fact_t;
294-
typedef enum {NOROWPERM, LargeDiag, MY_PERMR} rowperm_t;
301+
typedef enum {NOROWPERM, LargeDiag_MC64, LargeDiag_HWPM, MY_PERMR} rowperm_t;
295302
typedef enum {NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD,
296-
METIS_AT_PLUS_A, PARMETIS, ZOLTAN, MY_PERMC} colperm_t;
297-
typedef enum {NOTRANS, TRANS, CONJ} trans_t;
298-
typedef enum {NOREFINE, SLU_SINGLE=1, SLU_DOUBLE, SLU_EXTRA} IterRefine_t;
299-
typedef enum {SYSTEM, USER} LU_space_t;
300-
typedef enum {ONE_NORM, TWO_NORM, INF_NORM} norm_t;
301-
typedef enum {SILU, SMILU_1, SMILU_2, SMILU_3} milu_t;
302-
303+
METIS_AT_PLUS_A, PARMETIS, METIS_ATA, ZOLTAN, MY_PERMC} colperm_t;
304+
typedef enum {NOTRANS, TRANS, CONJ} trans_t;
305+
typedef enum {NOREFINE, SLU_SINGLE=1, SLU_DOUBLE, SLU_EXTRA} IterRefine_t;
306+
typedef enum {SYSTEM, USER} LU_space_t;
307+
typedef enum {ONE_NORM, TWO_NORM, INF_NORM} norm_t;
308+
typedef enum {SILU, SMILU_1, SMILU_2, SMILU_3} milu_t;
309+
303310
typedef struct
304311
{
305312
fact_t Fact;
@@ -352,33 +359,33 @@ namespace superlu
352359

353360
typedef struct e_node
354361
{
355-
int size;
362+
int_t size;
356363
void* mem;
357364
} ExpHeader;
358365

359366
typedef struct
360367
{
361-
int size;
362-
int used;
363-
int top1;
364-
int top2;
368+
int_t size;
369+
int_t used;
370+
int_t top1;
371+
int_t top2;
365372
void* array;
366373
} LU_stack_t;
367374

368375
typedef struct
369376
{
370377
int* xsup;
371378
int* supno;
372-
int* lsub;
373-
int* xlsub;
379+
int_t* lsub;
380+
int_t* xlsub;
374381
void* lusup;
375-
int* xlusup;
382+
int_t* xlusup;
376383
void* ucol;
377-
int* usub;
378-
int* xusub;
379-
int nzlmax;
380-
int nzumax;
381-
int nzlumax;
384+
int_t* usub;
385+
int_t* xusub;
386+
int_t nzlmax;
387+
int_t nzumax;
388+
int_t nzlumax;
382389
int n;
383390
LU_space_t MemModel;
384391
int num_expansions;

inst/include/armadillo_bits/op_expmat_meat.hpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,29 @@ op_expmat::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1
104104
return true;
105105
}
106106

107+
// trace reduction
108+
109+
const eT diag_shift = arma::trace(A) / T(A.n_rows);
110+
const eT exp_diag_shift = std::exp(diag_shift);
111+
112+
const bool do_trace_reduction = arma_isfinite(diag_shift) && arma_isfinite(exp_diag_shift) && (exp_diag_shift != eT(0)) && ( (is_cx<eT>::yes) ? (std::abs(diag_shift) > T(0)) : (access::tmp_real(diag_shift) > T(0)) );
113+
114+
if(do_trace_reduction)
115+
{
116+
arma_debug_print("op_expmat: diag_shift: ", diag_shift);
117+
118+
A.diag() -= diag_shift;
119+
}
120+
107121
const T norm_val = arma::norm(A, "inf");
108122

109123
if(arma_isfinite(norm_val) == false) { return false; }
110124

111-
const double log2_val = (norm_val > T(0)) ? double(eop_aux::log2(norm_val)) : double(0);
125+
int exponent = int(0); std::frexp(norm_val, &exponent);
112126

113-
int exponent = int(0); std::frexp(log2_val, &exponent);
127+
const uword s = (std::min)( uword( (std::max)(int(0), exponent) ), uword(1023) );
114128

115-
const uword s = uword( (std::max)(int(0), exponent + int(1)) );
129+
arma_debug_print("op_expmat: s: ", s);
116130

117131
A /= eT(eop_aux::pow(double(2), double(s)));
118132

@@ -125,7 +139,7 @@ op_expmat::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1
125139

126140
bool positive = true;
127141

128-
const uword N = 6;
142+
const uword N = 8;
129143

130144
for(uword i = 2; i <= N; ++i)
131145
{
@@ -148,6 +162,9 @@ op_expmat::apply_direct(Mat<typename T1::elem_type>& out, const Base<typename T1
148162

149163
for(uword i=0; i < s; ++i) { out = out * out; }
150164

165+
// inverse trace reduction
166+
if(do_trace_reduction) { out *= exp_diag_shift; }
167+
151168
return true;
152169
}
153170

0 commit comments

Comments
 (0)