Skip to content

Commit 47a4613

Browse files
committed
Improve out arguments assignability, for both validation and expression compilation (STUD-76892)
1 parent 5fefd46 commit 47a4613

File tree

5 files changed

+142
-42
lines changed

5 files changed

+142
-42
lines changed

src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
using Microsoft.CodeAnalysis.Scripting.Hosting;
55
using ReflectionMagic;
66
using System.Runtime.InteropServices;
7-
using System.Text;
8-
using System.Threading;
97

108
namespace System.Activities
119
{
1210
public sealed class CSharpCompilerHelper : CompilerHelper
1311
{
14-
private static int crt = 0;
1512
private static readonly dynamic s_typeNameFormatter = GetTypeNameFormatter();
1613
private static readonly dynamic s_typeOptions = GetTypeOptions();
1714

@@ -39,22 +36,19 @@ public override string CreateExpressionCode(string[] types, string[] names, stri
3936
return $"{myDelegate} \n public static Expression<{name}<{typesStr}>> CreateExpression() => ({namesStr}) => {code};";
4037
}
4138

42-
protected override (string, string) DefineDelegateCommon(int argumentsCount)
39+
internal string CreateReferenceCode(string[] types, string returnType, string[] names, string code)
4340
{
44-
var crtValue = Interlocked.Add(ref crt, 1);
45-
46-
var part1 = new StringBuilder();
47-
var part2 = new StringBuilder();
48-
for (var i = 0; i < argumentsCount; i++)
49-
{
50-
part1.Append($"in T{i}, ");
51-
part2.Append($" T{i} arg{i},");
52-
}
53-
part2.Remove(part2.Length - 1, 1);
54-
var name = $"Func{crtValue}";
55-
return ($"public delegate TResult {name}<{part1} out TResult>({part2});", name);
41+
var strTypes = string.Join(Comma, types);
42+
var strNames = string.Join(Comma, names);
43+
return CSharpValidatorCommon.CreateReferenceCode(strTypes, returnType, strNames, code, string.Empty, 0);
5644
}
5745

46+
internal string CreateValueCode(string[] types, string[] names, string code)
47+
=> CSharpValidatorCommon.CreateValueCode(types, string.Join(Comma, names), code, string.Empty, 0);
48+
49+
protected override (string, string) DefineDelegateCommon(int argumentsCount)
50+
=> CSharpValidatorCommon.DefineDelegateCommon(argumentsCount);
51+
5852
private static object GetTypeNameFormatter()
5953
{
6054
return typeof(CSharpScript)

src/UiPath.Workflow/Microsoft/CSharp/CSharpExpressionCompiler.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,23 @@ protected override SyntaxTree GetSyntaxTreeForExpression(string expression, bool
5050
.ToArray();
5151

5252
var names = resolvedIdentifiers.Select(var => var.Name).ToArray();
53-
var types = resolvedIdentifiers.Select(var => var.Type).Concat(new[] { returnType }).Select(_compilerHelper.GetTypeName).ToArray();
54-
var lambdaFuncCode = _compilerHelper.CreateExpressionCode(types, names, expression);
55-
return CSharpSyntaxTree.ParseText(lambdaFuncCode, _compilerHelper.ScriptParseOptions);
53+
var types = resolvedIdentifiers.Select(var => var.Type).Select(_compilerHelper.GetTypeName).ToArray();
54+
string expressionCode;
55+
if (isLocation)
56+
{
57+
expressionCode = _compilerHelper.CreateReferenceCode(types: types,
58+
returnType: _compilerHelper.GetTypeName(returnType),
59+
names: names,
60+
code: expression);
61+
}
62+
else
63+
{
64+
types = types.Concat(new[] { _compilerHelper.GetTypeName(returnType) }).ToArray();
65+
expressionCode = _compilerHelper.CreateValueCode(types: types,
66+
names: names,
67+
code: expression);
68+
}
69+
70+
return CSharpSyntaxTree.ParseText(expressionCode, _compilerHelper.ScriptParseOptions);
5671
}
5772
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using System.Collections.Generic;
2+
using System.Linq;
3+
using System.Text;
4+
using System.Threading;
5+
namespace System.Activities;
6+
7+
internal static class CSharpValidatorCommon
8+
{
9+
private static int crt = 0;
10+
11+
// This is used in case the expression does not properly close (e.g. missing quotes, or multiline comment not closed)
12+
private const string _expressionEnder = "// */ // \"";
13+
14+
private const string _valueValidationTemplate = "public static System.Linq.Expressions.Expression<System.Func<{0}>> CreateExpression{1}()//activityId:{4}\n => ({2}) => {3}; {5}";
15+
private const string _delegateValueValidationTemplate = "{0}\npublic static System.Linq.Expressions.Expression<{1}<{2}>> CreateExpression{3}()//activityId:{6}\n => ({4}) => {5}; {7}";
16+
private const string _referenceValidationTemplate = "public static {0} IsLocation{1}()//activityId:{5}\n => ({2}) => {3} = default({4}); {6}";
17+
18+
internal static string CreateReferenceCode(string types, string returnType, string names, string code, string activityId, int index)
19+
{
20+
var actionDefinition = !string.IsNullOrWhiteSpace(types)
21+
? $"System.Action<{string.Join(CompilerHelper.Comma, types)}>"
22+
: "System.Action";
23+
return string.Format(_referenceValidationTemplate, actionDefinition, index, names, code, returnType, activityId, _expressionEnder);
24+
}
25+
26+
internal static string CreateValueCode(IEnumerable<string> types, string names, string code, string activityId, int index)
27+
{
28+
var serializedArgumentTypes = string.Join(CompilerHelper.Comma, types);
29+
if (types.Count() <= 16) // .net defines Func<TResult>...Func<T1,...T16,TResult)
30+
return string.Format(_valueValidationTemplate, serializedArgumentTypes, index, names, code, activityId, _expressionEnder);
31+
32+
var (myDelegate, name) = DefineDelegateCommon(types.Count() - 1);
33+
return string.Format(_delegateValueValidationTemplate, myDelegate, name, serializedArgumentTypes, index, names, code, activityId, _expressionEnder);
34+
}
35+
36+
internal static (string, string) DefineDelegateCommon(int argumentsCount)
37+
{
38+
var crtValue = Interlocked.Add(ref crt, 1);
39+
40+
var part1 = new StringBuilder();
41+
var part2 = new StringBuilder();
42+
for (var i = 0; i < argumentsCount; i++)
43+
{
44+
part1.Append($"in T{i}, ");
45+
part2.Append($" T{i} arg{i},");
46+
}
47+
part2.Remove(part2.Length - 1, 1);
48+
var name = $"Func{crtValue}";
49+
return ($"public delegate TResult {name}<{part1} out TResult>({part2});", name);
50+
}
51+
}

src/UiPath.Workflow/Validation/CSharpExpressionValidator.cs

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ namespace System.Activities.Validation;
1818
/// </summary>
1919
public class CSharpExpressionValidator : RoslynExpressionValidator
2020
{
21-
// This is used in case the expression does not properly close (e.g. missing quotes, or multiline comment not closed)
22-
private const string _expressionEnder = "// */ // \"";
23-
24-
private const string _valueValidationTemplate = "public static System.Linq.Expressions.Expression<System.Func<{0}>> CreateExpression{1}()//activityId:{4}\n => ({2}) => {3}; {5}";
25-
private const string _delegateValueValidationTemplate = "{0}\npublic static System.Linq.Expressions.Expression<{1}<{2}>> CreateExpression{3}()//activityId:{6}\n => ({4}) => {5}; {7}";
26-
private const string _referenceValidationTemplate = "public static {0} IsLocation{1}()//activityId:{5}\n => ({2}) => {3} = default({4}); {6}";
27-
2821
private static readonly Lazy<CSharpExpressionValidator> s_instance = new(() => new());
2922
public override string Language => CSharpHelper.Language;
3023

@@ -66,22 +59,10 @@ protected override Compilation GetCompilation(IReadOnlyCollection<Assembly> asse
6659
}
6760

6861
protected override string CreateValueCode(IEnumerable<string> types, string names, string code, string activityId, int index)
69-
{
70-
var serializedArgumentTypes = string.Join(Comma, types);
71-
if (types.Count() <= 16) // .net defines Func<TResult>...Funct<T1,...T16,TResult)
72-
return string.Format(_valueValidationTemplate, serializedArgumentTypes, index, names, code, activityId, _expressionEnder);
73-
74-
var (myDelegate, name) = CompilerHelper.DefineDelegate(types);
75-
return string.Format(_delegateValueValidationTemplate, myDelegate, name, serializedArgumentTypes, index, names, code, activityId, _expressionEnder);
76-
}
62+
=> CSharpValidatorCommon.CreateValueCode(types, names, code, activityId, index);
7763

7864
protected override string CreateReferenceCode(string types, string names, string code, string activityId, string returnType, int index)
79-
{
80-
var actionDefinition = !string.IsNullOrWhiteSpace(types)
81-
? $"System.Action<{string.Join(Comma, types)}>"
82-
: "System.Action";
83-
return string.Format(_referenceValidationTemplate, actionDefinition, index, names, code, returnType, activityId, _expressionEnder);
84-
}
65+
=> CSharpValidatorCommon.CreateReferenceCode(types, returnType, names, code, activityId, index);
8566

8667
protected override SyntaxTree GetSyntaxTreeForExpression(string expressionText) =>
8768
CSharpSyntaxTree.ParseText(expressionText, CompilerHelper.ScriptParseOptions);

src/UiPath.Workflow/XamlIntegration/TextExpressionCompiler.cs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,8 @@ private void GenerateExpressionGetTreeMethod(Activity activity, CompiledExpressi
15251525
}
15261526
else if (IsCs)
15271527
{
1528-
expressionText = string.Concat(CSharpLambdaString, coreExpressionText);
1528+
var safeCastText = SafeCastText(coreExpressionText, expressionDescriptor.ResultType);
1529+
expressionText = string.Concat(CSharpLambdaString, safeCastText);
15291530
}
15301531

15311532
if (expressionText != null)
@@ -1570,7 +1571,16 @@ private CodeMemberMethod GenerateGetMethod(Activity activity, Type resultType, s
15701571
new CodeAttributeDeclaration(new CodeTypeReference(typeof(DebuggerHiddenAttribute))));
15711572

15721573
AlignText(activity, ref expressionText, out var pragma);
1573-
CodeStatement statement = new CodeMethodReturnStatement(new CodeSnippetExpression(expressionText));
1574+
CodeStatement statement;
1575+
if (IsCs)
1576+
{
1577+
statement = new CodeMethodReturnStatement(SafeCast(expressionText, resultType));
1578+
}
1579+
else
1580+
{
1581+
statement = new CodeMethodReturnStatement(new CodeSnippetExpression(expressionText));
1582+
}
1583+
15741584
statement.LinePragma = pragma;
15751585
expressionMethod.Statements.Add(statement);
15761586

@@ -2512,6 +2522,55 @@ private string GetActivityFullName(TextExpressionCompilerSettings settings)
25122522
return activityFullName;
25132523
}
25142524

2525+
/// <summary>
2526+
/// Creates a CodeSnippetExpression like: "variableName as TargetType".
2527+
/// If variableName or targetType is null/empty, falls back to "null".
2528+
/// </summary>
2529+
private CodeSnippetExpression SafeCast(string variableName, Type targetType)
2530+
=> new CodeSnippetExpression(SafeCastText(variableName, targetType));
2531+
2532+
/// <summary>
2533+
/// Safely formats a string representation of a variable name and its target type.
2534+
/// Validates that the 'as' operator is only applied to reference types or nullable value types.
2535+
/// For non-nullable value types, returns the variable name without any casting.
2536+
/// </summary>
2537+
/// <param name="variableName">The name of the variable to be cast. If null or whitespace, "null" is used instead.</param>
2538+
/// <param name="targetType">The target <see cref="Type"/> to which the variable is being cast.</param>
2539+
/// <returns>A string in the format "<paramref name="variableName"/> as <paramref name="targetType"/>" for valid types, otherwise just the variable name.</returns>
2540+
private string SafeCastText(string variableName, Type targetType)
2541+
{
2542+
string varName = string.IsNullOrWhiteSpace(variableName) ? "null" : variableName;
2543+
2544+
// Early exit if targetType is null or if it's a non-nullable value type
2545+
if (targetType == null || (targetType.IsValueType && !IsNullableValueType(targetType)))
2546+
{
2547+
return varName;
2548+
}
2549+
2550+
string typeName = GetFriendlyTypeName(targetType);
2551+
if (typeName == null)
2552+
{
2553+
return varName;
2554+
}
2555+
2556+
// Use 'as' operator for reference types and nullable value types
2557+
return $"{varName} as {typeName}";
2558+
}
2559+
2560+
private string GetFriendlyTypeName(Type type)
2561+
{
2562+
using var codeDomProvider = CodeDomProvider.CreateProvider(_settings.Language);
2563+
return type is null ? null : codeDomProvider.GetTypeOutput(new CodeTypeReference(type));
2564+
}
2565+
2566+
/// <summary>
2567+
/// Helper method to check if a type is a nullable value type (e.g., int?, DateTime?)
2568+
/// </summary>
2569+
private static bool IsNullableValueType(Type type)
2570+
{
2571+
return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
2572+
}
2573+
25152574
private class ExpressionCompilerActivityVisitor : CompiledExpressionActivityVisitor
25162575
{
25172576
private readonly TextExpressionCompiler _compiler;

0 commit comments

Comments
 (0)