diff --git a/src/embed_tests/ClassManagerTests.cs b/src/embed_tests/ClassManagerTests.cs index 15da61e3b..dcdf66edb 100644 --- a/src/embed_tests/ClassManagerTests.cs +++ b/src/embed_tests/ClassManagerTests.cs @@ -1,4 +1,5 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -1083,6 +1084,164 @@ def is_enum_value_defined(): Assert.Throws(() => module.InvokeMethod("is_enum_value_defined")); } } + + private static TestCaseData[] IDictionaryContainsTestCases => + [ + new(typeof(TestDictionary)), + new(typeof(Dictionary)), + new(typeof(TestKeyValueContainer)), + new(typeof(DynamicClassDictionary)), + ]; + + [TestCaseSource(nameof(IDictionaryContainsTestCases))] + public void IDictionaryContainsMethodIsBound(Type dictType) + { + using var _ = Py.GIL(); + + var module = PyModule.FromString("IDictionaryContainsMethodIsBound", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def contains(dictionary, key): + return key in dictionary +"); + + using var contains = module.GetAttr("contains"); + + var dictionary = Convert.ChangeType(Activator.CreateInstance(dictType), dictType); + var key1 = "key1"; + (dictionary as dynamic).Add(key1, "value1"); + + using var pyDictionary = dictionary.ToPython(); + using var pyKey1 = key1.ToPython(); + + var result = contains.Invoke(pyDictionary, pyKey1).As(); + Assert.IsTrue(result); + + using var pyKey2 = "key2".ToPython(); + result = contains.Invoke(pyDictionary, pyKey2).As(); + Assert.IsFalse(result); + } + + [TestCaseSource(nameof(IDictionaryContainsTestCases))] + public void CanCheckIfNoneIsInDictionary(Type dictType) + { + using var _ = Py.GIL(); + + var module = PyModule.FromString("CanCheckIfNoneIsInDictionary", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def contains(dictionary, key): + return key in dictionary +"); + + using var contains = module.GetAttr("contains"); + + var dictionary = Convert.ChangeType(Activator.CreateInstance(dictType), dictType); + (dictionary as dynamic).Add("key1", "value1"); + + using var pyDictionary = dictionary.ToPython(); + + var result = false; + Assert.DoesNotThrow(() => result = contains.Invoke(pyDictionary, PyObject.None).As()); + Assert.IsFalse(result); + } + + public class TestDictionary : IDictionary + { + private readonly Dictionary _data = new(); + + public object this[object key] { get => ((IDictionary)_data)[key]; set => ((IDictionary)_data)[key] = value; } + + public bool IsFixedSize => ((IDictionary)_data).IsFixedSize; + + public bool IsReadOnly => ((IDictionary)_data).IsReadOnly; + + public ICollection Keys => ((IDictionary)_data).Keys; + + public ICollection Values => ((IDictionary)_data).Values; + + public int Count => ((ICollection)_data).Count; + + public bool IsSynchronized => ((ICollection)_data).IsSynchronized; + + public object SyncRoot => ((ICollection)_data).SyncRoot; + + public void Add(object key, object value) + { + ((IDictionary)_data).Add(key, value); + } + + public void Clear() + { + ((IDictionary)_data).Clear(); + } + + public bool Contains(object key) + { + return ((IDictionary)_data).Contains(key); + } + + public void CopyTo(Array array, int index) + { + ((ICollection)_data).CopyTo(array, index); + } + + public IDictionaryEnumerator GetEnumerator() + { + return ((IDictionary)_data).GetEnumerator(); + } + + public void Remove(object key) + { + ((IDictionary)_data).Remove(key); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_data).GetEnumerator(); + } + + public bool ContainsKey(TKey key) + { + return Contains(key); + } + } + + public class TestKeyValueContainer + where TKey: class + where TValue: class + { + private readonly Dictionary _data = new(); + public int Count => _data.Count; + public bool ContainsKey(TKey key) + { + return _data.ContainsKey(key); + } + public void Add(TKey key, TValue value) + { + _data.Add(key, value); + } + } + + public class DynamicClassDictionary : TestPropertyAccess.DynamicFixture + { + private readonly Dictionary _data = new(); + public int Count => _data.Count; + public bool ContainsKey(TKey key) + { + return _data.ContainsKey(key); + } + public void Add(TKey key, TValue value) + { + _data.Add(key, value); + } + } } public class NestedTestParent diff --git a/src/runtime/ClassManager.cs b/src/runtime/ClassManager.cs index 58f80ce30..bf852112c 100644 --- a/src/runtime/ClassManager.cs +++ b/src/runtime/ClassManager.cs @@ -205,7 +205,19 @@ internal static ClassBase CreateClass(Type type) else if (typeof(IDynamicMetaObjectProvider).IsAssignableFrom(type)) { - impl = new DynamicClassObject(type); + if (type.IsLookUp()) + { + impl = new DynamicClassLookUpObject(type); + } + else + { + impl = new DynamicClassObject(type); + } + } + + else if (type.IsLookUp()) + { + impl = new LookUpObject(type); } else diff --git a/src/runtime/Types/DynamicClassLookUpObject.cs b/src/runtime/Types/DynamicClassLookUpObject.cs new file mode 100644 index 000000000..2c570fe20 --- /dev/null +++ b/src/runtime/Types/DynamicClassLookUpObject.cs @@ -0,0 +1,34 @@ +using System; + +namespace Python.Runtime +{ + /// + /// Implements a Python type for managed DynamicClass objects that support look up (dictionaries), + /// that is, they implement ContainsKey(). + /// This type is essentially the same as a ClassObject, except that it provides + /// sequence semantics to support natural dictionary usage (__contains__ and __len__) + /// from Python. + /// + internal class DynamicClassLookUpObject : DynamicClassObject + { + internal DynamicClassLookUpObject(Type tp) : base(tp) + { + } + + /// + /// Implements __len__ for dictionary types. + /// + public static int mp_length(BorrowedReference ob) + { + return LookUpObject.mp_length(ob); + } + + /// + /// Implements __contains__ for dictionary types. + /// + public static int sq_contains(BorrowedReference ob, BorrowedReference v) + { + return LookUpObject.sq_contains(ob, v); + } + } +} diff --git a/src/runtime/Types/KeyValuePairEnumerableObject.cs b/src/runtime/Types/KeyValuePairEnumerableObject.cs index 95a0180e1..04c3f66f9 100644 --- a/src/runtime/Types/KeyValuePairEnumerableObject.cs +++ b/src/runtime/Types/KeyValuePairEnumerableObject.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Reflection; namespace Python.Runtime { @@ -10,75 +9,14 @@ namespace Python.Runtime /// sequence semantics to support natural dictionary usage (__contains__ and __len__) /// from Python. /// - internal class KeyValuePairEnumerableObject : ClassObject + internal class KeyValuePairEnumerableObject : LookUpObject { - [NonSerialized] - private static Dictionary, MethodInfo> methodsByType = new Dictionary, MethodInfo>(); - private static List requiredMethods = new List { "Count", "ContainsKey" }; - - internal static bool VerifyMethodRequirements(Type type) - { - foreach (var requiredMethod in requiredMethods) - { - var method = type.GetMethod(requiredMethod); - if (method == null) - { - method = type.GetMethod($"get_{requiredMethod}"); - if (method == null) - { - return false; - } - } - - var key = Tuple.Create(type, requiredMethod); - methodsByType.Add(key, method); - } - - return true; - } - internal KeyValuePairEnumerableObject(Type tp) : base(tp) { } internal override bool CanSubclass() => false; - - /// - /// Implements __len__ for dictionary types. - /// - public static int mp_length(BorrowedReference ob) - { - var obj = (CLRObject)GetManagedObject(ob); - var self = obj.inst; - - var key = Tuple.Create(self.GetType(), "Count"); - var methodInfo = methodsByType[key]; - - return (int)methodInfo.Invoke(self, null); - } - - /// - /// Implements __contains__ for dictionary types. - /// - public static int sq_contains(BorrowedReference ob, BorrowedReference v) - { - var obj = (CLRObject)GetManagedObject(ob); - var self = obj.inst; - - var key = Tuple.Create(self.GetType(), "ContainsKey"); - var methodInfo = methodsByType[key]; - - var parameters = methodInfo.GetParameters(); - object arg; - if (!Converter.ToManaged(v, parameters[0].ParameterType, out arg, false)) - { - Exceptions.SetError(Exceptions.TypeError, - $"invalid parameter type for sq_contains: should be {Converter.GetTypeByAlias(v)}, found {parameters[0].ParameterType}"); - } - - return (bool)methodInfo.Invoke(self, new[] { arg }) ? 1 : 0; - } } public static class KeyValuePairEnumerableObjectExtension @@ -102,7 +40,7 @@ public static bool IsKeyValuePairEnumerable(this Type type) a.GetGenericTypeDefinition() == keyValuePairType && a.GetGenericArguments().Length == 2) { - return KeyValuePairEnumerableObject.VerifyMethodRequirements(type); + return LookUpObject.VerifyMethodRequirements(type); } } } diff --git a/src/runtime/Types/LookUpObject.cs b/src/runtime/Types/LookUpObject.cs new file mode 100644 index 000000000..04520132c --- /dev/null +++ b/src/runtime/Types/LookUpObject.cs @@ -0,0 +1,121 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Python.Runtime +{ + /// + /// Implements a Python type for managed objects that support look up (dictionaries), + /// that is, they implement ContainsKey(). + /// This type is essentially the same as a ClassObject, except that it provides + /// sequence semantics to support natural dictionary usage (__contains__ and __len__) + /// from Python. + /// + internal class LookUpObject : ClassObject + { + [NonSerialized] + private static Dictionary, MethodInfo> methodsByType = new Dictionary, MethodInfo>(); + private static List<(string, int)> requiredMethods = new (){ ("Count", 0), ("ContainsKey", 1) }; + + private static MethodInfo GetRequiredMethod(MethodInfo[] methods, string methodName, int parametersCount) + { + return methods.FirstOrDefault(m => m.Name == methodName && m.GetParameters().Length == parametersCount); + } + + internal static bool VerifyMethodRequirements(Type type) + { + var methods = type.GetMethods(); + + foreach (var (requiredMethod, parametersCount) in requiredMethods) + { + var method = GetRequiredMethod(methods, requiredMethod, parametersCount); + if (method == null) + { + var getterName = $"get_{requiredMethod}"; + method = GetRequiredMethod(methods, getterName, parametersCount); + if (method == null) + { + return false; + } + } + + var key = Tuple.Create(type, requiredMethod); + methodsByType.Add(key, method); + } + + return true; + } + + internal LookUpObject(Type tp) : base(tp) + { + } + + /// + /// Implements __len__ for dictionary types. + /// + public static int mp_length(BorrowedReference ob) + { + return LookUpObjectExtensions.Length(ob, methodsByType); + } + + /// + /// Implements __contains__ for dictionary types. + /// + public static int sq_contains(BorrowedReference ob, BorrowedReference v) + { + return LookUpObjectExtensions.Contains(ob, v, methodsByType); + } + } + + internal static class LookUpObjectExtensions + { + internal static bool IsLookUp(this Type type) + { + return LookUpObject.VerifyMethodRequirements(type); + } + + /// + /// Implements __len__ for dictionary types. + /// + internal static int Length(BorrowedReference ob, Dictionary, MethodInfo> methodsByType) + { + var obj = (CLRObject)ManagedType.GetManagedObject(ob); + var self = obj.inst; + + var key = Tuple.Create(self.GetType(), "Count"); + var methodInfo = methodsByType[key]; + + return (int)methodInfo.Invoke(self, null); + } + + /// + /// Implements __contains__ for dictionary types. + /// + internal static int Contains(BorrowedReference ob, BorrowedReference v, Dictionary, MethodInfo> methodsByType) + { + var obj = (CLRObject)ManagedType.GetManagedObject(ob); + var self = obj.inst; + + var key = Tuple.Create(self.GetType(), "ContainsKey"); + var methodInfo = methodsByType[key]; + + var parameters = methodInfo.GetParameters(); + object arg; + if (!Converter.ToManaged(v, parameters[0].ParameterType, out arg, false)) + { + Exceptions.SetError(Exceptions.TypeError, + $"invalid parameter type for sq_contains: should be {Converter.GetTypeByAlias(v)}, found {parameters[0].ParameterType}"); + } + + // If the argument is None, we return false. Python allows using None as key, + // but C# doesn't and will throw, so we shortcut here + if (arg == null) + { + return 0; + } + + return (bool)methodInfo.Invoke(self, new[] { arg }) ? 1 : 0; + } + } +}