Skip to content

Commit 5b6e734

Browse files
committed
Vectorized float-to-integer conversions on AVX2 and AVX512
1 parent ab30275 commit 5b6e734

File tree

7 files changed

+352
-32
lines changed

7 files changed

+352
-32
lines changed

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/asm/amd64/AMD64Assembler.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2021,8 +2021,10 @@ public void emit(AMD64Assembler asm, AVXSize size, Register dst, AMD64Address sr
20212021
*/
20222022
public static final class EvexRMIOp extends VexRMIOp {
20232023
// @formatter:off
2024-
public static final EvexRMIOp EVFPCLASSSS = new EvexRMIOp("EVFPCLASS", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W0, 0x67, VEXOpAssertion.MASK_XMM_AVX512DQ_128, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0);
2025-
public static final EvexRMIOp EVFPCLASSSD = new EvexRMIOp("EVFPCLASD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W1, 0x67, VEXOpAssertion.MASK_XMM_AVX512DQ_128, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1);
2024+
public static final EvexRMIOp EVFPCLASSSS = new EvexRMIOp("EVFPCLASSSS", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W0, 0x67, VEXOpAssertion.MASK_XMM_AVX512DQ_128, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0);
2025+
public static final EvexRMIOp EVFPCLASSSD = new EvexRMIOp("EVFPCLASSSD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W1, 0x67, VEXOpAssertion.MASK_XMM_AVX512DQ_128, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1);
2026+
public static final EvexRMIOp EVFPCLASSPS = new EvexRMIOp("EVFPCLASSPS", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W0, 0x66, VEXOpAssertion.MASK_XMM_AVX512DQ_VL, EVEXTuple.T1F_32BIT, VEXPrefixConfig.W0);
2027+
public static final EvexRMIOp EVFPCLASSPD = new EvexRMIOp("EVFPCLASSPD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F3A, VEXPrefixConfig.W1, 0x66, VEXOpAssertion.MASK_XMM_AVX512DQ_VL, EVEXTuple.T1F_64BIT, VEXPrefixConfig.W1);
20262028
// @formatter:on
20272029

20282030
private EvexRMIOp(String opcode, int pp, int mmmmm, int w, int op, VEXOpAssertion assertion, EVEXTuple evexTuple, int wEvex) {
@@ -6273,6 +6275,14 @@ public final void kshiftrw(Register dst, Register src, int imm8) {
62736275
VexMaskRRIOp.KSHIFTRW.emit(this, AVXSize.XMM, dst, src, imm8);
62746276
}
62756277

6278+
public final void ktestb(Register src1, Register src2) {
6279+
VexRROp.KTESTB.emit(this, AVXSize.XMM, src1, src2);
6280+
}
6281+
6282+
public final void ktestw(Register src1, Register src2) {
6283+
VexRROp.KTESTW.emit(this, AVXSize.XMM, src1, src2);
6284+
}
6285+
62766286
public final void ktestd(Register src1, Register src2) {
62776287
VexRROp.KTESTD.emit(this, AVXSize.XMM, src1, src2);
62786288
}

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/core/common/calc/FloatConvert.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ public int getInputBits() {
9393
return inputBits;
9494
}
9595

96+
/**
97+
* Returns {@code true} if this operation's input bit size is strictly larger than its output
98+
* bit size.
99+
*/
100+
public boolean isNarrowing() {
101+
return inputBits > reverse().inputBits;
102+
}
103+
96104
/**
97105
* Returns the conversion operation corresponding to a conversion from {@code from} to
98106
* {@code to}. Returns {@code null} if the given stamps don't correspond to a conversion
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation. Oracle designates this
8+
* particular file as subject to the "Classpath" exception as provided
9+
* by Oracle in the LICENSE file that accompanied this code.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22+
* or visit www.oracle.com if you need additional information or have any
23+
* questions.
24+
*/
25+
26+
package jdk.graal.compiler.lir.amd64.vector;
27+
28+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.EvexRMIOp.EVFPCLASSPD;
29+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.EvexRMIOp.EVFPCLASSPS;
30+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexFloatCompareOp.EVCMPPD;
31+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexFloatCompareOp.EVCMPPS;
32+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexFloatCompareOp.VCMPPD;
33+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexFloatCompareOp.VCMPPS;
34+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexMoveOp.EVMOVDQU32;
35+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexMoveOp.EVMOVDQU64;
36+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRMIOp.VPERMQ;
37+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRMIOp.VPSHUFD;
38+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMIOp.EVPTERNLOGD;
39+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMIOp.EVPTERNLOGQ;
40+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp.EVPCMPEQD;
41+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp.EVPCMPEQQ;
42+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp.VPCMPEQD;
43+
import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVROp.KANDNW;
44+
import static jdk.graal.compiler.lir.LIRInstruction.OperandFlag.ILLEGAL;
45+
import static jdk.graal.compiler.lir.LIRInstruction.OperandFlag.REG;
46+
import static jdk.vm.ci.code.ValueUtil.asRegister;
47+
48+
import java.nio.ByteBuffer;
49+
import java.nio.ByteOrder;
50+
51+
import jdk.graal.compiler.asm.Label;
52+
import jdk.graal.compiler.asm.amd64.AMD64Address;
53+
import jdk.graal.compiler.asm.amd64.AMD64Assembler;
54+
import jdk.graal.compiler.asm.amd64.AMD64BaseAssembler.EVEXPrefixConfig;
55+
import jdk.graal.compiler.asm.amd64.AMD64MacroAssembler;
56+
import jdk.graal.compiler.asm.amd64.AVXKind;
57+
import jdk.graal.compiler.core.common.LIRKind;
58+
import jdk.graal.compiler.core.common.NumUtil.Signedness;
59+
import jdk.graal.compiler.debug.GraalError;
60+
import jdk.graal.compiler.lir.LIRInstructionClass;
61+
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorUnary.FloatPointClassTestOp;
62+
import jdk.graal.compiler.lir.asm.CompilationResultBuilder;
63+
import jdk.graal.compiler.lir.gen.LIRGeneratorTool;
64+
import jdk.vm.ci.amd64.AMD64;
65+
import jdk.vm.ci.amd64.AMD64Kind;
66+
import jdk.vm.ci.code.Register;
67+
import jdk.vm.ci.meta.Value;
68+
69+
/**
70+
* Floating point to integer conversion according to Java semantics. This wraps an AMD64 conversion
71+
* instruction and adjusts its result as needed. According to Java semantics, NaN inputs should be
72+
* mapped to 0, and values outside the integer type's range should be mapped to {@code MIN_VALUE} or
73+
* {@code MAX_VALUE} according to the input's sign. The AMD64 instructions produce {@code MIN_VALUE}
74+
* for NaNs and all values outside the integer type's range. So we need to fix up the result for
75+
* NaNs and positive overflowing values. Negative overflowing values keep {@code MIN_VALUE}.
76+
* </p>
77+
*
78+
* AVX1 is not supported. On AVX2, only conversions to {@code int} but not {@code long} are
79+
* available.
80+
* </p>
81+
*
82+
* Unsigned mode ({@link jdk.graal.compiler.lir.amd64.AMD64ConvertFloatToIntegerOp}) is currently
83+
* not supported.
84+
*/
85+
public class AMD64VectorConvertFloatToIntegerOp extends AMD64VectorInstruction {
86+
public static final LIRInstructionClass<AMD64VectorConvertFloatToIntegerOp> TYPE = LIRInstructionClass.create(AMD64VectorConvertFloatToIntegerOp.class);
87+
88+
@Def({REG}) protected Value dstValue;
89+
@Alive({REG}) protected Value srcValue;
90+
/** Mask indicating those vector elements that may need to be fixed to match Java semantics. */
91+
@Temp({REG, ILLEGAL}) protected Value badElementMaskValue;
92+
/** Mask used for the result of various intermediate operations. */
93+
@Temp({REG, ILLEGAL}) protected Value compareMaskValue;
94+
95+
private final OpcodeEmitter emitter;
96+
private final boolean canBeNaN;
97+
private final boolean canOverflow;
98+
99+
@FunctionalInterface
100+
public interface OpcodeEmitter {
101+
/** Emit the actual conversion instruction. */
102+
void emit(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register dst, Register src);
103+
}
104+
105+
public AMD64VectorConvertFloatToIntegerOp(LIRGeneratorTool tool, OpcodeEmitter emitter, AVXKind.AVXSize size, Value dstValue, Value srcValue, boolean canBeNaN, boolean canOverflow,
106+
Signedness signedness) {
107+
super(TYPE, size);
108+
this.dstValue = dstValue;
109+
this.srcValue = srcValue;
110+
this.emitter = emitter;
111+
if (canBeNaN || canOverflow) {
112+
AMD64Kind maskKind;
113+
if (((AMD64) tool.target().arch).getFeatures().contains(AMD64.CPUFeature.AVX512F)) {
114+
GraalError.guarantee(Math.max(dstValue.getPlatformKind().getVectorLength(), srcValue.getPlatformKind().getVectorLength()) <= 16, "expect at most 16-element vectors");
115+
maskKind = AMD64Kind.MASK16;
116+
} else {
117+
maskKind = AVXKind.getAVXKind(AMD64Kind.BYTE, Math.max(srcValue.getPlatformKind().getSizeInBytes(), dstValue.getPlatformKind().getSizeInBytes()));
118+
}
119+
this.badElementMaskValue = tool.newVariable(LIRKind.value(maskKind));
120+
this.compareMaskValue = tool.newVariable(LIRKind.value(maskKind));
121+
} else {
122+
this.badElementMaskValue = Value.ILLEGAL;
123+
this.compareMaskValue = Value.ILLEGAL;
124+
}
125+
this.canBeNaN = canBeNaN;
126+
this.canOverflow = canOverflow;
127+
128+
GraalError.guarantee(signedness == Signedness.SIGNED, "only signed vector float-to-integer conversions are supported");
129+
}
130+
131+
@Override
132+
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
133+
if (masm.getFeatures().contains(AMD64.CPUFeature.AVX512F)) {
134+
GraalError.guarantee(masm.supportsFullAVX512(), "expect full AVX-512 support");
135+
emitAVX512(crb, masm);
136+
} else {
137+
emitAVX2(crb, masm);
138+
}
139+
}
140+
141+
private void emitAVX512(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
142+
AMD64Kind srcKind = (AMD64Kind) srcValue.getPlatformKind();
143+
AVXKind.AVXSize srcSize = AVXKind.getRegisterSize(srcKind);
144+
AMD64Assembler.VexFloatCompareOp floatCompare = srcKind.getScalar().equals(AMD64Kind.DOUBLE) ? EVCMPPD : EVCMPPS;
145+
AMD64Assembler.EvexRMIOp floatClassify = srcKind.getScalar().equals(AMD64Kind.DOUBLE) ? EVFPCLASSPD : EVFPCLASSPS;
146+
AMD64Kind dstKind = (AMD64Kind) dstValue.getPlatformKind();
147+
AVXKind.AVXSize dstSize = AVXKind.getRegisterSize(dstKind);
148+
AMD64Assembler.VexRVMOp integerEquals = dstKind.getScalar().equals(AMD64Kind.QWORD) ? EVPCMPEQQ : EVPCMPEQD;
149+
AMD64Assembler.VexMoveOp integerMove = dstKind.getScalar().equals(AMD64Kind.QWORD) ? EVMOVDQU64 : EVMOVDQU32;
150+
AMD64Assembler.VexRVMIOp ternlog = dstKind.getScalar().equals(AMD64Kind.QWORD) ? EVPTERNLOGQ : EVPTERNLOGD;
151+
152+
Register dst = asRegister(dstValue);
153+
Register src = asRegister(srcValue);
154+
155+
emitter.emit(crb, masm, dst, src);
156+
157+
if (!canBeNaN && !canOverflow) {
158+
/* No fixup needed. */
159+
return;
160+
}
161+
162+
Register badElementMask = asRegister(badElementMaskValue);
163+
Register compareMask = asRegister(compareMaskValue);
164+
Label done = new Label();
165+
166+
/* badElementMask = (dst == MIN_VALUE); (element-wise) */
167+
AMD64Address minValueVector = minValueVector(crb, dstKind);
168+
integerEquals.emit(masm, dstSize, badElementMask, dst, minValueVector);
169+
/* if (!anySet(badElementMask)) { goto done; } */
170+
masm.ktestw(badElementMask, badElementMask);
171+
masm.jcc(AMD64Assembler.ConditionFlag.Equal, done, true);
172+
173+
if (canBeNaN) {
174+
/* compareMask = !isNaN(src); (element-wise) */
175+
floatCompare.emit(masm, srcSize, compareMask, src, src, AMD64Assembler.VexFloatCompareOp.Predicate.ORD_Q);
176+
/* Zero all elements where compareMask is 0, i.e., all elements where src is NaN. */
177+
integerMove.emit(masm, dstSize, dst, dst, compareMask, EVEXPrefixConfig.Z1, EVEXPrefixConfig.B0);
178+
}
179+
180+
if (canOverflow) {
181+
/* compareMask = !(src >= 0.0); (element-wise) */
182+
int anyNaN = FloatPointClassTestOp.QUIET_NAN | FloatPointClassTestOp.SIG_NAN;
183+
int anyNegative = FloatPointClassTestOp.FIN_NEG | FloatPointClassTestOp.NEG_INF | FloatPointClassTestOp.NEG_ZERO;
184+
floatClassify.emit(masm, srcSize, compareMask, src, anyNaN | anyNegative);
185+
/* compareMask = (src >= 0.0) & badElement (element-wise) */
186+
KANDNW.emit(masm, compareMask, compareMask, badElementMask);
187+
/*
188+
* Now the compareMask marks just the positive overflown elements. They are MIN_VALUE,
189+
* we want them to be MAX_VALUE. This is bitwise negation.
190+
*/
191+
int ternlogNotA = 0x0F; // Intel SDM, Table 5-1
192+
ternlog.emit(masm, dstSize, dst, dst, dst, compareMask, ternlogNotA);
193+
}
194+
195+
masm.bind(done);
196+
}
197+
198+
private void emitAVX2(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
199+
AMD64Kind srcKind = (AMD64Kind) srcValue.getPlatformKind();
200+
AVXKind.AVXSize srcSize = AVXKind.getRegisterSize(srcKind);
201+
AMD64Assembler.VexFloatCompareOp floatCompare = srcKind.getScalar().equals(AMD64Kind.DOUBLE) ? VCMPPD : VCMPPS;
202+
AMD64Kind dstKind = (AMD64Kind) dstValue.getPlatformKind();
203+
GraalError.guarantee(dstKind.getScalar().equals(AMD64Kind.DWORD), "only expect conversions to int on AVX 2: %s", dstKind);
204+
AVXKind.AVXSize dstSize = AVXKind.getRegisterSize(dstKind);
205+
AMD64Assembler.VexRVMOp integerEquals = VPCMPEQD;
206+
207+
Register dst = asRegister(dstValue);
208+
Register src = asRegister(srcValue);
209+
210+
emitter.emit(crb, masm, dst, src);
211+
212+
if (!canBeNaN && !canOverflow) {
213+
/* No fixup needed. */
214+
return;
215+
}
216+
217+
Register badElementMask = asRegister(badElementMaskValue);
218+
Register compareMask = asRegister(compareMaskValue);
219+
Label done = new Label();
220+
221+
/* badElementMask = (dst == MIN_VALUE); (element-wise) */
222+
AMD64Address minValueVector = minValueVector(crb, dstKind);
223+
integerEquals.emit(masm, dstSize, badElementMask, dst, minValueVector);
224+
/* if (!anySet(badElementMask)) { goto done; } */
225+
masm.vptest(badElementMask, badElementMask, dstSize);
226+
masm.jcc(AMD64Assembler.ConditionFlag.Equal, done, true);
227+
228+
if (canBeNaN) {
229+
/* compareMask = !isNaN(src); (element-wise) */
230+
floatCompare.emit(masm, srcSize, compareMask, src, src, AMD64Assembler.VexFloatCompareOp.Predicate.ORD_Q);
231+
convertAvx2Mask(masm, compareMask, srcKind, dstKind);
232+
/* Zero all elements where compareMask is 0, i.e., all elements where src is NaN. */
233+
masm.vpand(dst, dst, compareMask, dstSize);
234+
}
235+
236+
if (canOverflow) {
237+
/* compareMask = (0.0 <= src); (element-wise) */
238+
masm.vpxor(compareMask, compareMask, compareMask, srcSize);
239+
floatCompare.emit(masm, srcSize, compareMask, compareMask, src, AMD64Assembler.VexFloatCompareOp.Predicate.LE_OS);
240+
convertAvx2Mask(masm, compareMask, srcKind, dstKind);
241+
/*
242+
* Negate bitwise all elements that are bad and where the source value is positive and
243+
* not NaN (i.e., compareMask is set). Bitwise negation will flip these elements from
244+
* MIN_VALUE to MAX_VALUE as required.
245+
*/
246+
masm.vpand(compareMask, compareMask, badElementMask, dstSize);
247+
masm.vpxor(dst, dst, compareMask, dstSize);
248+
}
249+
250+
masm.bind(done);
251+
}
252+
253+
/**
254+
* Returns the address of a constant of vector kind {@code dstKind} where each element is the
255+
* minimal value for the underlying scalar kind. For example, if {@code dstKind} is the kind
256+
* representing a vector of 4 {@code int}s, then the result will be a 4 * 4 byte constant
257+
* containing the bytes {@code 0x00, 0x00, 0x00, 0x80} ({@link Integer#MIN_VALUE} in a
258+
* little-endian representation) four times.
259+
*/
260+
private static AMD64Address minValueVector(CompilationResultBuilder crb, AMD64Kind dstKind) {
261+
byte[] minValueBytes = new byte[dstKind.getSizeInBytes()];
262+
int elementBytes = dstKind.getScalar().getSizeInBytes();
263+
GraalError.guarantee(dstKind.getScalar().isInteger() && (elementBytes == Integer.BYTES || elementBytes == Long.BYTES), "unexpected destination: %s", dstKind);
264+
ByteBuffer buffer = ByteBuffer.wrap(minValueBytes).order(ByteOrder.LITTLE_ENDIAN);
265+
for (int i = 0; i < dstKind.getVectorLength(); i++) {
266+
if (elementBytes == Integer.BYTES) {
267+
buffer.putInt(Integer.MIN_VALUE);
268+
} else {
269+
buffer.putLong(Long.MIN_VALUE);
270+
}
271+
}
272+
int alignment = crb.dataBuilder.ensureValidDataAlignment(minValueBytes.length);
273+
return (AMD64Address) crb.recordDataReferenceInCode(minValueBytes, alignment);
274+
}
275+
276+
/**
277+
* If the {@code fromKind}'s element kind is larger than the {@code toKind}'s element kind,
278+
* narrow the mask in {@code maskRegister} from the wider size to the narrower one. The
279+
* conversion is done in place.
280+
*/
281+
private static void convertAvx2Mask(AMD64MacroAssembler masm, Register maskRegister, AMD64Kind fromKind, AMD64Kind toKind) {
282+
GraalError.guarantee(fromKind.getVectorLength() == toKind.getVectorLength(), "vector length mismatch: %s, %s", fromKind, toKind);
283+
int fromBytes = fromKind.getScalar().getSizeInBytes();
284+
int toBytes = toKind.getScalar().getSizeInBytes();
285+
GraalError.guarantee((fromBytes == Integer.BYTES || fromBytes == Long.BYTES) && toBytes == Integer.BYTES, "unexpected sizes: %s, %s", fromKind, toKind);
286+
if (fromBytes > toBytes) {
287+
AVXKind.AVXSize shuffleSize = AVXKind.getRegisterSize(fromKind);
288+
/* Narrow using shuffles. */
289+
VPSHUFD.emit(masm, shuffleSize, maskRegister, maskRegister, 0x08);
290+
if (shuffleSize == AVXKind.AVXSize.YMM) {
291+
VPERMQ.emit(masm, shuffleSize, maskRegister, maskRegister, 0x08);
292+
}
293+
}
294+
}
295+
}

0 commit comments

Comments
 (0)