Skip to content

Commit 4e5b42a

Browse files
committed
JIT: Optimize memory usage by patching jump table asap
Signed-off-by: Paul Guyot <pguyot@kallisys.net>
1 parent 1563eeb commit 4e5b42a

File tree

5 files changed

+126
-34
lines changed

5 files changed

+126
-34
lines changed

libs/estdlib/src/code_server.erl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ set_native_code(_Module, _LabelsCount, _Stream) ->
152152
load(Module) ->
153153
case erlang:system_info(emu_flavor) of
154154
jit ->
155-
% atomvm_heap_growth, fibonacci divides compilation time by two
155+
% atomvm_heap_growth, fibonacci reduces compilation time
156156
{Pid, Ref} = spawn_opt(
157157
fun() ->
158158
try

libs/jit/src/jit.erl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@
114114
-define(ASSERT_ALL_NATIVE_FREE(St), ok).
115115
-define(ASSERT(Expr), ok).
116116

117+
%-define(JIT_INSTRUMENT, true).
118+
119+
-ifdef(JIT_INSTRUMENT).
120+
-define(INSTRUMENT(Tag, State, MSt), instrument(Tag, State, MSt)).
121+
-else.
122+
-define(INSTRUMENT(Tag, State, MSt), ok).
123+
-endif.
124+
117125
%%-----------------------------------------------------------------------------
118126
%% @param LabelsCount number of labels
119127
%% @param Arch code for the architecture
@@ -137,7 +145,6 @@ compile(
137145
MMod,
138146
MSt0
139147
) when OpcodeMax =< ?OPCODE_MAX ->
140-
MSt1 = MMod:jump_table(MSt0, LabelsCount),
141148
State0 = #state{
142149
line_offsets = [],
143150
labels_count = LabelsCount,
@@ -146,9 +153,15 @@ compile(
146153
type_resolver = TypeResolver,
147154
tail_cache = []
148155
},
156+
?INSTRUMENT("compile_start", State0, MSt0),
157+
MSt1 = MMod:jump_table(MSt0, LabelsCount),
158+
?INSTRUMENT("after_jump_table", State0, MSt1),
149159
{State1, MSt2} = first_pass(Opcodes, MMod, MSt1, State0),
160+
?INSTRUMENT("after_first_pass", State1, MSt2),
150161
MSt3 = second_pass(MMod, MSt2, State1),
162+
?INSTRUMENT("after_second_pass", State1, MSt3),
151163
MSt4 = MMod:flush(MSt3),
164+
?INSTRUMENT("after_flush", State1, MSt4),
152165
{LabelsCount, MSt4};
153166
compile(
154167
<<16:32, 0:32, OpcodeMax:32, _LabelsCount:32, _FunctionsCount:32, _Opcodes/binary>>,
@@ -3798,3 +3811,28 @@ backend(StreamModule, Stream) ->
37983811
Variant = ?MODULE:variant(),
37993812
BackendState = BackendModule:new(Variant, StreamModule, Stream),
38003813
{BackendModule, BackendState}.
3814+
3815+
-ifdef(JIT_INSTRUMENT).
3816+
instrument(Tag, #state{line_offsets = Lines, tail_cache = TC}, MSt) ->
3817+
StateSize = erts_debug:flat_size({Lines, TC}),
3818+
MStSize = erts_debug:flat_size(MSt),
3819+
LinesCount = length(Lines),
3820+
TCCount = length(TC),
3821+
3822+
% Extract branches count from backend state
3823+
% state record: {state, stream_module, stream, offset, branches, jump_table_start, ...}
3824+
BranchesCount = case element(1, MSt) of
3825+
state -> length(element(5, MSt));
3826+
_ -> unknown
3827+
end,
3828+
3829+
{heap_size, HeapSize} = process_info(self(), heap_size),
3830+
{total_heap_size, TotalHeapSize} = process_info(self(), total_heap_size),
3831+
3832+
io:format(
3833+
"~s: mst=~p words, state=~p words (lines=~p, tc=~p, br=~p), "
3834+
"heap=~p, total_heap=~p~n",
3835+
[Tag, MStSize, StateSize, LinesCount, TCCount, BranchesCount,
3836+
HeapSize, TotalHeapSize]
3837+
).
3838+
-endif.

libs/jit/src/jit_aarch64.erl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
stream :: stream(),
135135
offset :: non_neg_integer(),
136136
branches :: [{non_neg_integer(), non_neg_integer(), non_neg_integer()}],
137+
jump_table_start :: non_neg_integer(),
137138
available_regs :: [aarch64_register()],
138139
used_regs :: [aarch64_register()],
139140
labels :: [{integer() | reference(), integer()}],
@@ -233,6 +234,7 @@ new(Variant, StreamModule, Stream) ->
233234
stream_module = StreamModule,
234235
stream = Stream,
235236
branches = [],
237+
jump_table_start = 0,
236238
offset = StreamModule:offset(Stream),
237239
available_regs = ?AVAILABLE_REGS,
238240
used_regs = [],
@@ -355,22 +357,21 @@ assert_all_native_free(#state{
355357
%% @return Updated backend state
356358
%%-----------------------------------------------------------------------------
357359
-spec jump_table(state(), pos_integer()) -> state().
358-
jump_table(State, LabelsCount) ->
359-
jump_table0(State, 0, LabelsCount).
360+
jump_table(#state{stream_module = StreamModule, stream = Stream0} = State, LabelsCount) ->
361+
JumpTableStart = StreamModule:offset(Stream0),
362+
jump_table0(State#state{jump_table_start = JumpTableStart}, 0, LabelsCount).
360363

361364
-spec jump_table0(state(), non_neg_integer(), pos_integer()) -> state().
362365
jump_table0(State, N, LabelsCount) when N > LabelsCount ->
363366
State;
364367
jump_table0(
365-
#state{stream_module = StreamModule, stream = Stream0, branches = Branches} = State,
368+
#state{stream_module = StreamModule, stream = Stream0} = State,
366369
N,
367370
LabelsCount
368371
) ->
369-
Offset = StreamModule:offset(Stream0),
370372
BranchInstr = jit_aarch64_asm:b(0),
371-
Reloc = {N, Offset, b},
372373
Stream1 = StreamModule:append(Stream0, BranchInstr),
373-
jump_table0(State#state{stream = Stream1, branches = [Reloc | Branches]}, N + 1, LabelsCount).
374+
jump_table0(State#state{stream = Stream1}, N + 1, LabelsCount).
374375

375376
%%-----------------------------------------------------------------------------
376377
%% @doc Rewrite stream to update all branches for labels.
@@ -2334,5 +2335,22 @@ add_label(#state{stream_module = StreamModule, stream = Stream} = State, Label)
23342335
%% @return Updated backend state
23352336
%%-----------------------------------------------------------------------------
23362337
-spec add_label(state(), integer() | reference(), integer()) -> state().
2338+
add_label(
2339+
#state{
2340+
stream_module = StreamModule,
2341+
stream = Stream0,
2342+
jump_table_start = JumpTableStart,
2343+
labels = Labels
2344+
} = State,
2345+
Label,
2346+
LabelOffset
2347+
) when is_integer(Label) ->
2348+
% Patch the jump table entry immediately
2349+
% Each b instruction is 4 bytes
2350+
JumpTableEntryOffset = JumpTableStart + Label * 4,
2351+
RelativeOffset = LabelOffset - JumpTableEntryOffset,
2352+
BranchInstr = jit_aarch64_asm:b(RelativeOffset),
2353+
Stream1 = StreamModule:replace(Stream0, JumpTableEntryOffset, BranchInstr),
2354+
State#state{stream = Stream1, labels = [{Label, LabelOffset} | Labels]};
23372355
add_label(#state{labels = Labels} = State, Label, Offset) ->
23382356
State#state{labels = [{Label, Offset} | Labels]}.

libs/jit/src/jit_armv6m.erl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
stream :: stream(),
134134
offset :: non_neg_integer(),
135135
branches :: [{non_neg_integer(), non_neg_integer(), non_neg_integer()}],
136+
jump_table_start :: non_neg_integer(),
136137
available_regs :: [armv6m_register()],
137138
used_regs :: [armv6m_register()],
138139
labels :: [{integer() | reference(), integer()}],
@@ -246,6 +247,7 @@ new(Variant, StreamModule, Stream) ->
246247
stream_module = StreamModule,
247248
stream = Stream,
248249
branches = [],
250+
jump_table_start = 0,
249251
offset = StreamModule:offset(Stream),
250252
available_regs = ?AVAILABLE_REGS,
251253
used_regs = [],
@@ -379,13 +381,14 @@ assert_all_native_free(#state{
379381
%% @return Updated backend state
380382
%%-----------------------------------------------------------------------------
381383
-spec jump_table(state(), pos_integer()) -> state().
382-
jump_table(State, LabelsCount) ->
383-
jump_table0(State, 0, LabelsCount).
384+
jump_table(#state{stream_module = StreamModule, stream = Stream0} = State, LabelsCount) ->
385+
JumpTableStart = StreamModule:offset(Stream0),
386+
jump_table0(State#state{jump_table_start = JumpTableStart}, 0, LabelsCount).
384387

385388
jump_table0(State, N, LabelsCount) when N > LabelsCount ->
386389
State;
387390
jump_table0(
388-
#state{stream_module = StreamModule, stream = Stream0, branches = Branches} = State,
391+
#state{stream_module = StreamModule, stream = Stream0} = State,
389392
N,
390393
LabelsCount
391394
) ->
@@ -398,15 +401,7 @@ jump_table0(
398401
JumpEntry = <<I1/binary, I2/binary, I3/binary, I4/binary, 16#FFFFFFFF:32>>,
399402
Stream1 = StreamModule:append(Stream0, JumpEntry),
400403

401-
% Add relocation for the data entry so update_branches/2 can patch the jump target
402-
DataOffset = StreamModule:offset(Stream1) - 4,
403-
% Calculate the offset of the add instruction (3rd instruction, at offset 4 from entry start)
404-
EntryStartOffset = StreamModule:offset(Stream1) - 12,
405-
AddInstrOffset = EntryStartOffset + 4,
406-
DataReloc = {N, DataOffset, {jump_table_data, AddInstrOffset}},
407-
UpdatedState = State#state{stream = Stream1, branches = [DataReloc | Branches]},
408-
409-
jump_table0(UpdatedState, N + 1, LabelsCount).
404+
jump_table0(State#state{stream = Stream1}, N + 1, LabelsCount).
410405

411406
%%-----------------------------------------------------------------------------
412407
%% @doc Rewrite stream to update all branches for labels.
@@ -499,13 +494,7 @@ update_branches(
499494
I4 = <<RelativeOffset:32/little>>,
500495
<<I1/binary, I2/binary, I3/binary, I4/binary>>
501496
end
502-
end;
503-
{jump_table_data, AddInstrOffset} ->
504-
% Calculate offset from 'add pc, pc, r3' instruction + 4 to target label
505-
% PC when add instruction executes
506-
AddPC = AddInstrOffset + 4,
507-
RelativeOffset = LabelOffset - AddPC,
508-
<<RelativeOffset:32/little>>
497+
end
509498
end,
510499
Stream1 = StreamModule:replace(Stream0, Offset, NewInstr),
511500
update_branches(State#state{stream = Stream1, branches = BranchesT}).
@@ -3143,5 +3132,34 @@ add_label(#state{stream_module = StreamModule, stream = Stream0} = State0, Label
31433132
%% @return Updated backend state
31443133
%%-----------------------------------------------------------------------------
31453134
-spec add_label(state(), integer() | reference(), integer()) -> state().
3135+
add_label(
3136+
#state{
3137+
stream_module = StreamModule,
3138+
stream = Stream0,
3139+
jump_table_start = JumpTableStart,
3140+
labels = Labels
3141+
} = State,
3142+
Label,
3143+
LabelOffset
3144+
) when is_integer(Label) ->
3145+
% Patch the jump table entry immediately
3146+
% Each jump table entry is 12 bytes:
3147+
% - ldr r3, [pc, 4] (2 bytes) at offset 0
3148+
% - push {...} (2 bytes) at offset 2
3149+
% - add pc, r3 (2 bytes) at offset 4
3150+
% - nop (2 bytes) at offset 6
3151+
% - data (4 bytes) at offset 8
3152+
JumpTableEntryStart = JumpTableStart + Label * 12,
3153+
DataOffset = JumpTableEntryStart + 8,
3154+
AddInstrOffset = JumpTableEntryStart + 4,
3155+
3156+
% Calculate offset from 'add pc, pc, r3' instruction + 4 to target label
3157+
% PC when add instruction executes
3158+
AddPC = AddInstrOffset + 4,
3159+
RelativeOffset = LabelOffset - AddPC,
3160+
DataBytes = <<RelativeOffset:32/little>>,
3161+
3162+
Stream1 = StreamModule:replace(Stream0, DataOffset, DataBytes),
3163+
State#state{stream = Stream1, labels = [{Label, LabelOffset} | Labels]};
31463164
add_label(#state{labels = Labels} = State, Label, Offset) ->
31473165
State#state{labels = [{Label, Offset} | Labels]}.

libs/jit/src/jit_x86_64.erl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
stream :: stream(),
116116
offset :: non_neg_integer(),
117117
branches :: [{non_neg_integer(), non_neg_integer(), non_neg_integer()}],
118+
jump_table_start :: non_neg_integer(),
118119
available_regs :: [x86_64_register()],
119120
used_regs :: [x86_64_register()],
120121
labels :: [{integer() | reference(), integer()}],
@@ -218,6 +219,7 @@ new(Variant, StreamModule, Stream) ->
218219
stream_module = StreamModule,
219220
stream = Stream,
220221
branches = [],
222+
jump_table_start = 0,
221223
offset = StreamModule:offset(Stream),
222224
available_regs = ?AVAILABLE_REGS,
223225
used_regs = [],
@@ -340,21 +342,20 @@ assert_all_native_free(State) ->
340342
%% @return Updated backend state
341343
%%-----------------------------------------------------------------------------
342344
-spec jump_table(state(), pos_integer()) -> state().
343-
jump_table(State, LabelsCount) ->
344-
jump_table0(State, 0, LabelsCount).
345+
jump_table(#state{stream_module = StreamModule, stream = Stream0} = State, LabelsCount) ->
346+
JumpTableStart = StreamModule:offset(Stream0),
347+
jump_table0(State#state{jump_table_start = JumpTableStart}, 0, LabelsCount).
345348

346349
jump_table0(State, N, LabelsCount) when N > LabelsCount ->
347350
State;
348351
jump_table0(
349-
#state{stream_module = StreamModule, stream = Stream0, branches = Branches} = State,
352+
#state{stream_module = StreamModule, stream = Stream0} = State,
350353
N,
351354
LabelsCount
352355
) ->
353-
Offset = StreamModule:offset(Stream0),
354-
{RelocOffset, I1} = jit_x86_64_asm:jmp_rel32(1),
355-
Reloc = {N, Offset + RelocOffset, 32},
356+
{_RelocOffset, I1} = jit_x86_64_asm:jmp_rel32(1),
356357
Stream1 = StreamModule:append(Stream0, I1),
357-
jump_table0(State#state{stream = Stream1, branches = [Reloc | Branches]}, N + 1, LabelsCount).
358+
jump_table0(State#state{stream = Stream1}, N + 1, LabelsCount).
358359

359360
%%-----------------------------------------------------------------------------
360361
%% @doc Rewrite stream to update all branches for labels.
@@ -2063,5 +2064,22 @@ add_label(#state{stream_module = StreamModule, stream = Stream} = State, Label)
20632064
add_label(State, Label, Offset).
20642065

20652066
-spec add_label(state(), integer() | reference(), integer()) -> state().
2067+
add_label(
2068+
#state{
2069+
stream_module = StreamModule,
2070+
stream = Stream0,
2071+
jump_table_start = JumpTableStart,
2072+
labels = Labels
2073+
} = State,
2074+
Label,
2075+
LabelOffset
2076+
) when is_integer(Label) ->
2077+
% Patch the jump table entry immediately
2078+
% Each jmp_rel32 instruction is 5 bytes
2079+
JumpTableEntryOffset = JumpTableStart + Label * 5,
2080+
RelativeOffset = LabelOffset - JumpTableEntryOffset,
2081+
{_RelocOffset, JmpInstruction} = jit_x86_64_asm:jmp_rel32(RelativeOffset),
2082+
Stream1 = StreamModule:replace(Stream0, JumpTableEntryOffset, JmpInstruction),
2083+
State#state{stream = Stream1, labels = [{Label, LabelOffset} | Labels]};
20662084
add_label(#state{labels = Labels} = State, Label, Offset) ->
20672085
State#state{labels = [{Label, Offset} | Labels]}.

0 commit comments

Comments
 (0)