Skip to content

Commit a42c3f8

Browse files
Resolve peer enum/union types
1 parent 534d912 commit a42c3f8

File tree

2 files changed

+139
-6
lines changed

2 files changed

+139
-6
lines changed

src/analysis.zig

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,13 @@ pub fn resolveUnwrapErrorUnionType(analyser: *Analyser, ty: Type, side: ErrorUni
982982
}
983983

984984
fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985+
const tag_type = try analyser.resolveUnionTagType(ty) orelse
986+
return null;
987+
988+
return try tag_type.instanceTypeVal(analyser);
989+
}
990+
991+
fn resolveUnionTagType(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985992
if (!ty.is_type_val)
986993
return null;
987994

@@ -1000,12 +1007,10 @@ fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
10001007
return null;
10011008

10021009
if (container_decl.ast.enum_token != null)
1003-
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = false };
1010+
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = true };
10041011

1005-
if (container_decl.ast.arg.unwrap()) |arg| {
1006-
const tag_type = (try analyser.resolveTypeOfNode(.of(arg, handle))) orelse return null;
1007-
return try tag_type.instanceTypeVal(analyser) orelse return null;
1008-
}
1012+
if (container_decl.ast.arg.unwrap()) |arg|
1013+
return try analyser.resolveTypeOfNode(.of(arg, handle));
10091014

10101015
return null;
10111016
}
@@ -6984,7 +6989,59 @@ fn resolvePeerTypesInner(analyser: *Analyser, peer_tys: []?Type) !?Type {
69846989
return opt_cur_ty.?;
69856990
},
69866991

6987-
.enum_or_union => return null, // TODO
6992+
.enum_or_union => {
6993+
var opt_cur_ty: ?Type = null;
6994+
6995+
for (peer_tys) |opt_ty| {
6996+
const ty = opt_ty orelse continue;
6997+
switch (ty.zigTypeTag(analyser).?) {
6998+
.enum_literal, .@"enum", .@"union" => {},
6999+
else => return null,
7000+
}
7001+
const cur_ty = opt_cur_ty orelse {
7002+
opt_cur_ty = ty;
7003+
continue;
7004+
};
7005+
7006+
if ((cur_ty.isUnionType() and !cur_ty.isTaggedUnion()) or
7007+
(ty.isUnionType() and !ty.isTaggedUnion()))
7008+
{
7009+
if (cur_ty.eql(ty)) continue;
7010+
return null;
7011+
}
7012+
7013+
switch (cur_ty.zigTypeTag(analyser).?) {
7014+
.enum_literal => {
7015+
opt_cur_ty = ty;
7016+
},
7017+
.@"enum" => switch (ty.zigTypeTag(analyser).?) {
7018+
.enum_literal => {},
7019+
.@"enum" => {
7020+
if (!ty.eql(cur_ty)) return null;
7021+
},
7022+
.@"union" => {
7023+
const tag_ty = try analyser.resolveUnionTagType(ty);
7024+
if (!tag_ty.?.eql(cur_ty)) return null;
7025+
opt_cur_ty = ty;
7026+
},
7027+
else => unreachable,
7028+
},
7029+
.@"union" => switch (ty.zigTypeTag(analyser).?) {
7030+
.enum_literal => {},
7031+
.@"enum" => {
7032+
const cur_tag_ty = try analyser.resolveUnionTagType(cur_ty);
7033+
if (!ty.eql(cur_tag_ty.?)) return null;
7034+
},
7035+
.@"union" => {
7036+
if (!ty.eql(cur_ty)) return null;
7037+
},
7038+
else => unreachable,
7039+
},
7040+
else => unreachable,
7041+
}
7042+
}
7043+
return opt_cur_ty.?;
7044+
},
69887045

69897046
.int_or_float => {
69907047
var ip_indices: std.ArrayListUnmanaged(InternPool.Index) = try .initCapacity(analyser.gpa, peer_tys.len);

tests/analysis/peer_type_resolution.zig

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,36 @@ const S = struct {
55
int: i64,
66
float: f32,
77
};
8+
89
const s: S = .{
910
.int = 0,
1011
.float = 1.2,
1112
};
1213

14+
const E = enum {
15+
foo,
16+
bar,
17+
baz,
18+
};
19+
20+
const e: E = .bar;
21+
22+
const T = union(E) {
23+
foo: void,
24+
bar: void,
25+
baz: u16,
26+
};
27+
28+
const t: T = .{ .baz = 3 };
29+
30+
const U = union {
31+
foo: void,
32+
bar: void,
33+
baz: void,
34+
};
35+
36+
const u: U = .{ .baz = {} };
37+
1338
var runtime_bool: bool = true;
1439

1540
const widened_int_0 = if (runtime_bool) @as(i8, 0) else @as(i16, 0);
@@ -218,6 +243,57 @@ fn void_fn() void {}
218243
const optional_fn = if (comptime_bool) @as(?fn () void, void_fn) else void_fn;
219244
// ^^^^^^^^^^^ (?fn () void)()
220245

246+
const optional_enum_literal = if (comptime_bool) @as(?@Type(.enum_literal), .foo) else .bar;
247+
// ^^^^^^^^^^^^^^^^^^^^^ (?@Type(.enum_literal))()
248+
249+
const enum_literal_and_enum_literal = if (comptime_bool) .foo else .bar;
250+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (@Type(.enum_literal))()
251+
252+
const enum_literal_and_enum = if (comptime_bool) .foo else e;
253+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
254+
255+
const enum_literal_and_union = if (comptime_bool) .foo else u;
256+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
257+
258+
const enum_literal_and_tagged_union = if (comptime_bool) .foo else t;
259+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
260+
261+
const enum_and_enum_literal = if (comptime_bool) e else .foo;
262+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
263+
264+
const enum_and_enum = if (comptime_bool) e else e;
265+
// ^^^^^^^^^^^^^ (E)()
266+
267+
const enum_and_union = if (comptime_bool) e else u;
268+
// ^^^^^^^^^^^^^^ (either type)()
269+
270+
const enum_and_tagged_union = if (comptime_bool) e else t;
271+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
272+
273+
const union_and_enum_literal = if (comptime_bool) u else .foo;
274+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
275+
276+
const union_and_enum = if (comptime_bool) u else e;
277+
// ^^^^^^^^^^^^^^ (either type)()
278+
279+
const union_and_union = if (comptime_bool) u else u;
280+
// ^^^^^^^^^^^^^^^ (U)()
281+
282+
const union_and_tagged_union = if (comptime_bool) u else t;
283+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
284+
285+
const tagged_union_and_enum_literal = if (comptime_bool) t else .foo;
286+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
287+
288+
const tagged_union_and_enum = if (comptime_bool) t else e;
289+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
290+
291+
const tagged_union_and_union = if (comptime_bool) t else u;
292+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
293+
294+
const tagged_union_and_tagged_union = if (comptime_bool) t else t;
295+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
296+
221297
const f32_and_u32 = if (comptime_bool) @as(f32, 0) else @as(i32, 0);
222298
// ^^^^^^^^^^^ (either type)()
223299

0 commit comments

Comments
 (0)