Skip to content

Commit 1e46f2d

Browse files
Resolve peer enum/union types
1 parent e3ce89b commit 1e46f2d

File tree

2 files changed

+139
-6
lines changed

2 files changed

+139
-6
lines changed

src/analysis.zig

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,13 @@ pub fn resolveUnwrapErrorUnionType(analyser: *Analyser, ty: Type, side: ErrorUni
982982
}
983983

984984
fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985+
const tag_type = try analyser.resolveUnionTagType(ty) orelse
986+
return null;
987+
988+
return try tag_type.instanceTypeVal(analyser);
989+
}
990+
991+
fn resolveUnionTagType(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
985992
if (!ty.is_type_val)
986993
return null;
987994

@@ -1000,12 +1007,10 @@ fn resolveUnionTag(analyser: *Analyser, ty: Type) error{OutOfMemory}!?Type {
10001007
return null;
10011008

10021009
if (container_decl.ast.enum_token != null)
1003-
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = false };
1010+
return .{ .data = .{ .union_tag = try analyser.allocType(ty) }, .is_type_val = true };
10041011

1005-
if (container_decl.ast.arg.unwrap()) |arg| {
1006-
const tag_type = (try analyser.resolveTypeOfNode(.of(arg, handle))) orelse return null;
1007-
return try tag_type.instanceTypeVal(analyser) orelse return null;
1008-
}
1012+
if (container_decl.ast.arg.unwrap()) |arg|
1013+
return try analyser.resolveTypeOfNode(.of(arg, handle));
10091014

10101015
return null;
10111016
}
@@ -6979,7 +6984,59 @@ fn resolvePeerTypesInner(analyser: *Analyser, peer_tys: []?Type) !?Type {
69796984
return opt_cur_ty.?;
69806985
},
69816986

6982-
.enum_or_union => return null, // TODO
6987+
.enum_or_union => {
6988+
var opt_cur_ty: ?Type = null;
6989+
6990+
for (peer_tys) |opt_ty| {
6991+
const ty = opt_ty orelse continue;
6992+
switch (ty.zigTypeTag(analyser).?) {
6993+
.enum_literal, .@"enum", .@"union" => {},
6994+
else => return null,
6995+
}
6996+
const cur_ty = opt_cur_ty orelse {
6997+
opt_cur_ty = ty;
6998+
continue;
6999+
};
7000+
7001+
if ((cur_ty.isUnionType() and !cur_ty.isTaggedUnion()) or
7002+
(ty.isUnionType() and !ty.isTaggedUnion()))
7003+
{
7004+
if (cur_ty.eql(ty)) continue;
7005+
return null;
7006+
}
7007+
7008+
switch (cur_ty.zigTypeTag(analyser).?) {
7009+
.enum_literal => {
7010+
opt_cur_ty = ty;
7011+
},
7012+
.@"enum" => switch (ty.zigTypeTag(analyser).?) {
7013+
.enum_literal => {},
7014+
.@"enum" => {
7015+
if (!ty.eql(cur_ty)) return null;
7016+
},
7017+
.@"union" => {
7018+
const tag_ty = try analyser.resolveUnionTagType(ty);
7019+
if (!tag_ty.?.eql(cur_ty)) return null;
7020+
opt_cur_ty = ty;
7021+
},
7022+
else => unreachable,
7023+
},
7024+
.@"union" => switch (ty.zigTypeTag(analyser).?) {
7025+
.enum_literal => {},
7026+
.@"enum" => {
7027+
const cur_tag_ty = try analyser.resolveUnionTagType(cur_ty);
7028+
if (!ty.eql(cur_tag_ty.?)) return null;
7029+
},
7030+
.@"union" => {
7031+
if (!ty.eql(cur_ty)) return null;
7032+
},
7033+
else => unreachable,
7034+
},
7035+
else => unreachable,
7036+
}
7037+
}
7038+
return opt_cur_ty.?;
7039+
},
69837040

69847041
.int_or_float => {
69857042
var ip_indices: std.ArrayListUnmanaged(InternPool.Index) = try .initCapacity(analyser.gpa, peer_tys.len);

tests/analysis/peer_type_resolution.zig

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,36 @@ const S = struct {
55
int: i64,
66
float: f32,
77
};
8+
89
const s: S = .{
910
.int = 0,
1011
.float = 1.2,
1112
};
1213

14+
const E = enum {
15+
foo,
16+
bar,
17+
baz,
18+
};
19+
20+
const e: E = .bar;
21+
22+
const T = union(E) {
23+
foo: void,
24+
bar: void,
25+
baz: u16,
26+
};
27+
28+
const t: T = .{ .baz = 3 };
29+
30+
const U = union {
31+
foo: void,
32+
bar: void,
33+
baz: void,
34+
};
35+
36+
const u: U = .{ .baz = {} };
37+
1338
var runtime_bool: bool = true;
1439

1540
const widened_int_0 = if (runtime_bool) @as(i8, 0) else @as(i16, 0);
@@ -218,6 +243,57 @@ fn void_fn() void {}
218243
const optional_fn = if (comptime_bool) @as(?fn () void, void_fn) else void_fn;
219244
// ^^^^^^^^^^^ (?fn () void)()
220245

246+
const optional_enum_literal = if (comptime_bool) @as(?@Type(.enum_literal), .foo) else .bar;
247+
// ^^^^^^^^^^^^^^^^^^^^^ (?@Type(.enum_literal))()
248+
249+
const enum_literal_and_enum_literal = if (comptime_bool) .foo else .bar;
250+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (@Type(.enum_literal))()
251+
252+
const enum_literal_and_enum = if (comptime_bool) .foo else e;
253+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
254+
255+
const enum_literal_and_union = if (comptime_bool) .foo else u;
256+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
257+
258+
const enum_literal_and_tagged_union = if (comptime_bool) .foo else t;
259+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
260+
261+
const enum_and_enum_literal = if (comptime_bool) e else .foo;
262+
// ^^^^^^^^^^^^^^^^^^^^^ (E)()
263+
264+
const enum_and_enum = if (comptime_bool) e else e;
265+
// ^^^^^^^^^^^^^ (E)()
266+
267+
const enum_and_union = if (comptime_bool) e else u;
268+
// ^^^^^^^^^^^^^^ (either type)()
269+
270+
const enum_and_tagged_union = if (comptime_bool) e else t;
271+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
272+
273+
const union_and_enum_literal = if (comptime_bool) u else .foo;
274+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
275+
276+
const union_and_enum = if (comptime_bool) u else e;
277+
// ^^^^^^^^^^^^^^ (either type)()
278+
279+
const union_and_union = if (comptime_bool) u else u;
280+
// ^^^^^^^^^^^^^^^ (U)()
281+
282+
const union_and_tagged_union = if (comptime_bool) u else t;
283+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
284+
285+
const tagged_union_and_enum_literal = if (comptime_bool) t else .foo;
286+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
287+
288+
const tagged_union_and_enum = if (comptime_bool) t else e;
289+
// ^^^^^^^^^^^^^^^^^^^^^ (T)()
290+
291+
const tagged_union_and_union = if (comptime_bool) t else u;
292+
// ^^^^^^^^^^^^^^^^^^^^^^ (either type)()
293+
294+
const tagged_union_and_tagged_union = if (comptime_bool) t else t;
295+
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (T)()
296+
221297
const f32_and_u32 = if (comptime_bool) @as(f32, 0) else @as(i32, 0);
222298
// ^^^^^^^^^^^ (either type)()
223299

0 commit comments

Comments
 (0)