Skip to content

Support Covariant Return Types #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/AspectCore.Core/Extensions/CollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// ReSharper disable once CheckNamespace
namespace System.Collections.Generic
{
internal static class CollectionExtensions
{
#if NETSTANDARD2_0
public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
return dictionary.TryGetValue(key, out var obj)
? obj
: defaultValue;
}
#endif
}
}
22 changes: 22 additions & 0 deletions src/AspectCore.Core/Extensions/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Collections.Generic;

// ReSharper disable once CheckNamespace
namespace System.Linq
{
internal static class EnumerableExtensions
{
#if NETSTANDARD2_0 || NETSTANDARD2_1
public static IEnumerable<(TFirst First, TSecond Second)> Zip<TFirst, TSecond>(this IEnumerable<TFirst> first, IEnumerable<TSecond> second)
{
return first.Zip(second, (f, s) => (f, s));
}
#endif

#if NETSTANDARD2_0
public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> source, IEqualityComparer<TSource> comparer = null)
{
return new HashSet<TSource>(source, comparer);
}
#endif
}
}
45 changes: 45 additions & 0 deletions src/AspectCore.Core/Extensions/MethodInfoExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using static AspectCore.Extensions.TypeExtensions;

// ReSharper disable once CheckNamespace
namespace AspectCore.Extensions
{
internal static class MethodInfoExtensions
{
public static IEnumerable<MethodInfo> GetInterfaceDeclarations(this MethodInfo method)
{
var typeInfo = method.ReflectedType?.GetTypeInfo();
if (typeInfo is null)
yield break;

foreach (var implementedInterface in typeInfo.ImplementedInterfaces)
{
var map = typeInfo.GetInterfaceMap(implementedInterface);
foreach (var (interfaceMethod, targetMethod) in map.InterfaceMethods.Zip(map.TargetMethods))
{
if (targetMethod == method)
yield return interfaceMethod;
}
}
}

public static bool IsPreserveBaseOverride(this MethodInfo method, bool checkBase)
{
if (PreserveBaseOverridesAttribute is null)
return false;

if (method.IsDefined(PreserveBaseOverridesAttribute))
return true;

return checkBase && method.GetBaseDefinition().IsDefined(PreserveBaseOverridesAttribute);
}

public static bool IsSameBaseDefinition(this MethodInfo method, MethodInfo other)
{
return method.GetBaseDefinition() == other.GetBaseDefinition();
}
}
}

100 changes: 100 additions & 0 deletions src/AspectCore.Core/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

// ReSharper disable once CheckNamespace
namespace AspectCore.Extensions
{
internal readonly struct CovariantReturnMethodInfo
{
public readonly MethodInfo CovariantReturnMethod;
public readonly MethodInfo OverridenMethod;
public readonly HashSet<MethodInfo> InterfaceDeclarations;

public CovariantReturnMethodInfo(MethodInfo covariantReturnMethod, MethodInfo overridenMethod, HashSet<MethodInfo> interfaceDeclarations)
{
InterfaceDeclarations = interfaceDeclarations;
OverridenMethod = overridenMethod;
CovariantReturnMethod = covariantReturnMethod;
}
}

internal static class TypeExtensions
{
public static readonly Type PreserveBaseOverridesAttribute = Type.GetType("System.Runtime.CompilerServices.PreserveBaseOverridesAttribute", false);

public static IReadOnlyList<CovariantReturnMethodInfo> GetCovariantReturnMethods(this Type type)
{
var result = new List<CovariantReturnMethodInfo>();
// No PreserveBaseOverridesAttribute means that the runtime does not support covariant return types.
if (PreserveBaseOverridesAttribute is null)
return result;

var methods = type
.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
.GroupBy(m => m.IsPreserveBaseOverride(true))
.ToDictionary(m => m.Key, m => m.ToArray());

var covariantReturnMethods = methods.GetValueOrDefault(true, Array.Empty<MethodInfo>());
var otherMethods = methods.GetValueOrDefault(false, Array.Empty<MethodInfo>());

foreach (var covariantReturnMethod in covariantReturnMethods)
{
var overridenMethod = otherMethods.FirstOrDefault(m => Match(covariantReturnMethod, m));
if (overridenMethod is null)
continue;

var interfaceDeclarations = covariantReturnMethod.GetInterfaceDeclarations().ToHashSet();
result.Add(new CovariantReturnMethodInfo(covariantReturnMethod, overridenMethod, interfaceDeclarations));
}

return result;

bool Match(MethodInfo covariantReturnMethod, MethodInfo other)
{
if (covariantReturnMethod.Name != other.Name)
return false;

// return types should not be the same.
if (covariantReturnMethod.ReturnType == other.ReturnType)
return false;

if (other.ReturnType.IsAssignableFrom(covariantReturnMethod.ReturnType) == false)
return false;

var params1 = covariantReturnMethod.GetParameters();
var params2 = other.GetParameters();

if (params1.Length != params2.Length)
return false;

foreach (var (p1, p2) in params1.Zip(params2))
{
if (p1.ParameterType != p2.ParameterType)
return false;
}

var isGeneric = covariantReturnMethod.IsGenericMethod;
if (isGeneric != other.IsGenericMethod)
return false;

if (isGeneric)
{
var args1 = covariantReturnMethod.GetGenericArguments();
var args2 = other.GetGenericArguments();
if (args1.Length != args2.Length)
return false;

foreach (var (a1, a2) in args1.Zip(args2))
{
if (a1 != a2)
return false;
}
}

return true;
}
}
}
}
Loading
Loading