Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions inc/zoo/swar/associative_iteration.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ template<int NB, typename B>
constexpr auto makeLaneMaskFromMSB(SWAR<NB, B> input) {
using S = SWAR<NB, B>;
auto msb = input & S{S::MostSignificantBit};
auto msbCopiedToLSB = S{msb.value() >> (NB - 1)};
auto msbCopiedToLSB = S{static_cast<B>(msb.value() >> (NB - 1))};
return impl::makeLaneMaskFromMSB_and_LSB(msb, msbCopiedToLSB);
}

Expand Down Expand Up @@ -218,16 +218,69 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount(

auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{msbCleared.value() << 1};
return S{static_cast<T>(msbCleared.value() << 1)};
};

multiplier = S{multiplier.value() << (NB - ActualBits)};
multiplier = S{static_cast<T>(multiplier.value() << (NB - ActualBits))};
return associativeOperatorIterated_regressive(
multiplicand, S{0}, multiplier, S{S::MostSignificantBit}, operation,
ActualBits, halver
);
}


/*
// extended from mathematics to generic programming
// see https://github.com/jamierpond/fmtgp/blob/main/2_first_algo/main.cpp

template <typename T> constexpr T exp_acc(T r, T a, T n) {
for (;;) {
if (is_odd(n)) {
r = multiply(r, a);
if (n == 1) {
return r;
}
}
n = half(n);
a = multiply(a, a);
}
}
*/

template<int ActualBits, int NB, typename T>
constexpr auto expo_OverflowUnsafe_SpecificBitCount(
SWAR<NB, T> x,
SWAR<NB, T> exponent
) {
using S = SWAR<NB, T>;

auto operation = [](auto left, auto right, auto counts) {
const auto mask = makeLaneMaskFromMSB(counts);
const auto antiMask = ~mask;
const auto product =
multiplication_OverflowUnsafe_SpecificBitCount<ActualBits>(left, right);
/*
* if (count)
* return product;
* else
* return left;
*/
return (product & mask) | (left & antiMask);
};

// halver should work same as multiplication... i think...
auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{static_cast<T>(msbCleared.value() << 1)};
};

exponent = S{static_cast<T>(exponent.value() << (NB - ActualBits))};
return associativeOperatorIterated_regressive(
x, S{1}, exponent, S{S::MostSignificantBit}, operation,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your neutral is, lane wise, { 0, ..., 1 }, so, the upper lanes are initialized to 0
You could S{meta::BitmaskMaker<BaseType, 1, BitsPerLane>::value()}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh thanks for that, that makes total sense!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would broadcast(S{1}) be more readable?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly, but using broadcast is a complication because it relies on multiplication.
I have been slowly internalizing what needs to be done to support SIMD: no assumptions about the availability of whole-base-type multiplication. In ARM NEON, x86-64 SSE of 128 bits, there is no such thing, but at most lane-wise 64 bit multiplication, same with AVX 256 bits.

BitmaskMaker can be upgraded to support longer types.
The problem is that SIMD negate constexpr, it's going to give us pain to make the non-constexpr flavor of the library

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, indeed, subtle.

Perhaps not so subtle in the SIMD constexpr negation...

ActualBits, halver
);
}

/// \note Not removed yet because it is an example of "progressive" associative exponentiation
template<int ActualBits, int NB, typename T>
constexpr auto multiplication_OverflowUnsafe_SpecificBitCount_deprecated(
Expand Down Expand Up @@ -261,6 +314,17 @@ constexpr auto multiplication_OverflowUnsafe(
);
}

template<int NB, typename T>
constexpr auto expo_OverflowUnsafe(
SWAR<NB, T> base,
SWAR<NB, T> exponent
) {
return
expo_OverflowUnsafe_SpecificBitCount<NB>(
base, exponent
);
}

template<int NB, typename T>
struct SWAR_Pair{
SWAR<NB, T> even, odd;
Expand Down
16 changes: 16 additions & 0 deletions test/swar/BasicOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,24 @@ static_assert(
multiplication_OverflowUnsafe_SpecificBitCount<3>(Micand, Mplier).value()
);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GOOD NEWS. Looks like my initial, first-pass implementation was correct after all!
Seems like part of the issue was how I'm trying to compare the binary literals.

All value I tried when evaluating using the hex idiom you guys have already got provides consistently correct answers.

TEST_CASE("Jamie's wip expo") {
// the LSB lanes seem to be correct, but the MSB lanes are not...
constexpr auto base = SWAR<8, u32>{0b0001'0011}; // 2 | 3
constexpr auto exponent = SWAR<8, u32>{0b0001'0010}; // 3 | 2
constexpr auto expected = SWAR<8, u32>{0b0001'1001}; // 8 | 9
// static_assert(
// expected.value() == expo_OverflowUnsafe(base, exponent).value()
// );
auto actual = expo_OverflowUnsafe(base, exponent);
CHECK(expected.value() == actual.value());
auto expected_as_bits = std::bitset<32>(expected.value());
auto actual_as_bits = std::bitset<32>(actual.value());
printf("expected: %s\n", expected_as_bits.to_string().c_str());
printf("actual: %s\n", actual_as_bits.to_string().c_str());
}

} // namespace Multiplication

#define HE(nbits, t, v0, v1) \
static_assert(horizontalEquality<nbits, t>(\
SWAR<nbits, t>(v0),\
Expand Down