Skip to content

Commit b5c19e0

Browse files
committed
Add decimal class and divmod tests
1 parent 0ca8d5a commit b5c19e0

File tree

4 files changed

+294
-26
lines changed

4 files changed

+294
-26
lines changed

cp-algo/math/bigint.hpp

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,90 @@ namespace cp_algo::math {
1111
};
1212
template<base_v base = x10>
1313
struct bigint {
14+
static constexpr uint64_t Base = uint64_t(base);
1415
static constexpr uint16_t digit_length = base == x10 ? 16 : 15;
1516
static constexpr uint16_t sub_base = base == x10 ? 10 : 16;
1617
static constexpr uint32_t meta_base = base == x10 ? uint32_t(1e4) : uint32_t(1 << 15);
1718
big_basic_string<uint64_t> digits;
1819
bool negative;
1920

21+
auto operator <=> (bigint const& other) const {
22+
// Handle zero cases
23+
if (digits.empty() && other.digits.empty()) {
24+
return std::strong_ordering::equal;
25+
}
26+
if (digits.empty()) {
27+
return other.negative ? std::strong_ordering::greater : std::strong_ordering::less;
28+
}
29+
if (other.digits.empty()) {
30+
return negative ? std::strong_ordering::less : std::strong_ordering::greater;
31+
}
32+
33+
// Handle sign differences
34+
if (negative != other.negative) {
35+
return negative ? std::strong_ordering::less : std::strong_ordering::greater;
36+
}
37+
38+
// Both have the same sign - compare magnitudes
39+
if (digits.size() != other.digits.size()) {
40+
auto size_cmp = digits.size() <=> other.digits.size();
41+
// If both negative, reverse the comparison
42+
return negative ? 0 <=> size_cmp : size_cmp;
43+
}
44+
45+
// Same size, compare digits from most significant to least
46+
for (auto i = ssize(digits) - 1; i >= 0; i--) {
47+
auto digit_cmp = digits[i] <=> other.digits[i];
48+
if (digit_cmp != std::strong_ordering::equal) {
49+
return negative ? 0 <=> digit_cmp : digit_cmp;
50+
}
51+
}
52+
53+
return std::strong_ordering::equal;
54+
}
55+
2056
bigint() {}
2157

58+
bigint(big_basic_string<uint64_t> d, bool neg): digits(std::move(d)), negative(neg) {
59+
normalize();
60+
}
61+
62+
bigint& pad_inplace(size_t to_add) {
63+
digits.insert(0, to_add, 0);
64+
return normalize();
65+
}
66+
bigint& drop_inplace(size_t to_drop) {
67+
digits.erase(0, std::min(to_drop, size(digits)));
68+
return normalize();
69+
}
70+
bigint& take_inplace(size_t to_keep) {
71+
digits.erase(std::min(to_keep, size(digits)), std::string::npos);
72+
return normalize();
73+
}
74+
bigint& top_inplace(size_t to_keep) {
75+
if (to_keep >= size(digits)) {
76+
return pad_inplace(to_keep - size(digits));
77+
} else {
78+
return drop_inplace(size(digits) - to_keep);
79+
}
80+
}
81+
bigint pad(size_t to_add) const {
82+
return bigint{big_basic_string<uint64_t>(to_add, 0) + digits, negative}.normalize();
83+
}
84+
bigint drop(size_t to_drop) const {
85+
return bigint{digits.substr(std::min(to_drop, size(digits))), negative}.normalize();
86+
}
87+
bigint take(size_t to_keep) const {
88+
return bigint{digits.substr(0, std::min(to_keep, size(digits))), negative}.normalize();
89+
}
90+
bigint top(size_t to_keep) const {
91+
if (to_keep >= size(digits)) {
92+
return pad(to_keep - size(digits));
93+
} else {
94+
return drop(size(digits) - to_keep);
95+
}
96+
}
97+
2298
bigint& normalize() {
2399
while (!empty(digits) && digits.back() == 0) {
24100
digits.pop_back();
@@ -223,43 +299,39 @@ namespace cp_algo::math {
223299
return in;
224300
}
225301

226-
template<base_v base>
227-
decltype(std::cout)& operator << (decltype(std::cout) &out, cp_algo::math::bigint<base> const& x) {
302+
template<base_v base, bool fill = true>
303+
auto& print_digit(auto &out, uint64_t d) {
228304
char buf[16];
229-
if (size(x.digits) <= 1) {
230-
if (x.negative) {
231-
out << '-';
232-
}
233-
if constexpr (base == x16) {
234-
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), empty(x.digits) ? 0 : x.digits[0], bigint<base>::sub_base);
235-
std::ranges::transform(buf, buf, toupper);
236-
return out << std::string_view(buf, ptr - buf);
237-
} else {
238-
return out << (empty(x.digits) ? 0 : x.digits[0]);
239-
}
305+
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), d, bigint<base>::sub_base);
306+
if constexpr (base == x16) {
307+
std::ranges::transform(buf, buf, toupper);
240308
}
309+
auto len = ptr - buf;
310+
if constexpr (fill) {
311+
out << std::string(bigint<base>::digit_length - len, '0');
312+
}
313+
return out << std::string_view(buf, len);
314+
}
315+
316+
template<bool fill_all = false, base_v base>
317+
auto& print_bigint(auto &out, cp_algo::math::bigint<base> const& x) {
241318
if (x.negative) {
242319
out << '-';
243320
}
244321
if (empty(x.digits)) {
245-
return out << '0';
246-
}
247-
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), x.digits.back(), bigint<base>::sub_base);
248-
if constexpr (base == x16) {
249-
std::ranges::transform(buf, buf, toupper);
322+
return print_digit<base, fill_all>(out, 0);
250323
}
251-
out << std::string_view(buf, ptr - buf);
324+
print_digit<base, fill_all>(out, x.digits.back());
252325
for (auto d: x.digits | std::views::reverse | std::views::drop(1)) {
253-
auto [ptr, ec] = std::to_chars(buf, buf + sizeof(buf), d, bigint<base>::sub_base);
254-
if constexpr (base == x16) {
255-
std::ranges::transform(buf, buf, toupper);
256-
}
257-
auto len = ptr - buf;
258-
out << std::string(bigint<base>::digit_length - len, '0');
259-
out << std::string_view(buf, len);
326+
print_digit<base, true>(out, d);
260327
}
261328
return out;
262329
}
330+
331+
template<base_v base>
332+
decltype(std::cout)& operator << (decltype(std::cout) &out, cp_algo::math::bigint<base> const& x) {
333+
return print_bigint(out, x);
334+
}
263335
}
264336

