@@ -10,6 +10,24 @@ namespace Silk.NET.SilkTouch.Mods;
10
10
/// </summary>
11
11
public class TransformVulkan : IMod
12
12
{
13
+ private const string MethodClassName = "Vk" ;
14
+
15
+ private const string InstanceTypeName = "InstanceHandle" ;
16
+ private const string InstanceNativeTypeName = "VkInstance" ;
17
+ private const string InstanceFieldName = "_currentInstance" ;
18
+ private const string InstancePropertyName = "CurrentInstance" ;
19
+
20
+ private const string DeviceTypeName = "DeviceHandle" ;
21
+ private const string DeviceNativeTypeName = "VkDevice" ;
22
+ private const string DeviceFieldName = "_currentDevice" ;
23
+ private const string DevicePropertyName = "CurrentDevice" ;
24
+
25
+ private const string VkCreateInstanceNativeName = "vkCreateInstance" ;
26
+ private const string VkCreateDeviceNativeName = "vkCreateDevice" ;
27
+
28
+ private const string VkResultName = "Result" ;
29
+ private const string VkResultSuccessName = "Success" ;
30
+
13
31
/// <inheritdoc />
14
32
public async Task ExecuteAsync ( IModContext ctx , CancellationToken ct = default )
15
33
{
@@ -38,25 +56,47 @@ public async Task ExecuteAsync(IModContext ctx, CancellationToken ct = default)
38
56
ctx . SourceProject = proj ;
39
57
}
40
58
41
- private class Rewriter : CSharpSyntaxRewriter
59
+ /// <summary>
60
+ /// Used by <see cref="Rewriter"/> to identify methods that call
61
+ /// the native function pointer through the vtable slots field.
62
+ /// </summary>
63
+ private class SlotsMethodIdentifier : CSharpSyntaxWalker
42
64
{
43
- private const string MethodClassName = "Vk" ;
65
+ public bool IsSlotsMethod { get ; private set ; }
66
+
67
+ private bool isInInvocationExpression = false ;
68
+
69
+ public override void VisitMethodDeclaration ( MethodDeclarationSyntax node )
70
+ {
71
+ IsSlotsMethod = false ;
72
+ base . VisitMethodDeclaration ( node ) ;
73
+ }
44
74
45
- private const string InstanceTypeName = "InstanceHandle" ;
46
- private const string InstanceNativeTypeName = "VkInstance" ;
47
- private const string InstanceFieldName = "_currentInstance" ;
48
- private const string InstancePropertyName = "CurrentInstance" ;
75
+ public override void VisitInvocationExpression ( InvocationExpressionSyntax node )
76
+ {
77
+ isInInvocationExpression = true ;
78
+ base . VisitInvocationExpression ( node ) ;
79
+ }
49
80
50
- private const string DeviceTypeName = "DeviceHandle" ;
51
- private const string DeviceNativeTypeName = "VkDevice" ;
52
- private const string DeviceFieldName = "_currentDevice" ;
53
- private const string DevicePropertyName = "CurrentDevice" ;
81
+ public override void VisitFunctionPointerType ( FunctionPointerTypeSyntax node )
82
+ {
83
+ if ( isInInvocationExpression )
84
+ {
85
+ IsSlotsMethod = true ;
86
+ }
87
+ }
88
+ }
54
89
55
- private const string VkCreateInstanceNativeName = "vkCreateInstance" ;
56
- private const string VkCreateDeviceNativeName = "vkCreateDevice" ;
90
+ /// <summary>
91
+ /// This does the following:
92
+ /// 1. Add the instance/device members.
93
+ /// 2. Rewrite the vkCreateInstance and vkCreateDevice methods to set those members.
94
+ /// </summary>
95
+ private class Rewriter : CSharpSyntaxRewriter
96
+ {
97
+ private readonly SlotsMethodIdentifier slotsMethodIdentifier = new ( ) ;
57
98
58
- private const string VkResultName = "Result" ;
59
- private const string VkResultSuccessName = "Success" ;
99
+ private bool hasOutputInstanceDeviceMembers ;
60
100
61
101
public override SyntaxNode ? VisitClassDeclaration ( ClassDeclarationSyntax node )
62
102
{
@@ -65,28 +105,39 @@ private class Rewriter : CSharpSyntaxRewriter
65
105
return base . VisitClassDeclaration ( node ) ;
66
106
}
67
107
68
- var instanceField = FieldDeclaration (
69
- VariableDeclaration ( NullableType ( IdentifierName ( InstanceTypeName ) ) )
70
- . AddVariables ( VariableDeclarator ( InstanceFieldName ) )
71
- ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
108
+ // Rewrite members
109
+ node = node . WithMembers ( [
110
+ .. node . Members . SelectMany ( RewriteMember )
111
+ ] ) ;
72
112
113
+ // Output instance/device members if needed
114
+ if ( ! hasOutputInstanceDeviceMembers )
115
+ {
116
+ var instanceField = FieldDeclaration (
117
+ VariableDeclaration ( NullableType ( IdentifierName ( InstanceTypeName ) ) )
118
+ . AddVariables ( VariableDeclarator ( InstanceFieldName ) )
119
+ ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
73
120
74
- var deviceField = FieldDeclaration (
75
- VariableDeclaration ( NullableType ( IdentifierName ( DeviceTypeName ) ) )
76
- . AddVariables ( VariableDeclarator ( DeviceFieldName ) )
77
- ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
78
121
79
- var instanceProperty = CreateProperty ( InstanceTypeName , InstancePropertyName , InstanceFieldName ) ;
122
+ var deviceField = FieldDeclaration (
123
+ VariableDeclaration ( NullableType ( IdentifierName ( DeviceTypeName ) ) )
124
+ . AddVariables ( VariableDeclarator ( DeviceFieldName ) )
125
+ ) . AddModifiers ( Token ( SyntaxKind . PrivateKeyword ) ) ;
80
126
81
- var deviceProperty = CreateProperty ( DeviceTypeName , DevicePropertyName , DeviceFieldName ) ;
127
+ var instanceProperty = CreateProperty ( InstanceTypeName , InstancePropertyName , InstanceFieldName ) ;
82
128
83
- node = node . WithMembers ( [
84
- instanceField ,
85
- deviceField ,
86
- instanceProperty ,
87
- deviceProperty ,
88
- ..node . Members . SelectMany ( RewriteMember )
89
- ] ) ;
129
+ var deviceProperty = CreateProperty ( DeviceTypeName , DevicePropertyName , DeviceFieldName ) ;
130
+
131
+ node = node . WithMembers ( [
132
+ instanceField ,
133
+ deviceField ,
134
+ instanceProperty ,
135
+ deviceProperty ,
136
+ ..node . Members
137
+ ] ) ;
138
+ }
139
+
140
+ hasOutputInstanceDeviceMembers = true ;
90
141
91
142
return base . VisitClassDeclaration ( node ) ;
92
143
}
@@ -105,13 +156,14 @@ private IEnumerable<MemberDeclarationSyntax> RewriteMember(MemberDeclarationSynt
105
156
yield break ;
106
157
}
107
158
108
- if ( ! method . Modifiers . Any ( modifier => modifier . IsKind ( SyntaxKind . ExternKeyword ) ) )
159
+ if ( entryPoint != VkCreateInstanceNativeName && entryPoint != VkCreateDeviceNativeName )
109
160
{
110
161
yield return member ;
111
162
yield break ;
112
163
}
113
164
114
- if ( entryPoint != VkCreateInstanceNativeName && entryPoint != VkCreateDeviceNativeName )
165
+ slotsMethodIdentifier . Visit ( member ) ;
166
+ if ( ! slotsMethodIdentifier . IsSlotsMethod )
115
167
{
116
168
yield return member ;
117
169
yield break ;
@@ -122,9 +174,9 @@ private IEnumerable<MemberDeclarationSyntax> RewriteMember(MemberDeclarationSynt
122
174
123
175
// Output the original method, but private
124
176
yield return method
177
+ . WithExplicitInterfaceSpecifier ( null )
125
178
. WithIdentifier ( Identifier ( privateMethodName ) )
126
179
. WithModifiers ( [
127
- Token ( SyntaxKind . PrivateKeyword ) ,
128
180
..member . Modifiers . Where ( modifier =>
129
181
! SyntaxFacts . IsAccessibilityModifier ( modifier . Kind ( ) ) )
130
182
] ) ;
0 commit comments