Skip to content

Commit 56bfa8b

Browse files
Resolve peer types for enum literal, enum, and union
1 parent dcd6fbd commit 56bfa8b

File tree

2 files changed

+139
-6
lines changed

2 files changed

+139
-6
lines changed

src/analysis.zig

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,13 @@ pub fn resolveUnwrapErrorUnionType(analyser: *Analyser, ty: Type, side: ErrorUni
982982
}
983983

984984
fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985+
const tag_type = try analyser.resolveUnionTagType(ty) orelse
986+
return null;
987+
988+
return try tag_type.instanceTypeVal(analyser);
989+
}
990+
991+
fn resolveUnionTagType(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985992
if (!ty.is_type_val)
986993
return null;
987994

@@ -1000,12 +1007,10 @@ fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
10001007
return null;
10011008

10021009
if (container_decl.ast.enum_token != null)
1003-
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = false };
1010+
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = true };
10041011

1005-
if (container_decl.ast.arg.unwrap()) |arg| {
1006-
const tag_type = (try analyser.resolveTypeOfNode(.of(arg, handle))) orelse return null;
1007-
return try tag_type.instanceTypeVal(analyser) orelse return null;
1008-
}
1012+
if (container_decl.ast.arg.unwrap()) |arg|
1013+
return try analyser.resolveTypeOfNode(.of(arg, handle));
10091014

10101015
return null;
10111016
}
@@ -7031,7 +7036,59 @@ fn resolvePeerTypesInner(analyser: *Analyser, peer_tys: []?Type) !?Type {
70317036
return opt_cur_ty.?;
70327037
},
70337038

7034-
.enum_or_union => return null, // TODO
7039+
.enum_or_union => {
7040+
var opt_cur_ty: ?Type = null;
7041+
7042+
for (peer_tys) |opt_ty| {
7043+
const ty = opt_ty orelse continue;
7044+
switch (ty.zigTypeTag(analyser).?) {
7045+
.enum_literal, .@"enum", .@"union" => {},
7046+
else => return null,
7047+
}
7048+
const cur_ty = opt_cur_ty orelse {
7049+
opt_cur_ty = ty;
7050+
continue;
7051+
};
7052+
7053+
if ((cur_ty.isUnionType() and !cur_ty.isTaggedUnion()) or
7054+
(ty.isUnionType() and !ty.isTaggedUnion()))
7055+
{
7056+
if (cur_ty.eql(ty)) continue;
7057+
return null;
7058+
}
7059+
7060+
switch (cur_ty.zigTypeTag(analyser).?) {
7061+
.enum_literal => {
7062+
opt_cur_ty = ty;
7063+
},
7064+
.@"enum" => switch (ty.zigTypeTag(analyser).?) {
7065+
.enum_literal => {},
7066+
.@"enum" => {
7067+
if (!ty.eql(cur_ty)) return null;
7068+
},
7069+
.@"union" => {
7070+
const tag_ty = try analyser.resolveUnionTagType(ty);
7071+
if (!tag_ty.?.eql(cur_ty)) return null;
7072+
opt_cur_ty = ty;
7073+
},
7074+
else => unreachable,
7075+
},
7076+
.@"union" => switch (ty.zigTypeTag(analyser).?) {
7077+
.enum_literal => {},
7078+
.@"enum" => {
7079+
const cur_tag_ty = try analyser.resolveUnionTagType(cur_ty);
7080+
if (!ty.eql(cur_tag_ty.?)) return null;
7081+
},
7082+
.@"union" => {
7083+
if (!ty.eql(cur_ty)) return null;
7084+
},
7085+
else => unreachable,
7086+
},
7087+
else => unreachable,
7088+
}
7089+
}
7090+
return opt_cur_ty.?;
7091+
},
70357092

70367093
.int_or_float => {
70377094
var ip_indices: std.ArrayListUnmanaged(InternPool.Index) = try .initCapacity(analyser.gpa, peer_tys.len);

tests/analysis/peer_type_resolution.zig

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,36 @@ const S = struct {
55
int: i64,
66
float: f32,
77
};
8+
89
const s: S = .{
910
.int = 0,
1011
.float = 1.2,
1112
};
1213

14+
const E = enum {
15+
foo,
16+
bar,
17+
baz,
18+
};
19+
20+
const e: E = .bar;
21+
22+
const T = union(E) {
23+
foo: void,
24+
bar: void,
25+
baz: u16,
26+
};
27+
28+
const t: T = .{ .baz = 3 };
29+
30+
const U = union {
31+
foo: void,
32+
bar: void,
33+
baz: void,
34+
};
35+
36+
const u: U = .{ .baz = {} };
37+
1338
var runtime_bool: bool = true;
1439
var runtime_int: i32 = 0;
1540

@@ -270,6 +295,57 @@ fn void_fn() void {}
270295
const optional_fn = if (comptime_bool) @as(?fn () void, void_fn) else void_fn;
271296
// ^^^^^^^^^^^ (?fn () void)()
272297

298+
const optional_enum_literal = if (comptime_bool) @as(?@Type(.enum_literal), .foo) else .bar;
299+
// ^^^^^^^^^^^^^^^^^^^^^ (?@Type(.enum_literal))()
300+
301+
const enum_literal_and_enum_literal = if (comptime_bool) .foo else .bar;
302+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (@Type(.enum_literal))()
303+
304+
const enum_literal_and_enum = if (comptime_bool) .foo else e;
305+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
306+
307+
const enum_literal_and_union = if (comptime_bool) .foo else u;
308+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
309+
310+
const enum_literal_and_tagged_union = if (comptime_bool) .foo else t;
311+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
312+
313+
const enum_and_enum_literal = if (comptime_bool) e else .foo;
314+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
315+
316+
const enum_and_enum = if (comptime_bool) e else e;
317+
// ^^^^^^^^^^^^^ (E)()
318+
319+
const enum_and_union = if (comptime_bool) e else u;
320+
// ^^^^^^^^^^^^^^ (either type)()
321+
322+
const enum_and_tagged_union = if (comptime_bool) e else t;
323+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
324+
325+
const union_and_enum_literal = if (comptime_bool) u else .foo;
326+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
327+
328+
const union_and_enum = if (comptime_bool) u else e;
329+
// ^^^^^^^^^^^^^^ (either type)()
330+
331+
const union_and_union = if (comptime_bool) u else u;
332+
// ^^^^^^^^^^^^^^^ (U)()
333+
334+
const union_and_tagged_union = if (comptime_bool) u else t;
335+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
336+
337+
const tagged_union_and_enum_literal = if (comptime_bool) t else .foo;
338+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
339+
340+
const tagged_union_and_enum = if (comptime_bool) t else e;
341+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
342+
343+
const tagged_union_and_union = if (comptime_bool) t else u;
344+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
345+
346+
const tagged_union_and_tagged_union = if (comptime_bool) t else t;
347+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
348+
273349
const f32_and_u32 = if (comptime_bool) @as(f32, 0) else @as(i32, 0);
274350
// ^^^^^^^^^^^ (either type)()
275351

0 commit comments

Comments
 (0)