265337
#endif // CP_ALGO_MATH_BIGINT_HPP

cp-algo/math/decimal.hpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#ifndef CP_ALGO_MATH_DECIMAL_HPP
2+
#define CP_ALGO_MATH_DECIMAL_HPP
3+
#include "bigint.hpp"
4+
#include <utility>
5+
6+
namespace cp_algo::math {
7+
template<base_v base = x10>
8+
struct decimal {
9+
bigint<base> value;
10+
int64_t scale; // value * base^scale
11+
12+
decimal(int64_t v=0, int64_t s=0): value(bigint<base>(v)), scale(s) {}
13+
decimal(bigint<base> v, int64_t s=0): value(v), scale(s) {}
14+
15+
decimal& operator *= (const decimal &other) {
16+
value *= other.value;
17+
scale += other.scale;
18+
return *this;
19+
}
20+
decimal& operator += (decimal const& other) {
21+
if (scale < other.scale) {
22+
value += other.value.pad(other.scale - scale);
23+
} else {
24+
value.pad_inplace(scale - other.scale);
25+
value += other.value;
26+
scale = other.scale;
27+
}
28+
return *this;
29+
}
30+
decimal& operator -= (decimal const& other) {
31+
if (scale < other.scale) {
32+
value -= other.value.pad(other.scale - scale);
33+
} else {
34+
value.pad_inplace(scale - other.scale);
35+
value -= other.value;
36+
scale = other.scale;
37+
}
38+
return *this;
39+
}
40+
decimal operator * (const decimal &other) const {
41+
return decimal(*this) *= other;
42+
}
43+
decimal operator + (const decimal &other) const {
44+
return decimal(*this) += other;
45+
}
46+
decimal operator - (const decimal &other) const {
47+
return decimal(*this) -= other;
48+
}
49+
auto split() const {
50+
auto int_part = scale >= -ssize(value.digits) ? value.top(ssize(value.digits) + scale) : bigint<base>(0);
51+
auto frac_part = *this - decimal(int_part);
52+
return std::pair{int_part, frac_part};
53+
}
54+
void print() {
55+
auto [int_part, frac_part] = split();
56+
print_bigint(std::cout, int_part);
57+
if (frac_part.value != bigint<base>(0)) {
58+
std::cout << '.';
59+
std::cout << std::string(bigint<base>::digit_length * (-frac_part.magnitude()), '0');
60+
frac_part.value.negative = false;
61+
print_bigint<true>(std::cout, frac_part.value);
62+
}
63+
std::cout << std::endl;
64+
}
65+
bigint<base> trunc() const {
66+
if (scale >= 0) {
67+
return value.pad(scale);
68+
} else if (-scale >= ssize(value.digits)) {
69+
return 0;
70+
} else {
71+
return value.top(ssize(value.digits) + scale);
72+
}
73+
}
74+
bigint<base> round() const {
75+
if (scale >= 0) {
76+
return value.pad(scale);
77+
} else if (-scale > ssize(value.digits)) {
78+
return 0;
79+
} else {
80+
auto res = value.top(ssize(value.digits) + scale);
81+
if (value.digits[-scale - 1] * 2 >= bigint<base>::Base) {
82+
res += 1;
83+
}
84+
return res;
85+
}
86+
}
87+
decimal trunc(size_t digits) const {
88+
digits = std::min(digits, size(value.digits));
89+
return decimal(
90+
value.top(digits),
91+
scale + ssize(value.digits) - digits
92+
);
93+
}
94+
auto magnitude() const {
95+
static constexpr int64_t inf = 1e18;
96+
if (value.digits.empty()) return -inf;
97+
return ssize(value.digits) + scale;
98+
}
99+
decimal inv(int64_t precision) {
100+
assert(precision >= 0);
101+
int64_t lead = llround((double)bigint<base>::Base / (double)value.digits.back());
102+
decimal d(bigint<base>(lead), -ssize(value.digits));
103+
size_t cur = 2;
104+
decimal amend = decimal(1) - trunc(cur) * d;
105+
while(-amend.magnitude() <= precision) {
106+
d += d * amend;
107+
cur = 2 * (1 - amend.magnitude());
108+
d = d.trunc(cur);
109+
amend = decimal(1) - trunc(cur) * d;
110+
}
111+
return d;
112+
}
113+
};
114+
115+
template<base_v base>
116+
auto divmod(bigint<base> const& a, bigint<base> const& b) {
117+
if (a < b) {
118+
return std::pair{bigint<base>(0), a};
119+
}
120+
auto A = decimal<base>(a);
121+
auto B = decimal<base>(b);
122+
auto d = (A * B.inv(A.magnitude())).trunc();
123+
auto r = a - d * b;
124+
if (r >= b) {
125+
d += 1;
126+
r -= b;
127+
}
128+
return std::pair{d, r};
129+
}
130+
}
131+
132+
#endif // CP_ALGO_MATH_DECIMAL_HPP

verify/bigint/divmod.test.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// @brief Division of Big Integers
2+
#define PROBLEM "https://judge.yosupo.jp/problem/division_of_big_integers"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
//#define CP_ALGO_CHECKPOINT
7+
#include <iostream>
8+
#include "blazingio/blazingio.min.hpp"
9+
#include "cp-algo/math/decimal.hpp"
10+
#include "cp-algo/util/checkpoint.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
using namespace cp_algo::math;
15+
16+
void solve() {
17+
bigint a, b;
18+
cin >> a >> b;
19+
auto [d, r] = divmod(a, b);
20+
cout << d << ' ' << r << '\n';
21+
}
22+
23+
signed main() {
24+
ios::sync_with_stdio(0);
25+
cin.tie(0);
26+
int t = 1;
27+
cin >> t;
28+
while(t--) {
29+
solve();
30+
}
31+
cp_algo::checkpoint<1>();
32+
}

verify/bigint/hex_divmod.test.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// @brief Division of Hex Big Integers
2+
#define PROBLEM "https://judge.yosupo.jp/problem/division_of_hex_big_integers"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
//#define CP_ALGO_CHECKPOINT
7+
#include <iostream>
8+
#include "blazingio/blazingio.min.hpp"
9+
#include "cp-algo/math/decimal.hpp"
10+
#include "cp-algo/util/checkpoint.hpp"
11+
#include <bits/stdc++.h>
12+
13+
using namespace std;
14+
using namespace cp_algo::math;
15+
16+
void solve() {
17+
bigint<x16> a, b;
18+
cin >> a >> b;
19+
auto [d, r] = divmod(a, b);
20+
cout << d << ' ' << r << '\n';
21+
}
22+
23+
signed main() {
24+
ios::sync_with_stdio(0);
25+
cin.tie(0);
26+
int t = 1;
27+
cin >> t;
28+
while(t--) {
29+
solve();
30+
}
31+
cp_algo::checkpoint<1>();
32+
}

0 commit comments

Comments
 (0)