diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Animal.cs b/src/NHibernate.DomainModel/Northwind/Entities/Animal.cs index 21fac97d4ae..0aec780eb02 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/Animal.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Animal.cs @@ -12,7 +12,8 @@ public class Animal public virtual Animal Father { get; set; } public virtual IList Children { get; set; } public virtual string SerialNumber { get; set; } - + public virtual string FatherSerialNumber => Father?.SerialNumber; + public virtual bool HasFather => Father != null; public virtual Animal FatherOrMother => Father ?? Mother; } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs new file mode 100644 index 00000000000..c8c9fb3cf03 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs @@ -0,0 +1,18 @@ +using System.Collections; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class DynamicUser : IEnumerable + { + public virtual int Id { get; set; } + + public virtual dynamic Properties { get; set; } + + public virtual IDictionary Settings { get; set; } + + public virtual IEnumerator GetEnumerator() + { + throw new System.NotImplementedException(); + } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index c4cbda23f26..4551ce0e9d8 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable DynamicUsers + { + get { return _session.Query(); } + } + public IQueryable PatientRecords { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Role.cs b/src/NHibernate.DomainModel/Northwind/Entities/Role.cs index 2643c03d285..021f39cce19 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/Role.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Role.cs @@ -7,5 +7,6 @@ public class Role public virtual bool IsActive { get; set; } public virtual AnotherEntity Entity { get; set; } public virtual Role ParentRole { get; set; } + public virtual User CreatedBy { get; set; } // Not mapped } } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c23e667be9b..14096dac912 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -48,10 +48,16 @@ public class User : IUser, IEntity public virtual FeatureSet Features { get; set; } + public virtual User NotMappedUser => this; + public virtual EnumStoredAsString Enum1 { get; set; } + public virtual EnumStoredAsString? NullableEnum1 { get; set; } + public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual EnumStoredAsInt32? NullableEnum2 { get; set; } + public virtual IUser CreatedBy { get; set; } public virtual IUser ModifiedBy { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/UserComponent.cs b/src/NHibernate.DomainModel/Northwind/Entities/UserComponent.cs index bd94d3a8f87..868c856eb80 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/UserComponent.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/UserComponent.cs @@ -5,10 +5,14 @@ public class UserComponent public string Property1 { get; set; } public string Property2 { get; set; } public UserComponent2 OtherComponent { get; set; } + + public string Property3 => $"{Property1}{Property2}"; } public class UserComponent2 { public string OtherProperty1 { get; set; } + + public string OtherProperty2 => OtherProperty1; } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/UserDto.cs b/src/NHibernate.DomainModel/Northwind/Entities/UserDto.cs index 32662ea2bcc..b9ab6e8fade 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/UserDto.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/UserDto.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; namespace NHibernate.DomainModel.Northwind.Entities { @@ -9,11 +10,13 @@ public class UserDto public virtual int InvalidLoginAttempts { get; set; } public virtual string RoleName { get; set; } public virtual UserDto2 Dto2 { get; set; } + public virtual List Dto2List { get; set; } = new List(); public UserDto(int id, string name) { Id = id; Name = name; + Dto2 = new UserDto2(); } } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml new file mode 100644 index 00000000000..1b6775b29c6 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml @@ -0,0 +1,30 @@ + + + + + select * from Users + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index 2764cb70898..f249de9574e 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -24,8 +24,14 @@ + + + + + diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 622a806ed30..6e9355d294c 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -61,5 +61,42 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public async Task ConditionalNavigationPropertyAsync() + { + EnumStoredAsString? type = null; + await (db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToListAsync()); + await (db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToListAsync()); + await (db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToListAsync()); + } + + [Test] + public async Task CanQueryComplexExpressionOnEnumStoredAsStringAsync() + { + var type = EnumStoredAsString.Unspecified; + var query = await ((from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToListAsync()); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 4fbebe3e78b..0956fdfe92b 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -88,6 +88,34 @@ public async Task UsingTwoEntityParametersAsync() 2)); } + [Test] + public async Task UsingEntityEnumerableParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = await (db.DynamicUsers.FirstAsync()); + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1)); + } + + [Test] + public async Task UsingEntityEnumerableListParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {await (db.DynamicUsers.FirstAsync())}; + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1)); + } + [Test] public async Task UsingValueTypeParameterTwiceAsync() { diff --git a/src/NHibernate.Test/Async/Linq/SelectionTests.cs b/src/NHibernate.Test/Async/Linq/SelectionTests.cs index 2404abae275..c2775573396 100644 --- a/src/NHibernate.Test/Async/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Async/Linq/SelectionTests.cs @@ -13,8 +13,12 @@ using System.Linq; using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Exceptions; +using NHibernate.Proxy; using NHibernate.Type; using NUnit.Framework; +using static NHibernate.Linq.ExpressionEvaluation; using NHibernate.Linq; namespace NHibernate.Test.Linq @@ -131,7 +135,7 @@ public async Task CanSelectNestedMemberInitExpressionAsync() { InvalidLoginAttempts = user.InvalidLoginAttempts, Dto2 = new UserDto2 - { + { RegisteredAt = user.RegisteredAt, Enum = user.Enum2 }, @@ -154,7 +158,7 @@ public async Task CanSelectNestedMemberInitWithinNewExpressionAsync() user.Name, user.InvalidLoginAttempts, Dto = new UserDto2 - { + { RegisteredAt = user.RegisteredAt, Enum = user.Enum2 }, @@ -272,6 +276,161 @@ public async Task CanSelectWithAnySubQueryAsync() Assert.AreEqual(1, list.Count(t => !t.HasEntries)); } + [Test] + public async Task CanSelectConditionalAsync() + { + // SqlServerCeDriver and OdbcDriver have an issue matching the case statements inside select and order by statement, + // when having one or more parameters inside them. Throws with the following error: + // ORDER BY items must appear in the select list if SELECT DISTINCT is specified. + if (!(Sfi.ConnectionProvider.Driver is OdbcDriver) && !(Sfi.ConnectionProvider.Driver is SqlServerCeDriver)) + { + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => o.ShippedTo.Contains("test") ? o.ShippedTo : o.Customer.CompanyName) + .OrderBy(o => o) + .Distinct() + .ToListAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(2)); + } + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate) + .FirstOrDefaultAsync()); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : o.ShippedTo) + }) + .FirstOrDefaultAsync()); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : null) + }) + .FirstOrDefaultAsync()); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : "default") + }) + .FirstOrDefaultAsync()); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + var defaultValue = "default"; + using (var sqlLog = new SqlLogSpy()) + { + var q = await (db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : defaultValue) + }) + .FirstOrDefaultAsync()); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + } + + [Test] + public async Task CanSelectConditionalSubQueryAsync() + { + if (!Dialect.SupportsScalarSubSelects) + Assert.Ignore(Dialect.GetType().Name + " does not support scalar sub-queries"); + + var list = await (db.Customers + .Select(c => new + { + Date = db.Orders.Where(o => o.Customer.CustomerId == c.CustomerId) + .Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate) + .Max() + }) + .ToListAsync()); + Assert.That(list, Has.Count.GreaterThan(0)); + + var list2 = await (db.Orders + .Select( + o => new + { + UnitPrice = o.Freight.HasValue + ? o.OrderLines.Where(l => l.Discount == 1) + .Select(l => l.Product.UnitPrice.HasValue ? l.Product.UnitPrice : l.UnitPrice) + .Max() + : o.OrderLines.Where(l => l.Discount == 0) + .Select(l => l.Product.UnitPrice.HasValue ? l.Product.UnitPrice : l.UnitPrice) + .Max() + }) + .ToListAsync()); + Assert.That(list2, Has.Count.GreaterThan(0)); + + var list3 = await (db.Orders + .Select(o => new + { + Date = o.OrderLines.Any(l => o.OrderDate.HasValue) + ? db.Employees + .Select(e => e.BirthDate.HasValue ? e.BirthDate : e.HireDate) + .Max() + : o.Employee.Superior != null ? o.Employee.Superior.BirthDate : o.Employee.BirthDate + }) + .ToListAsync()); + Assert.That(list3, Has.Count.GreaterThan(0)); + + var list4 = await (db.Orders + .Select(o => new + { + Employee = db.Employees.Any(e => e.Superior != null) + ? db.Employees + .Where(e => e.Superior != null) + .Select(e => e.Superior).FirstOrDefault() + : o.Employee.Superior != null ? o.Employee.Superior : o.Employee + }) + .ToListAsync()); + Assert.That(list4, Has.Count.GreaterThan(0)); + } + [Test, KnownBug("NH-3045")] public async Task CanSelectFirstElementFromChildCollectionAsync() { @@ -371,56 +530,56 @@ public async Task CanSelectConditionalKnownTypesAsync() if (!Dialect.SupportsScalarSubSelects) Assert.Ignore(Dialect.GetType().Name + " does not support scalar sub-queries"); - var moreThanTwoOrderLinesBool = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : false }).ToListAsync()); + var moreThanTwoOrderLinesBool = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : false, Param = true }).ToListAsync()); Assert.That(moreThanTwoOrderLinesBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410)); - var moreThanTwoOrderLinesNBool = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : (bool?)null }).ToListAsync()); + var moreThanTwoOrderLinesNBool = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : (bool?)null, Param = (bool?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410)); - var moreThanTwoOrderLinesShort = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short)1 : (short)0 }).ToListAsync()); + var moreThanTwoOrderLinesShort = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short)1 : (short)0, Param = (short)0 }).ToListAsync()); Assert.That(moreThanTwoOrderLinesShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesNShort = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short?)1 : (short?)null }).ToListAsync()); + var moreThanTwoOrderLinesNShort = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short?)1 : (short?)null, Param = (short?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesInt = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : 0 }).ToListAsync()); + var moreThanTwoOrderLinesInt = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : 0, Param = 1 }).ToListAsync()); Assert.That(moreThanTwoOrderLinesInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesNInt = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : (int?)null }).ToListAsync()); + var moreThanTwoOrderLinesNInt = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : (int?)null, Param = (int?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesDecimal = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : 0m }).ToListAsync()); + var moreThanTwoOrderLinesDecimal = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : 0m, Param = 1m }).ToListAsync()); Assert.That(moreThanTwoOrderLinesDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDecimal = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : (decimal?)null }).ToListAsync()); + var moreThanTwoOrderLinesNDecimal = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : (decimal?)null, Param = (decimal?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410)); - var moreThanTwoOrderLinesSingle = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : 0f }).ToListAsync()); + var moreThanTwoOrderLinesSingle = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : 0f, Param = 1f }).ToListAsync()); Assert.That(moreThanTwoOrderLinesSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410)); - var moreThanTwoOrderLinesNSingle = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : (float?)null }).ToListAsync()); + var moreThanTwoOrderLinesNSingle = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : (float?)null, Param = (float?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410)); - var moreThanTwoOrderLinesDouble = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : 0d }).ToListAsync()); + var moreThanTwoOrderLinesDouble = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : 0d, Param = 1d }).ToListAsync()); Assert.That(moreThanTwoOrderLinesDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDouble = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : (double?)null }).ToListAsync()); + var moreThanTwoOrderLinesNDouble = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : (double?)null, Param = (double?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410)); - var moreThanTwoOrderLinesString = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? "yes" : "no" }).ToListAsync()); + var moreThanTwoOrderLinesString = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? "yes" : "no", Param = "no" }).ToListAsync()); Assert.That(moreThanTwoOrderLinesString.Count(x => x.HasMoreThanTwo == "yes"), Is.EqualTo(410)); var now = DateTime.Now.Date; - var moreThanTwoOrderLinesDateTime = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate.Value : now }).ToListAsync()); + var moreThanTwoOrderLinesDateTime = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate.Value : now, Param = now }).ToListAsync()); Assert.That(moreThanTwoOrderLinesDateTime.Count(x => x.HasMoreThanTwo != now), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDateTime = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate : null }).ToListAsync()); + var moreThanTwoOrderLinesNDateTime = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate : null, Param = (DateTime?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNDateTime.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410)); - var moreThanTwoOrderLinesGuid = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : Guid.Empty }).ToListAsync()); + var moreThanTwoOrderLinesGuid = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : Guid.Empty, Param = Guid.Empty }).ToListAsync()); Assert.That(moreThanTwoOrderLinesGuid.Count(x => x.HasMoreThanTwo != Guid.Empty), Is.EqualTo(410)); - var moreThanTwoOrderLinesNGuid = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : (Guid?)null }).ToListAsync()); + var moreThanTwoOrderLinesNGuid = await (db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : (Guid?)null, Param = (Guid?)null }).ToListAsync()); Assert.That(moreThanTwoOrderLinesNGuid.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410)); } @@ -453,6 +612,632 @@ public async Task CanSelectConditionalEntityValueWithEntityComparisonAsync() Assert.That(fatherInsteadOfChild, Has.Exactly(2).EqualTo("5678")); } + [Test] + public async Task CanSelectModulusAsync() + { + var list = await (db.Animals.Select(a => new { Sql = a.Id % 2.1f, a.Id }).ToListAsync()); + Assert.That(list.Select(o => o.Sql), Is.EqualTo(list.Select(o => o.Id % 2.1f)).Within(GetTolerance())); + var list1 = await (db.Animals.Select(a => new { Sql = a.Id % 2.1d, a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id % 2.1d)).Within(GetTolerance())); + var list2 = await (db.Animals.Select(a => new { Sql = a.BodyWeight % 2.1f, a.BodyWeight }).ToListAsync()); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.BodyWeight % 2.1f)).Within(GetTolerance())); + var list3 = await (db.Animals.Select(a => new { Sql = a.Id % 2.1m, a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id % 2.1m))); + var list4 = await (db.Animals.Select(a => new { Sql = a.Id % 2, a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id % 2))); + var list5 = await (db.Animals.Select(a => new { Sql = a.Id % 2L, a.Id }).ToListAsync()); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id % 2L))); + var list7 = await (db.Animals.Select(a => new { Sql = a.BodyWeight % 2, a.BodyWeight }).ToListAsync()); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.BodyWeight % 2))); + var list8 = await (db.Animals.Select(a => new { Sql = a.BodyWeight % 2L, a.BodyWeight }).ToListAsync()); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight % 2L))); + var list9 = await (db.Products.Select(a => new { Sql = a.UnitPrice % 2L, a.UnitPrice }).ToListAsync()); + Assert.That(list9.Select(o => o.Sql), Is.EqualTo(list9.Select(o => o.UnitPrice % 2L))); + var list10 = await (db.Products.Select(a => new { Sql = a.UnitPrice % 2, a.UnitPrice }).ToListAsync()); + Assert.That(list10.Select(o => o.Sql), Is.EqualTo(list10.Select(o => o.UnitPrice % 2))); + } + + [Test] + public async Task CanSelectModulusSameExpressionAsync() + { + var list1 = await (db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1m, OriginalValue = a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.CalculatedValue), Is.EqualTo(list1.Select(o => o.OriginalValue % 2.1m))); + var list2 = await (db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2L, OriginalValue = a.Id }).ToListAsync()); + Assert.That(list2.Select(o => o.CalculatedValue), Is.EqualTo(list2.Select(o => o.OriginalValue % 2L))); + var list3 = await (db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1f, OriginalValue = a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.CalculatedValue), Is.EqualTo(list3.Select(o => o.OriginalValue % 2.1f)).Within(GetTolerance())); + var list4 = await (db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1d, OriginalValue = a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.CalculatedValue), Is.EqualTo(list4.Select(o => o.OriginalValue % 2.1d)).Within(GetTolerance())); + } + + [Test] + public async Task CanForceClientEvaluationAsync() + { + var query = db.Animals.Select(a => ClientEval(() => a.Id + 5)); + Assert.That(GetSqlSelect(query), Does.Not.Contain("+")); + Assert.That(await (query.ToListAsync()), Is.EqualTo(await (db.Animals.Select(a => a.Id + 5).ToListAsync()))); + + query = db.Animals.Select(a => ClientEval(() => a.SerialNumber.Length)); + Assert.That(GetSqlSelect(query), Does.Not.Contain("len(").And.Not.Contain("length(")); + Assert.That(await (query.ToListAsync()), Is.EqualTo(await (db.Animals.Select(a => a.SerialNumber.Length).ToListAsync()))); + + var query2 = db.Animals.Select(a => ClientEval(() => a.SerialNumber.Substring(0, 1))); + Assert.That(GetSqlSelect(query2), Does.Not.Contain("substr(").And.Not.Contain("substring(")); + Assert.That(await (query2.ToListAsync()), Is.EqualTo(await (db.Animals.Select(a => a.SerialNumber.Substring(0, 1)).ToListAsync()))); + + query2 = db.Animals.Select(a => ClientEval(() => a.Id % 2 == 0 ? a.SerialNumber : a.Description)); + Assert.That(GetSqlSelect(query2), Does.Not.Contain("case")); + Assert.That(await (query2.ToListAsync()), Is.EqualTo(await (db.Animals.Select(a => a.Id % 2 == 0 ? a.SerialNumber : a.Description).ToListAsync()))); + + var query3 = await (db.Animals.Select(a => new + { + Client = ClientEval(() => a.Id % 2 == 0 ? a.SerialNumber.Substring(0, 1) : a.Description), + Server = a.Id % 2 == 0 ? a.SerialNumber.Substring(0, 1) : a.Description, + }).ToListAsync()); + Assert.That(query3.Select(o => o.Client), Is.EqualTo(query3.Select(o => o.Server))); + } + + [Test] + public async Task CanSelectMultiplyOperatorAsync() + { + var list1 = await (db.Animals.Select(a => new { Sql = a.Id * 5, a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id * 5))); + var list2 = await (db.Animals.Select(a => new { Sql = a.Id * 12345.54321m, a.Id }).ToListAsync()); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id * 12345.54321m))); + var list3 = await (db.Animals.Select(a => new { Sql = a.Id * 123.321f, a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id * 123.321f)).Within(GetTolerance())); + var list4 = await (db.Animals.Select(a => new { Sql = a.Id * 12345.54321d, a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id * 12345.54321d)).Within(GetTolerance())); + var list5 = await (db.Animals.Select(a => new { Sql = a.Id * 2L, a.Id }).ToListAsync()); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id * 2L))); + + var list6 = await (db.Products.Select(a => new { Sql = a.UnitPrice * 12345.54321m, a.UnitPrice }).ToListAsync()); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice * 12345.54321m))); + var list7 = await (db.Products.Select(a => new { Sql = a.UnitPrice * 12345L, a.UnitPrice }).ToListAsync()); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice * 12345L))); + + var list8 = await (db.Animals.Select(a => new { Sql = a.BodyWeight * 12345.54321f, a.BodyWeight }).ToListAsync()); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight * 12345.54321f))); + } + + [Test] + public async Task CanSelectDivideOperatorAsync() + { + var list1 = await (db.Animals.Select(a => new { Sql = a.Id / 5, a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id / 5))); + var list2 = await (db.Animals.Select(a => new { Sql = a.Id / 12345.54321m, a.Id }).ToListAsync()); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id / 12345.54321m))); + var list3 = await (db.Animals.Select(a => new { Sql = a.Id / 12345.54321f, a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id / 12345.54321f)).Within(GetTolerance())); + var list4 = await (db.Animals.Select(a => new { Sql = a.Id / 12345.54321d, a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id / 12345.54321d)).Within(GetTolerance())); + var list5 = await (db.Animals.Select(a => new { Sql = a.Id / 2L, a.Id }).ToListAsync()); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id / 2L))); + + var list6 = await (db.Products.Select(a => new { Sql = a.UnitPrice / 12345.54321m, a.UnitPrice }).ToListAsync()); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice / 12345.54321m))); + var list7 = await (db.Products.Select(a => new { Sql = a.UnitPrice.Value / 12345L, a.UnitPrice }).ToListAsync()); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice / 12345L))); + + var list8 = await (db.Animals.Select(a => new { Sql = a.BodyWeight / 12345.54321f, a.BodyWeight }).ToListAsync()); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight / 12345.54321f)).Within(GetTolerance())); + } + + [Test] + public async Task CanSelectAddOperatorAsync() + { + var list1 = await (db.Animals.Select(a => new { Sql = a.Id + 5, a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id + 5))); + var list2 = await (db.Animals.Select(a => new { Sql = a.Id + 12345.54321m, a.Id }).ToListAsync()); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id + 12345.54321m))); + var list3 = await (db.Animals.Select(a => new { Sql = a.Id + 12345.54321f, a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id + 12345.54321f)).Within(GetTolerance())); + var list4 = await (db.Animals.Select(a => new { Sql = a.Id + 12345.54321d, a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id + 12345.54321d))); + var list5 = await (db.Animals.Select(a => new { Sql = a.Id + 2L, a.Id }).ToListAsync()); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id + 2L))); + + var list6 = await (db.Products.Select(a => new { Sql = a.UnitPrice + 12345.54321m, a.UnitPrice }).ToListAsync()); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice + 12345.54321m))); + var list7 = await (db.Products.Select(a => new { Sql = a.UnitPrice + 12345L, a.UnitPrice }).ToListAsync()); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice + 12345L))); + + var list8 = await (db.Animals.Select(a => new { Sql = a.BodyWeight + 12345.54321f, a.BodyWeight }).ToListAsync()); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight + 12345.54321f))); + } + + [Test] + public async Task CanSelectSubtractOperatorAsync() + { + var list1 = await (db.Animals.Select(a => new { Sql = a.Id - 5, a.Id }).ToListAsync()); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id - 5))); + var list2 = await (db.Animals.Select(a => new { Sql = a.Id - 12345.54321m, a.Id }).ToListAsync()); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id - 12345.54321m))); + var list3 = await (db.Animals.Select(a => new { Sql = a.Id - 12345.54321f, a.Id }).ToListAsync()); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id - 12345.54321f)).Within(GetTolerance())); + var list4 = await (db.Animals.Select(a => new { Sql = a.Id - 12345.54321d, a.Id }).ToListAsync()); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id - 12345.54321d))); + var list5 = await (db.Animals.Select(a => new { Sql = a.Id - 2L, a.Id }).ToListAsync()); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id - 2L))); + + var list6 = await (db.Products.Select(a => new { Sql = a.UnitPrice - 12345.54321m, a.UnitPrice }).ToListAsync()); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice - 12345.54321m))); + var list7 = await (db.Products.Select(a => new { Sql = a.UnitPrice - 12345L, a.UnitPrice }).ToListAsync()); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice - 12345L))); + + var list8 = await (db.Animals.Select(a => new { Sql = a.BodyWeight - 12345.54321f, a.BodyWeight }).ToListAsync()); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight - 12345.54321f))); + } + + private class ObjectDto + { + public object CalculatedValue { get; set; } + + public int OriginalValue { get; set; } + } + + [Test] + public async Task CanSelectConditionalEntityValueWithEntityComparisonComplexAsync() + { + var animal = await (db.Animals.Select( + a => new + { + Parent = a.Father != null || a.Mother != null ? (a.Father ?? a.Mother) : null, + ParentSerialNumber = a.Father != null || a.Mother != null ? (a.Father ?? a.Mother).SerialNumber : null, + Parent2 = a.Mother ?? a.Father, + a.Father, + a.Mother + }) + .FirstOrDefaultAsync(o => o.ParentSerialNumber == "5678")); + + Assert.That(animal, Is.Not.Null); + Assert.That(animal.Father, Is.Not.Null); + Assert.That(animal.Mother, Is.Not.Null); + Assert.That(animal.Parent, Is.Not.Null); + Assert.That(animal.Parent2, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(animal.Parent), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Parent2), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Father), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Mother), Is.True); + } + + [Test] + public async Task CanSelectConditionalEntityValueWithEntityCastAsync() + { + var list = await (db.Animals.Select( + a => new + { + BodyWeight = (double?) (a is Cat + ? (a.Father ?? a.Mother).BodyWeight + : (a is Dog + ? (a.Mother ?? a.Father).BodyWeight + : (a.Father.Father.BodyWeight) + )) + }) + .ToListAsync()); + Assert.That(list, Has.Exactly(1).With.Property("BodyWeight").Not.Null); + } + + [Test] + public async Task CanSelectBinaryClientSideTestAsync() + { + var exception = Assert.ThrowsAsync(() => + { + return db.Animals.Select(a => a.FatherOrMother.BodyWeight + a.BodyWeight).ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. Cast expression '([a].FatherOrMother.BodyWeight + [a].BodyWeight)' to 'System.Nullable`1[System.Double]'.")); + + var list = await (db.Animals.Select(a => (double?) (a.FatherOrMother.BodyWeight + a.BodyWeight)).ToListAsync()); + Assert.That(list, Has.Exactly(5).Null.And.Exactly(1).EqualTo(271d)); + + // Arithmetic operator + var list2 = await (db.Animals.Select(a => new + { + // Left side null + Client = (double?) (a.FatherOrMother.BodyWeight + a.BodyWeight + a.Father.BodyWeight), + Server = (double?) (a.Father ?? a.Mother).BodyWeight + a.BodyWeight + a.Father.BodyWeight, + // Right side null + Client2 = (double?) (a.BodyWeight - a.Father.BodyWeight - a.FatherOrMother.BodyWeight), + Server2 = (double?) a.BodyWeight - a.Father.BodyWeight - (a.Father ?? a.Mother).BodyWeight + }).ToListAsync()); + Assert.That(list2.Select(o => o.Client), Is.EqualTo(list2.Select(o => o.Server))); + Assert.That(list2.Select(o => o.Client2), Is.EqualTo(list2.Select(o => o.Server2))); + + // Boolean logic operator + var list3 = await (db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.IsActive && true, + Server = u.Role.IsActive && true, + // Right side null + Client2 = true && u.NotMappedUser.Role.IsActive, + Server2 = true && u.Role.IsActive + }).ToListAsync()); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + list3 = await (db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.IsActive || true, + Server = u.Role.IsActive || true, + // Right side null + Client2 = false || u.NotMappedUser.Role.IsActive, + Server2 = false || u.Role.IsActive + }).ToListAsync()); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + // Comparison operator + list3 = await (db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.Id > 0, + Server = u.Role.Id > 0, + // Right side null + Client2 = 0 < u.NotMappedUser.Role.Id, + Server2 = 0 < u.Role.Id + }).ToListAsync()); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + // Bitwise boolean operator + var list4 = await (db.Users.Select(u => new + { + // Left side null + Client = (bool?) (u.NotMappedUser.Role.IsActive | true), + Server = (bool?) (u.Role.IsActive | true), + // Right side null + Client2 = (bool?) (true | u.NotMappedUser.Role.IsActive), + Server2 = (bool?) (true | u.Role.IsActive) + }).ToListAsync()); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + Assert.That(list4.Select(o => o.Client2), Is.EqualTo(list4.Select(o => o.Server2))); + + // Bitwise number operator + var list5 = await (db.Users.Select(u => new + { + // Left side null + Client = (int?) (u.NotMappedUser.Role.Id | 5), + Server = (int?) (u.Role.Id | 5), + // Right side null + Client2 = (int?) (5 | u.NotMappedUser.Role.Id), + Server2 = (int?) (5 | u.Role.Id) + }).ToListAsync()); + Assert.That(list5.Select(o => o.Client), Is.EqualTo(list5.Select(o => o.Server))); + Assert.That(list5.Select(o => o.Client2), Is.EqualTo(list5.Select(o => o.Server2))); + + // Coalesce operator + var list6 = await (db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.Name ?? u.NotMappedUser.Name, + Server = u.Role.Name ?? u.Name, + // Right side null + Client2 = u.NotMappedUser.Name ?? u.NotMappedUser.Role.Name, + Server2 = u.Name ?? u.Role.Name, + // Both side null + Client3 = u.NotMappedUser.Role.Name ?? u.NotMappedUser.Role.Name, + Server3 = u.Role.Name ?? u.Role.Name + }).ToListAsync()); + Assert.That(list6.Select(o => o.Client), Is.EqualTo(list6.Select(o => o.Server))); + Assert.That(list6.Select(o => o.Client2), Is.EqualTo(list6.Select(o => o.Server2))); + Assert.That(list6.Select(o => o.Client3), Is.EqualTo(list6.Select(o => o.Server3))); + } + + [Test] + public async Task CanSelectUnaryClientSideTestAsync() + { + var exception = Assert.ThrowsAsync(() => + { + return db.Animals.Select(a => -a.FatherOrMother.BodyWeight).ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. Cast expression '-[a].FatherOrMother.BodyWeight' to 'System.Nullable`1[System.Double]'.")); + + // Negate + var list = await (db.Animals.Select(a => new + { + Client = (double?) -a.FatherOrMother.BodyWeight, + Server = (double?) -((a.Father ?? a.Mother).BodyWeight) + }).ToListAsync()); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // Convert + list = await (db.Animals.Select(a => new + { + Client = (double?) a.FatherOrMother.BodyWeight, + Server = (double?) (a.Father ?? a.Mother).BodyWeight + }).ToListAsync()); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // UnaryPlus + list = await (db.Animals.Select(a => new + { + Client = (double?) +a.FatherOrMother.BodyWeight, + Server = (double?) +((a.Father ?? a.Mother).BodyWeight) + }).ToListAsync()); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // Not + var list2 = await (db.Users.Select(u => new + { + Client = (bool?) !u.NotMappedUser.Role.IsActive, + Server = (bool?) !u.Role.IsActive + }).ToListAsync()); + Assert.That(list2.Select(o => o.Client), Is.EqualTo(list2.Select(o => o.Server))); + + // Convert value type + var list3 = await (db.Users.Select(u => (int?) (u.Role != null ? 5 : 10)).ToListAsync()); + Assert.That(list3, Has.Exactly(3).Not.Null); + + // Convert enum + list3 = await (db.Users.Select(u => (int?) u.Role.CreatedBy.Enum2).ToListAsync()); + Assert.That(list3, Has.Exactly(3).Null); + + // Convert reference type + var list4 = await (db.Animals.Select(a => new + { + Client = (Dog) a.FatherOrMother, + Server = (Dog) (a.Father ?? a.Mother) + }).ToListAsync()); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + + // TypeAs + list4 = await (db.Animals.Select(a => new + { + Client = a.FatherOrMother as Dog, + Server = (a.Father ?? a.Mother) as Dog + }).ToListAsync()); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + + // Convert constant reference type + var list5 = await (db.Animals.Select(a => (Animal) new Dog()).ToListAsync()); + Assert.That(list5, Has.Exactly(6).Not.Null); + } + + [Test] + public async Task CanSelectConditionalClientSideWithNullValueTypeTestAsync() + { + var exception = Assert.ThrowsAsync(() => + { + return db.Animals.Select( + a => new + { + BodyWeight = (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + }) + .ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. " + + "Cast expression 'IIF(IsNullOrWhiteSpace([a].Description), [_3].BodyWeight, [_1].BodyWeight)' to 'System.Nullable`1[System.Double]'.")); + + var list = await (db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + }) + .ToListAsync()); + Assert.That(list, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list2 = await (db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : 5d) + }) + .ToListAsync()); + Assert.That(list2, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list3 = await (db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? 5d + : a.Father.Mother.BodyWeight) + }) + .ToListAsync()); + Assert.That(list3, Has.Exactly(6).With.Property("BodyWeight").Not.Null); + + var list4 = await (db.Animals.Select( + a => new + { + BodyWeightHashCode = (int?) ((string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight)).GetHashCode() + }) + .ToListAsync()); + Assert.That(list4, Has.Exactly(0).With.Property("BodyWeightHashCode").Not.Null); + + var list5 = await (db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + : (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight)) + }) + .ToListAsync()); + Assert.That(list5, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list6 = await (db.Animals.Select( + a => new + { + Client = a.Father.HasFather ? (double?) null : a.BodyWeight, + Server = a.Father.Father != null ? (double?) null : a.BodyWeight, + }) + .ToListAsync()); + Assert.That(list6.Select(o => o.Client), Is.EqualTo(list6.Select(o => o.Server))); + + var list7 = await (db.Users.Select( + a => new + { + Client = a.NotMappedUser.Role.IsActive ? 1 : 2, + Server = a.Role.IsActive ? 1 : 2 + }) + .ToListAsync()); + Assert.That(list7.Select(o => o.Client), Is.EqualTo(list7.Select(o => o.Server))); + } + + [Test] + public async Task CanExecuteMethodWithNullObjectClientSideTestAsync() + { + var exception = Assert.ThrowsAsync(() => + { + return db.Animals.Select( + a => new + { + a.Id, + FatherId = a.Father.Father.Id + }) + .ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_0].Father.Id' to 'System.Nullable`1[System.Int32]'.")); + + exception = Assert.ThrowsAsync(() => + { + return db.Animals.Select( + a => new + { + a.Id, + FatherIdHashCode = a.Father.Father.Id.GetHashCode() + }) + .ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_1].Id.GetHashCode()' to 'System.Nullable`1[System.Int32]'.")); + + var list = await (db.Animals.Select( + a => new + { + NullableId = (int?) a.Father.Father.Id, + NullableIdHashCode = (int?) a.Father.Father.Id.GetHashCode() + }) + .ToListAsync()); + Assert.That(list, Has.Exactly(0).With.Property("NullableId").Not.Null); + } + + [Test] + public void CanSelectWithIsOperatorAsync() + { + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => a is Dog).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => a.FatherSerialNumber is string).ToListAsync()); + } + + [Test] + public async Task CanSelectNonMappedComponentPropertyAsync() + { + Assert.DoesNotThrowAsync(() => db.Users.Select(u => u.Component.Property3).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Users.Select(u => u.Component.OtherComponent.OtherProperty2).ToListAsync()); + var list = await (db.Users.Select(u => new + { + u.Component.OtherComponent.OtherProperty1, + OtherProperty3 = u.Component.OtherComponent.OtherProperty2, + u.Component.Property1, + u.Component.Property2, + u.Component.Property3 + }).ToListAsync()); + Assert.That(list.Select(o => o.OtherProperty3), Is.EqualTo(list.Select(o => o.OtherProperty1))); + Assert.That( + list.Select(o => (o.Property1 ?? o.Property2) == null ? null : $"{o.Property1}{o.Property2}"), + Is.EqualTo(list.Select(o => o.Property3))); + } + + [Test] + public void CanSelectWithAnInvocationAsync() + { + Func func = s => s + "postfix"; + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => func(a.SerialNumber)).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => func(a.FatherSerialNumber)).ToListAsync()); + } + + [Test] + public void CanSelectEnumerableAsync() + { + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { a.Id } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { a.Id, 1 } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { 1 } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { a, a.Father, a.Mother } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new + { + Enumerable = new[] + { + new UserDto(a.Id, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber}, + new UserDto(1, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber, InvalidLoginAttempts = 1}, + null, + new UserDto(1, "test") {RoleName = "test", InvalidLoginAttempts = 1}, + new UserDto(1, "test") {Dto2List = {new UserDto2(), new UserDto2()}, Dto2 = {Enum = EnumStoredAsInt32.High}}, + new UserDto(1, a.FatherSerialNumber) + { + Dto2List = {new UserDto2() { Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified }, new UserDto2()}, + Dto2 = {Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified} + } + } + }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { a.SerialNumber, a.FatherSerialNumber, null } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new int[][] { new[] { a.Id }, new[] { 1 }, new[] { a.Id, 1 } } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new List { a.Id, 1 } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new List(5) { a.Id, 1 } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new List(a.Id) { 1 } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new List(a.Id) { a.SerialNumber, a.FatherSerialNumber, null } }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new + { + Enumerable = new List(a.Id) + { + new UserDto(a.Id, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber}, + new UserDto(1, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber, InvalidLoginAttempts = 1}, + null, + new UserDto(1, "test") {RoleName = "test", InvalidLoginAttempts = 1}, + new UserDto(1, "test") {Dto2List = {new UserDto2(), new UserDto2()}, Dto2 = {Enum = EnumStoredAsInt32.High}}, + new UserDto(1, a.FatherSerialNumber) + { + Dto2List = {new UserDto2() { Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified }, new UserDto2()}, + Dto2 = {Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified} + } + } + }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new[] { a.SerialNumber, a.FatherSerialNumber, null }[a.Id - a.Id].Length }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new { Enumerable = new List { a.SerialNumber, a.FatherSerialNumber, null }[a.Id - a.Id].Length }).ToListAsync()); + Assert.DoesNotThrowAsync(() => db.Animals.Select(a => new + { + Enumerable = new Dictionary + { + { a.SerialNumber, a.FatherSerialNumber }, + { "1", a.Father.SerialNumber }, + { "2", null } + }[a.SerialNumber] + }).ToListAsync()); + } + + [Test] + public async Task CanSelectConditionalSubClassPropertyValueAsync() + { + var animal = await (db.Animals.Select( + a => new + { + Pregnant = a is Mammal ? ((Mammal) a).Pregnant : false + }) + .Where(o => o.Pregnant) + .ToListAsync()); + + Assert.That(animal, Has.Count.EqualTo(1)); + } + [Test] public async Task CanSelectConditionalEntityValueWithEntityComparisonRepeatAsync() { @@ -484,10 +1269,200 @@ public async Task CanCastToCustomRegisteredTypeAsync() Assert.That(await (db.Users.Where(o => (NullableInt32) o.Id == 1).ToListAsync()), Has.Count.EqualTo(1)); } + [Test] + public async Task TestClientSideEvaluationAsync() + { + var list = await (db.Animals.Select(a => new + { + ClientSide = string.IsNullOrEmpty(a.FatherSerialNumber) ? 1 : 0, + ClientSide2 = string.IsNullOrEmpty(a.Father.SerialNumber) ? 1 : 0 + }).ToListAsync()); + Assert.That(list.Select(o => o.ClientSide), Is.EqualTo(list.Select(o => o.ClientSide2))); + + var list2 = await (db.Animals.Select(a => new + { + ClientSide = a.Father.IsProxy(), + ClientSide2 = a.FatherSerialNumber.IsProxy() + }).ToListAsync()); + Assert.That(list2.Select(o => o.ClientSide), Is.EqualTo(list2.Select(o => o.ClientSide2))); + + var list3 = await (db.Orders.Where(o => o.OrderDate.HasValue).Select(o => new + { + ClientSide = o.OrderDate.Value.TimeOfDay.Days, + ClientSide2 = o.OrderDate.Value + }).ToListAsync()); + Assert.That(list3.Select(o => o.ClientSide), Is.EqualTo(list3.Select(o => o.ClientSide2.TimeOfDay.Days))); + + var list4 = await (db.Orders.Where(o => o.OrderDate.HasValue).Select(o => new + { + o.OrderId, + ClientSide = o.OrderDate.Value.TimeOfDay.CompareTo(new TimeSpan(o.OrderId)), + ClientSide2 = o.OrderDate.Value + }).ToListAsync()); + Assert.That(list4.Select(o => o.ClientSide), Is.EqualTo(list4.Select(o => o.ClientSide2.TimeOfDay.CompareTo(new TimeSpan(o.OrderId))))); + } + + [Test] + public async Task TestServerAndClientSideEvaluationComparisonAsync() + { + var list = await (db.Animals.Select( + a => new + { + ServerSide = (int?) a.Father.SerialNumber.Length, + ClientSide = (int?) a.FatherSerialNumber.Length + }).ToListAsync()); + Assert.That(list.Select(o => o.ClientSide), Is.EqualTo(list.Select(o => o.ServerSide))); + + var list1 = await (db.Animals + .Where(a => a.Father.SerialNumber != null) + .Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length, + ClientSide = a.FatherSerialNumber.Length + }) + .ToListAsync()); + Assert.That(list1.Select(o => o.ClientSide), Is.EqualTo(list1.Select(o => o.ServerSide))); + + var clientSide = await (db.Animals.Select(a => a.FatherSerialNumber.Length.ToString()).ToListAsync()); + var serverSide = await (db.Animals.Select(a => a.FatherSerialNumber.Length.ToString()).ToListAsync()); + Assert.That(clientSide, Is.EqualTo(serverSide)); + + var exception = Assert.ThrowsAsync( + () => + { + return db.Animals.Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length + }).ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_0].SerialNumber.Length' to 'System.Nullable`1[System.Int32]'.")); + + exception = Assert.ThrowsAsync( + () => + { + return db.Animals.Select( + a => new + { + ClientSide = a.FatherSerialNumber.Length + }).ToListAsync(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[a].FatherSerialNumber.Length' to 'System.Nullable`1[System.Int32]'.")); + + var list2 = await (db.Animals.Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length.ToString(), + ClientSide = a.FatherSerialNumber.Length.ToString() + }).ToListAsync()); + Assert.That(list2.Select(o => o.ClientSide), Is.EqualTo(list2.Select(o => o.ServerSide))); + + var list3 = await (db.Animals.Select( + a => new + { + ServerSide = (int?) a.Father.SerialNumber.Substring(0, ((int?) a.Father.SerialNumber.Length - 1) ?? 0).Length, + ClientSide = (int?) a.FatherSerialNumber.Substring(0, ((int?) a.FatherSerialNumber.Length - 1) ?? 0).Length + }).ToListAsync()); + Assert.That(list3.Select(o => o.ClientSide), Is.EqualTo(list3.Select(o => o.ServerSide))); + + var list4 = await (db.Animals.Select(a => new + { + ServerSide = a.Father.SerialNumber, + ClientSide = a.FatherSerialNumber, + Test = (object) null + }).ToListAsync()); + Assert.That(list4.Select(o => o.ClientSide), Is.EqualTo(list4.Select(o => o.ServerSide))); + + var list5 = await (db.Animals.Select(a => new + { + ServerSide = a.Father.SerialNumber == null, + ClientSide = a.FatherSerialNumber == null + }).ToListAsync()); + Assert.That(list5.Select(o => o.ClientSide), Is.EqualTo(list5.Select(o => o.ServerSide))); + + var list6 = await (db.Animals + .Where(a => a.Father.SerialNumber != null) + .Select( + a => new + { + ServerSide = -a.Father.SerialNumber.Length, + ClientSide = -a.FatherSerialNumber.Length + }).ToListAsync()); + Assert.That(list6.Select(o => o.ClientSide), Is.EqualTo(list6.Select(o => o.ServerSide))); + + var list7 = await (db.Animals + .Select( + a => new + { + ServerSide = a.Father != null ? a.Father.SerialNumber : null, + ClientSide = a.HasFather ? a.FatherSerialNumber : null + }).ToListAsync()); + Assert.That(list7.Select(o => o.ClientSide), Is.EqualTo(list7.Select(o => o.ServerSide))); + + var list8 = await (db.Animals + .Where(a => a is Dog) + .Select( + a => new + { + ServerSide = (long?) (int?) ((Dog) a).Father.SerialNumber.Length, + ClientSide = (long?) (int?) ((Dog) a).FatherSerialNumber.Length + }).ToListAsync()); + Assert.That(list8.Select(o => o.ClientSide), Is.EqualTo(list8.Select(o => o.ServerSide))); + } + public class Wrapper { public T item; public string message; } + + private double GetTolerance() + { + return !Dialect.SupportsIEEE754FloatingPointNumbers || TestDialect.SendsParameterValuesAsStrings + ? 0.1d + : 0d; + } + + private static void AssertOneSelectColumn(IQueryable query) + { + using (var sqlLog = new SqlLogSpy()) + { + // Execute query + foreach (var item in query) { } + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "as col"), Is.EqualTo(1)); + } + } + + private static string GetSqlSelect(IQueryable query) + { + using (var sqlLog = new SqlLogSpy()) + { + // Execute query + foreach (var item in query) { } + + var sql = sqlLog.GetWholeLog(); + return sql.Substring(0, sql.IndexOf(" from")); + } + } + + private static int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + ++count; + } + return count; + } } } diff --git a/src/NHibernate.Test/Async/TypedManyToOne/TypedManyToOneTest.cs b/src/NHibernate.Test/Async/TypedManyToOne/TypedManyToOneTest.cs index 31f380e80a3..a78d2b3afa6 100644 --- a/src/NHibernate.Test/Async/TypedManyToOne/TypedManyToOneTest.cs +++ b/src/NHibernate.Test/Async/TypedManyToOne/TypedManyToOneTest.cs @@ -9,12 +9,15 @@ using System.Collections; +using System.Linq; using NHibernate.Dialect; using NUnit.Framework; +using NHibernate.Linq; namespace NHibernate.Test.TypedManyToOne { using System.Threading.Tasks; + using System.Threading; [TestFixture] public class TypedManyToOneTestAsync : TestCase { @@ -35,38 +38,27 @@ protected override bool AppliesTo(Dialect.Dialect dialect) } [Test] - public async Task TestCreateQueryAsync() + public async Task TestLinqEntityNameQueryAsync() { - var cust = new Customer(); - cust.CustomerId = "abc123"; - cust.Name = "Matt"; - - var ship = new Address(); - ship.Street = "peachtree rd"; - ship.State = "GA"; - ship.City = "ATL"; - ship.Zip = "30326"; - ship.AddressId = new AddressId("SHIPPING", "xyz123"); - ship.Customer = cust; - - var bill = new Address(); - bill.Street = "peachtree rd"; - bill.State = "GA"; - bill.City = "ATL"; - bill.Zip = "30326"; - bill.AddressId = new AddressId("BILLING", "xyz123"); - bill.Customer = cust; - - cust.BillingAddress = bill; - cust.ShippingAddress = ship; - - using (ISession s = Sfi.OpenSession()) - using (ITransaction t = s.BeginTransaction()) + var cust = await (CreateCustomerAsync()); + using (var s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) { - await (s.PersistAsync(cust)); + var billingNotes = await (s.Query().Select(o => o.BillingAddress.BillingNotes).FirstAsync()); + Assert.That(billingNotes, Is.EqualTo("BillingNotes")); + var shippingNotes = await (s.Query().Select(o => o.ShippingAddress.ShippingNotes).FirstAsync()); + Assert.That(shippingNotes, Is.EqualTo("ShippingNotes")); + await (t.CommitAsync()); } + await (DeleteCustomerAsync(cust)); + } + + [Test] + public async Task TestCreateQueryAsync() + { + var cust = await (CreateCustomerAsync()); using (ISession s = Sfi.OpenSession()) using (ITransaction t = s.BeginTransaction()) { @@ -82,20 +74,7 @@ public async Task TestCreateQueryAsync() await (t.CommitAsync()); } - using (ISession s = Sfi.OpenSession()) - using (ITransaction t = s.BeginTransaction()) - { - await (s.SaveOrUpdateAsync(cust)); - ship = cust.ShippingAddress; - cust.ShippingAddress = null; - await (s.DeleteAsync("ShippingAddress", ship)); - await (s.FlushAsync()); - - Assert.That(await (s.GetAsync("ShippingAddress", ship.AddressId)), Is.Null); - await (s.DeleteAsync(cust)); - - await (t.CommitAsync()); - } + await (DeleteCustomerAsync(cust)); } [Test] @@ -124,5 +103,60 @@ public async Task TestCreateQueryNullAsync() await (t.CommitAsync()); } } + + private async Task CreateCustomerAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + var cust = new Customer(); + cust.CustomerId = "abc123"; + cust.Name = "Matt"; + + var ship = new Address(); + ship.Street = "peachtree rd"; + ship.State = "GA"; + ship.City = "ATL"; + ship.Zip = "30326"; + ship.AddressId = new AddressId("SHIPPING", "xyz123"); + ship.Customer = cust; + ship.ShippingNotes = "ShippingNotes"; + + var bill = new Address(); + bill.Street = "peachtree rd"; + bill.State = "GA"; + bill.City = "ATL"; + bill.Zip = "30326"; + bill.AddressId = new AddressId("BILLING", "xyz123"); + bill.Customer = cust; + bill.BillingNotes = "BillingNotes"; + + cust.BillingAddress = bill; + cust.ShippingAddress = ship; + + using (ISession s = Sfi.OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + await (s.PersistAsync(cust, cancellationToken)); + await (t.CommitAsync(cancellationToken)); + } + + return cust; + } + + private async Task DeleteCustomerAsync(Customer cust, CancellationToken cancellationToken = default(CancellationToken)) + { + using (var s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) + { + await (s.SaveOrUpdateAsync(cust, cancellationToken)); + var ship = cust.ShippingAddress; + cust.ShippingAddress = null; + await (s.DeleteAsync("ShippingAddress", ship, cancellationToken)); + await (s.FlushAsync(cancellationToken)); + + Assert.That(await (s.GetAsync("ShippingAddress", ship.AddressId, cancellationToken)), Is.Null); + await (s.DeleteAsync(cust, cancellationToken)); + + await (t.CommitAsync(cancellationToken)); + } + } } } diff --git a/src/NHibernate.Test/DriverTest/DriverNumericTypesFixture.cs b/src/NHibernate.Test/DriverTest/DriverNumericTypesFixture.cs new file mode 100644 index 00000000000..5ef3393d91b --- /dev/null +++ b/src/NHibernate.Test/DriverTest/DriverNumericTypesFixture.cs @@ -0,0 +1,810 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.ExceptionServices; +using NHibernate.Cfg; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Driver; +using NHibernate.Engine; +using NHibernate.Exceptions; +using NHibernate.Mapping.ByCode; +using NHibernate.SqlTypes; +using NHibernate.Type; +using NHibernate.Util; +using NUnit.Framework; +using Environment = System.Environment; + +namespace NHibernate.Test.DriverTest +{ + public class DriverNumericTypesFixture : TestCaseMappingByCode + { + private static readonly MethodInfo SelectDefinition = + ReflectHelper.GetMethodDefinition(() => Queryable.Select(null, default(Expression>))); + private static readonly MethodInfo FirstDefinition = + ReflectHelper.GetMethodDefinition(() => Queryable.First(null)); + private static readonly MethodInfo LambdaDefinition = + ReflectHelper.GetMethodDefinition(() => Expression.Lambda(null)); + private static readonly ObjectType ObjectTypeInstance = new ObjectType(); + private static readonly Dictionary> DriverNotSupportedTypes = + new Dictionary> + { + { + NHibernateUtil.UInt16, new HashSet + { + typeof(OracleManagedDataClientDriver), + typeof(Sql2008ClientDriver), + typeof(NpgsqlDriver), + typeof(SqlServerCeDriver), + typeof(FirebirdClientDriver), + typeof(OdbcDriver) + } + }, + { + NHibernateUtil.UInt32, new HashSet + { + typeof(OracleManagedDataClientDriver), + typeof(Sql2008ClientDriver), + typeof(NpgsqlDriver), + typeof(SqlServerCeDriver), + typeof(FirebirdClientDriver), + typeof(OdbcDriver) + } + }, + { + NHibernateUtil.UInt64, new HashSet + { + typeof(OracleManagedDataClientDriver), + typeof(Sql2008ClientDriver), + typeof(NpgsqlDriver), + typeof(SqlServerCeDriver), + typeof(FirebirdClientDriver), + typeof(OdbcDriver) + } + } + }; + + private static readonly Dictionary> AllowedDriverDifferentValues = + new Dictionary> + { + { + // This tests should be enabled when MySqlDataDriver will support prepared statements. + typeof(MySqlDataDriver), + new HashSet + { + new OperatorTest('%', nameof(NumericEntity.Float), nameof(NumericEntity.Float), typeof(float), true, false), + new OperatorTest('%', nameof(NumericEntity.Float), nameof(NumericEntity.Float), typeof(float), false, true), + new OperatorTest('-', nameof(NumericEntity.Float), nameof(NumericEntity.Float), typeof(float), true, false), + new OperatorTest('-', nameof(NumericEntity.Float), nameof(NumericEntity.Float), typeof(float), false, true) + } + } + }; + + private static string GetColumnName(string propertyName) + { + return propertyName + "Column"; + } + + private readonly HashSet _ignoreProperties = new HashSet {nameof(NumericEntity.Id)}; + private readonly NumericEntity _originalEntity = new NumericEntity + { + Short = 123, + Integer = 12345, + Long = 1234567L, + UnsignedShort = 123, + UnsignedInteger = 12345, + UnsignedLong = 1234567L, + Currency = 12345.5432m, + Decimal = 1234567.54321m, + DecimalLowScale = 123.32m, + // The floating-point numbers where carefully selected so that they will stay the same when they are saved + // into the database. From testing, MySql seem the most problematic in term of compatibility with .NET + // as the values are differently rounded, which causes TestTypesAndValuesAfterSave to fail. For instance, MySql + // will throw when float.MaxValue is stored (https://stackoverflow.com/a/18832334), due to different rounding. + Double = 3.1415926535897932E+30, + Float = 3.14159E+15f + }; + private List _properties; + + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + var driverType = ReflectHelper.ClassForName(cfg.GetProperty(Cfg.Environment.ConnectionDriver)); + + mapper.Class(o => + { + o.Table(nameof(NumericEntity)); + o.EntityName(nameof(NumericEntity)); + o.Id(x => x.Id, map => map.Generator(Generators.Native)); + o.Property( + x => x.Short, + map => + { + map.Type(NHibernateUtil.Int16); + map.Column(GetColumnName(nameof(NumericEntity.Short))); + }); + o.Property( + x => x.Integer, + map => + { + map.Type(NHibernateUtil.Int32); + map.Column(GetColumnName(nameof(NumericEntity.Integer))); + }); + o.Property( + x => x.Long, + map => + { + map.Type(NHibernateUtil.Int64); + map.Column(GetColumnName(nameof(NumericEntity.Long))); + }); + + if (DriverNotSupportedTypes[NHibernateUtil.UInt16].Contains(driverType)) + { + _ignoreProperties.Add(nameof(NumericEntity.UnsignedShort)); + } + else + { + o.Property( + x => x.UnsignedShort, + map => + { + map.Type(NHibernateUtil.UInt16); + map.Column(GetColumnName(nameof(NumericEntity.UnsignedShort))); + }); + } + + if (DriverNotSupportedTypes[NHibernateUtil.UInt32].Contains(driverType)) + { + _ignoreProperties.Add(nameof(NumericEntity.UnsignedInteger)); + } + else + { + o.Property( + x => x.UnsignedInteger, + map => + { + map.Type(NHibernateUtil.UInt32); + map.Column(GetColumnName(nameof(NumericEntity.UnsignedInteger))); + }); + } + + if (DriverNotSupportedTypes[NHibernateUtil.UInt64].Contains(driverType)) + { + _ignoreProperties.Add(nameof(NumericEntity.UnsignedLong)); + } + else + { + o.Property( + x => x.UnsignedLong, + map => + { + map.Type(NHibernateUtil.UInt64); + map.Column(GetColumnName(nameof(NumericEntity.UnsignedLong))); + }); + } + + o.Property( + x => x.Decimal, + map => + { + map.Type(NHibernateUtil.Decimal); + map.Column(GetColumnName(nameof(NumericEntity.Decimal))); + }); + o.Property( + x => x.DecimalLowScale, + map => + { + map.Type(NHibernateUtil.Decimal); + map.Precision(5); + map.Scale(2); + map.Column(GetColumnName(nameof(NumericEntity.DecimalLowScale))); + }); + o.Property( + x => x.Currency, + map => + { + map.Type(NHibernateUtil.Currency); + map.Column(GetColumnName(nameof(NumericEntity.Currency))); + }); + o.Property( + x => x.Double, + map => + { + map.Type(NHibernateUtil.Double); + map.Column(GetColumnName(nameof(NumericEntity.Double))); + }); + o.Property( + x => x.Float, + map => + { + map.Type(NHibernateUtil.Single); + map.Column(GetColumnName(nameof(NumericEntity.Float))); + }); + }); + + _properties = typeof(NumericEntity) + .GetProperties() + .Where(o => !_ignoreProperties.Contains(o.Name)) + .Select(o => new PropertyMetadata(o)) + .ToList(); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void Configure(Configuration configuration) + { + configuration.SetProperty(Cfg.Environment.OracleUseBinaryFloatingPointTypes, "true"); + } + + protected override void OnSetUp() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + s.Save(_originalEntity); + t.Commit(); + } + } + + protected override void OnTearDown() + { + base.OnTearDown(); + + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + s.CreateQuery("delete from NumericEntity").ExecuteUpdate(); + t.Commit(); + } + } + + /// + /// Tested drivers: + /// - SqlServerCeDriver (net461) + /// - OracleManagedDataClientDriver (net461, netcoreapp2.0) + /// - OracleClientDriver (netcoreapp2.0) + /// - OdbcDriver - SqlServer (net461, netcoreapp2.0) + /// - Sql2008ClientDriver (net461, netcoreapp2.0) + /// - NpgsqlDriver (net461, netcoreapp2.0) + /// - MySqlDataDriver (net461, netcoreapp2.0) + /// - FirebirdClientDriver (net461, netcoreapp2.0) + /// - SQLite20Driver (net461, netcoreapp2.0) + /// The following drivers fails the test: + /// - OracleManagedDataClientDriver: + /// 1) Property 'Short' returned type is not the same as original + /// Expected: System.Int16 + /// But was: System.Int32 + /// 2) Property 'Integer' returned type is not the same as original + /// Expected: System.Int32 + /// But was: System.Int64 + /// 3) Property 'Long' returned type is not the same as original + /// Expected: System.Int64 + /// But was: System.Decimal + /// 4) Property 'DecimalLowScale' returned type is not the same as original + /// Expected: System.Decimal + /// But was: System.Single + /// - MySqlDataDriver: + /// 1) Property 'Float' value is not the same as original + /// Expected: 3.14159275E+15f + /// But was: 3.14159007E+15f + /// - SQLite20Driver: + /// 1) Property 'UnsignedShort' returned type is not the same as original + /// Expected: System.UInt16 + /// But was: System.Int64 + /// 2) Property 'UnsignedInteger' returned type is not the same as original + /// Expected: System.UInt32 + /// But was: System.Int64> + /// 3) Property 'UnsignedLong' returned type is not the same as original + /// Expected: System.UInt64 + /// But was: System.Int64 + /// 4) Property 'Decimal' returned type is not the same as original + /// Expected: System.Decimal + /// But was: System.Double + /// 5) Property 'DecimalLowScale' returned type is not the same as original + /// Expected: System.Decimal + /// But was: System.Double + /// 6) Property 'Currency' returned type is not the same as original + /// Expected: System.Decimal + /// But was: System.Double + /// 7) Property 'Float' returned type is not the same as original + /// Expected: System.Single + /// But was: System.Double + /// + [Explicit] + public void TestRawTypesAndValuesAfterSave() + { + TestTypesAndValuesAfterSave(true); + } + + [Test] + public void TestTypesAndValuesAfterSave() + { + TestTypesAndValuesAfterSave(false); + } + + private void TestTypesAndValuesAfterSave(bool testRawValue) + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var selectColumns = string.Join(",", _properties.Select(o => o.ColumnName)); + var query = session.CreateSQLQuery($"select {selectColumns} from NumericEntity"); + foreach (var property in _properties) + { + query.AddScalar(property.ColumnName, testRawValue ? ObjectTypeInstance : NHibernateUtil.GuessType(property.Type)); + } + + var result = (object[]) query.UniqueResult(); + Assert.Multiple(() => + { + for (var i = 0; i < _properties.Count; i++) + { + var value = result[i]; + var property = _properties[i]; + Assert.That(value.GetType(), Is.EqualTo(property.Type), $"Property '{property.Name}' returned type is not the same as original"); + Assert.That(value, Is.EqualTo(property.GetValue(_originalEntity)), $"Property '{property.Name}' value is not the same as original"); + } + }); + + transaction.Commit(); + } + } + + /// + /// Tested driver (testRawValue=false was used): + /// - SqlServerCeDriver (net461) + /// - OracleManagedDataClientDriver (net461, netcoreapp2.0) + /// - OracleClientDriver (net461, netcoreapp2.0) + /// - OdbcDriver - SqlServer (net461, netcoreapp2.0) + /// - Sql2008ClientDriver (net461, netcoreapp2.0) + /// - NpgsqlDriver (net461, netcoreapp2.0) + /// - MySqlDataDriver (net461, netcoreapp2.0) + /// - FirebirdClientDriver (net461, netcoreapp2.0) + /// - SQLite20Driver (net461, netcoreapp2.0) + /// The following drivers fails the test: + /// - MySqlDataDriver: + /// 1) Expression 'pFloat % Float' value is not as expected + /// Expected: 0.0f + /// But was: 3.14159007E+15f + /// 2) Expression 'Float % pFloat' value is not as expected + /// Expected: 0.0f + /// But was: 67445760.0f + /// 3) Expression 'pFloat - Float' value is not as expected + /// Expected: 0.0f + /// But was: -67445760.0f + /// 4) Expression 'Float - pFloat' value is not as expected + /// Expected: 0.0f + /// But was: 67445760.0f + /// + /// Whether to execute a sql query without adding casts and without having a that alters + /// the returned value form the driver. Use this in order to check what value and type is returned for a specific arithmetic operation + /// by the driver. + [TestCase(false)] + public void TestArithmeticOperators(bool testRawValue) + { + var tests = new List(); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Short), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Integer), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Long), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.UnsignedShort), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.UnsignedInteger), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.Short), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Integer), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Long), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.UnsignedShort), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.UnsignedInteger), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.Integer), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.Long), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.UnsignedShort), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.UnsignedInteger), typeof(long), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.Long), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.UnsignedShort), typeof(int), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.UnsignedInteger), typeof(uint), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.UnsignedLong), typeof(ulong), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.UnsignedShort), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.UnsignedInteger), typeof(uint), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.UnsignedLong), typeof(ulong), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.UnsignedInteger), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.UnsignedLong), typeof(ulong), tests); + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.Currency), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.UnsignedLong), nameof(NumericEntity.Float), typeof(float), tests); + + FillOperationTests(nameof(NumericEntity.Decimal), nameof(NumericEntity.Decimal), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Decimal), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.Decimal), nameof(NumericEntity.Currency), typeof(decimal), tests); + + FillOperationTests(nameof(NumericEntity.DecimalLowScale), nameof(NumericEntity.DecimalLowScale), typeof(decimal), tests); + FillOperationTests(nameof(NumericEntity.DecimalLowScale), nameof(NumericEntity.Currency), typeof(decimal), tests); + + FillOperationTests(nameof(NumericEntity.Currency), nameof(NumericEntity.Currency), typeof(decimal), tests); + + FillOperationTests(nameof(NumericEntity.Double), nameof(NumericEntity.Double), typeof(double), tests); + FillOperationTests(nameof(NumericEntity.Double), nameof(NumericEntity.Float), typeof(double), tests); + + FillOperationTests(nameof(NumericEntity.Float), nameof(NumericEntity.Float), typeof(float), tests); + + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + TestOperators(tests, testRawValue, session); + transaction.Commit(); + } + } + + private static void FillOperationTests( + string leftPropertyName, + string rightPropertyName, + System.Type operatorType, + List tests) + { + FillOperationTests('+', leftPropertyName, rightPropertyName, operatorType, tests); + FillOperationTests('-', leftPropertyName, rightPropertyName, operatorType, tests); + FillOperationTests('*', leftPropertyName, rightPropertyName, operatorType, tests); + FillOperationTests('/', leftPropertyName, rightPropertyName, operatorType, tests); + FillOperationTests('%', leftPropertyName, rightPropertyName, operatorType, tests); + } + + private static void FillOperationTests( + char @operator, + string leftPropertyName, + string rightPropertyName, + System.Type operatorType, + List tests) + { + tests.Add(new OperatorTest(@operator, leftPropertyName, rightPropertyName, operatorType, false, false)); + tests.Add(new OperatorTest(@operator, leftPropertyName, rightPropertyName, operatorType, true, false)); + tests.Add(new OperatorTest(@operator, leftPropertyName, rightPropertyName, operatorType, false, true)); + } + + private void TestOperators(List tests, bool testRawValue, ISession session) + { + Console.WriteLine($"Total tests: {tests.Count}"); + Assert.Multiple( + () => + { + var skippedTests = 0; + foreach (var test in tests) + { + var leftProperty = _properties.FirstOrDefault(o => o.Name == test.LeftPropertyName); + var rightProperty = _properties.FirstOrDefault(o => o.Name == test.RightPropertyName); + if (leftProperty == null || + rightProperty == null) + { + skippedTests++; + continue; // Not supported + } + + if (AllowedDriverDifferentValues.TryGetValue(Sfi.ConnectionProvider.Driver.GetType(), out var allowedFailures) && + allowedFailures.Contains(test)) + { + skippedTests++; + continue; + } + + var lambda = GetLambdaExpression(leftProperty, test, rightProperty); + object expectedResult; + try + { + expectedResult = lambda.Compile().DynamicInvoke(_originalEntity); + } + catch (TargetInvocationException e) when (e.InnerException is OverflowException) + { + skippedTests++; + continue; // Skip overflows (e.g. UnsignedShort - UnsignedInteger) + } + + TestOperator(testRawValue, session, leftProperty, test, rightProperty, expectedResult); + } + + Console.WriteLine($"Skipped tests: {skippedTests}"); + }); + } + + private void TestOperator( + bool testRawValue, + ISession session, + PropertyMetadata leftProperty, + OperatorTest test, + PropertyMetadata rightProperty, + object expectedResult) + { + object result; + var selectExpression = + $"{(test.LeftAsParameter ? $"p{test.LeftPropertyName}" : test.LeftPropertyName)} " + + $"{test.Operator} " + + $"{(test.RightAsParameter ? $"p{test.RightPropertyName}" : test.RightPropertyName)}"; + try + { + result = testRawValue + ? ExecuteSqlQuery(true, session, leftProperty, test, rightProperty) + : ExecuteLinqQuery(session, leftProperty, test, rightProperty); + } + catch (GenericADOException e) + { + Assert.Fail($" Expression '{selectExpression}' failed to execute.{Environment.NewLine} {e}"); + return; + } + + // Don't assert the type as there will be a lot of failures, the important thing is that the value is correct. + //Assert.That(result.GetType(), Is.EqualTo(expectedResult.GetType()), $"Expression '{select}' returned type is not as expected"); + try + { + Assert.That(result, Is.EqualTo(expectedResult), $"Expression '{selectExpression}' value is not as expected"); + } + catch (OverflowException) // Can happen when a negative value is returned for an unsigned number + { + // Generate the same message as NUnit + Assert.Fail( + $" Expression '{selectExpression}' value is not as expected." + Environment.NewLine + + $" Expected: {expectedResult}" + Environment.NewLine + + $" But was: {result}" + Environment.NewLine); + } + } + + private object ExecuteSqlQuery( + bool testRawValue, + ISession session, + PropertyMetadata leftProperty, + OperatorTest test, + PropertyMetadata rightProperty) + { + var selectSql = $"{(test.LeftAsParameter ? "?" : GetColumnName(test.LeftPropertyName))} " + + $"{test.Operator} " + + $"{(test.RightAsParameter ? "?" : GetColumnName(test.RightPropertyName))}"; + var query = session.CreateSQLQuery($"select {selectSql} as col1 from NumericEntity") + .AddScalar( + "col1", + testRawValue ? ObjectTypeInstance : NHibernateUtil.GuessType(test.OperandType)); + if (test.LeftAsParameter) + { + query.SetParameter( + 0, + leftProperty.GetValue(_originalEntity), + Sfi.GetClassMetadata(nameof(NumericEntity)).GetPropertyType(test.LeftPropertyName)); + } + else if (test.RightAsParameter) + { + query.SetParameter( + 0, + rightProperty.GetValue(_originalEntity), + Sfi.GetClassMetadata(nameof(NumericEntity)).GetPropertyType(test.RightPropertyName)); + } + + return query.UniqueResult(); + } + + private object ExecuteLinqQuery( + ISession session, + PropertyMetadata leftProperty, + OperatorTest test, + PropertyMetadata rightProperty) + { + var query = session.Query(); + var lambda = GetLambdaExpression(leftProperty, test, rightProperty); + var queryable = SelectDefinition.MakeGenericMethod(typeof(NumericEntity), test.OperandType) + .Invoke(null, new object[] {query, lambda}); + try + { + return FirstDefinition.MakeGenericMethod(test.OperandType) + .Invoke(null, new[] { queryable }); + } + catch (TargetInvocationException e) + { + ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + } + + return null; + } + + private LambdaExpression GetLambdaExpression( + PropertyMetadata leftProperty, + OperatorTest test, + PropertyMetadata rightProperty) + { + var parameter = Expression.Parameter(typeof(NumericEntity), "o"); + var leftExpression = test.LeftAsParameter + ? (Expression) Expression.Constant(leftProperty.GetValue(_originalEntity)) + : Expression.MakeMemberAccess(parameter, leftProperty.PropertyInfo); + if (leftProperty.Type != test.OperandType) + { + leftExpression = Expression.Convert(leftExpression, test.OperandType); + } + + var rightExpression = test.RightAsParameter + ? (Expression) Expression.Constant(rightProperty.GetValue(_originalEntity)) + : Expression.MakeMemberAccess(parameter, rightProperty.PropertyInfo); + if (rightProperty.Type != test.OperandType) + { + rightExpression = Expression.Convert(rightExpression, test.OperandType); + } + + var operatorExpression = GetOperatorExpression(test, leftExpression, rightExpression); + + return (LambdaExpression) LambdaDefinition + .MakeGenericMethod(typeof(Func<,>).MakeGenericType(typeof(NumericEntity), test.OperandType)) + .Invoke(null, new object[] {operatorExpression, new[] {parameter}}); + } + + private static Expression GetOperatorExpression( + OperatorTest test, + Expression leftExpression, + Expression rightExpression) + { + switch (test.Operator) + { + case '+': + return Expression.AddChecked(leftExpression, rightExpression); + case '-': + return Expression.SubtractChecked(leftExpression, rightExpression); + case '*': + return Expression.MultiplyChecked(leftExpression, rightExpression); + case '/': + return Expression.Divide(leftExpression, rightExpression); + case '%': + return Expression.Modulo(leftExpression, rightExpression); + default: + throw new InvalidOperationException("Invalid operator"); + } + } + + public class NumericEntity + { + public virtual int Id { get; set; } + public virtual short Short { get; set; } + public virtual int Integer { get; set; } + public virtual long Long { get; set; } + public virtual ushort UnsignedShort { get; set; } + public virtual uint UnsignedInteger { get; set; } + public virtual ulong UnsignedLong { get; set; } + public virtual decimal Decimal { get; set; } + public virtual decimal DecimalLowScale { get; set; } + public virtual decimal Currency { get; set; } + public virtual double Double { get; set; } + public virtual float Float { get; set; } + } + + private class OperatorTest + { + public OperatorTest( + char @operator, + string leftPropertyName, + string rightPropertyName, + System.Type operandType, + bool leftAsParameter, + bool rightAsParameter) + { + Operator = @operator; + LeftPropertyName = leftPropertyName; + RightPropertyName = rightPropertyName; + OperandType = operandType; + LeftAsParameter = leftAsParameter; + RightAsParameter = rightAsParameter; + } + + public char Operator { get; } + public string LeftPropertyName { get; } + public string RightPropertyName { get; } + public System.Type OperandType { get; } + public bool LeftAsParameter { get; } + public bool RightAsParameter { get; } + + public override bool Equals(object obj) + { + if (!(obj is OperatorTest test)) + { + return false; + } + + return + test.Operator == Operator && + test.LeftPropertyName == LeftPropertyName && + test.RightPropertyName == RightPropertyName && + test.OperandType == OperandType && + test.LeftAsParameter == LeftAsParameter && + test.RightAsParameter == RightAsParameter; + } + + public override int GetHashCode() + { + return + Operator.GetHashCode() ^ + LeftPropertyName.GetHashCode() ^ + RightPropertyName.GetHashCode() ^ + OperandType.GetHashCode() ^ + LeftAsParameter.GetHashCode() ^ + RightAsParameter.GetHashCode(); + } + } + + private class PropertyMetadata + { + private readonly MethodInfo _getterMethodInfo; + + public PropertyMetadata(PropertyInfo propertyInfo) + { + PropertyInfo = propertyInfo; + Name = propertyInfo.Name; + ColumnName = GetColumnName(Name); + Type = propertyInfo.PropertyType; + _getterMethodInfo = propertyInfo.GetMethod; + } + + public string Name { get; } + + public string ColumnName { get; } + + public System.Type Type { get; } + + public PropertyInfo PropertyInfo { get; } + + public object GetValue(object instance) + { + return _getterMethodInfo.Invoke(instance, null); + } + } + + private class ObjectType : PrimitiveType + { + public ObjectType() : base(new SqlType(DbType.Object)) + { + } + + public override string Name => "Object"; + public override System.Type ReturnedClass => typeof(object); + public override void Set(DbCommand cmd, object value, int index, ISessionImplementor session) + { + cmd.Parameters[index].Value = value; + } + + public override object Get(DbDataReader rs, int index, ISessionImplementor session) + { + return rs[index]; + } + + public override object Get(DbDataReader rs, string name, ISessionImplementor session) + { + return rs[name]; + } + + public override System.Type PrimitiveClass => typeof(object); + + public override object DefaultValue => null; + + public override string ObjectToSQLString(object value, Dialect.Dialect dialect) + { + return value?.ToString(); + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 6b693ddbc4a..f75bf3a9070 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -217,12 +217,12 @@ public void ConstantInWhereDoesNotCauseManyKeys() select c); var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi); var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters); - var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1); - var k1 = ExpressionKeyVisitor.Visit(expression, parameters1); + var parameters1 = ExpressionParameterVisitor.Visit(preTransformResult); + var k1 = ExpressionKeyVisitor.Visit(preTransformResult.Expression, parameters1, Sfi); var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters); - var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2); - var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2); + var parameters2 = ExpressionParameterVisitor.Visit(preTransformResult2); + var k2 = ExpressionKeyVisitor.Visit(preTransformResult2.Expression, parameters2, Sfi); Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index 4050c7ddb97..aeea060b51e 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -48,5 +48,42 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public void ConditionalNavigationProperty() + { + EnumStoredAsString? type = null; + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToList(); + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToList(); + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToList(); + } + + [Test] + public void CanQueryComplexExpressionOnEnumStoredAsString() + { + var type = EnumStoredAsString.Unspecified; + var query = (from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToList(); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index e047732d7ad..daf14b9cd18 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -34,7 +34,8 @@ protected override string[] Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.DynamicUser.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 920fa565129..cab27fe9dd5 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -76,6 +76,34 @@ public void UsingTwoEntityParameters() 2); } + [Test] + public void UsingEntityEnumerableParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = db.DynamicUsers.First(); + AssertTotalParameters( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1); + } + + [Test] + public void UsingEntityEnumerableListParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {db.DynamicUsers.First()}; + AssertTotalParameters( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1); + } + [Test] public void UsingValueTypeParameterTwice() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs new file mode 100644 index 00000000000..2cb87bc50b2 --- /dev/null +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -0,0 +1,444 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Dynamic.Core; +using System.Linq.Expressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Type; +using NUnit.Framework; +using Remotion.Linq.Clauses; + +namespace NHibernate.Test.Linq +{ + public class ParameterTypeLocatorTests : LinqTestCase + { + [Test] + public void AddIntegerTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5", o => o is Int32Type}, + }, + db.Users.Where(o => o.Id + 5 > 2.1), + db.Users.Where(o => 2.1 < 5 + o.Id) + ); + } + + [Test] + public void AddDecimalTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DecimalType}, + {"5.2", o => o is DecimalType}, + }, + db.Users.Where(o => o.Id + 5.2m > 2.1m), + db.Users.Where(o => 2.1m < 5.2m + o.Id) + ); + } + + [Test] + public void SubtractFloatTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5.2", o => o is SingleType}, + }, + db.Users.Where(o => o.Id - 5.2f > 2.1), + db.Users.Where(o => 2.1 < 5.2f - o.Id) + ); + } + + [Test] + public void GreaterThanTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is Int32Type} + }, + db.Users.Where(o => o.Id > 2.1), + db.Users.Where(o => 2.1 > o.Id) + ); + } + + [Test] + public void EqualStringEnumTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void EqualStringTest() + { + AssertResults( + new Dictionary> + { + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City == "London"), + db.Orders.Where(o => "London" == o.ShippingAddress.City) + ); + } + + [Test] + public void EqualEntityTest() + { + var order = new Order(); + AssertResults( + new Dictionary> + { + { + $"value({typeof(Order).FullName})", + o => o is ManyToOneType manyToOne && manyToOne.Name == typeof(Order).FullName + } + }, + db.Orders.Where(o => o == order), + db.Orders.Where(o => order == o) + ); + } + + [Test] + public void DoubleEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large && o.Enum2 == EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High == o.Enum2 && EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void NotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1) + ); + } + + [Test] + public void DoubleNotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large || o.NullableEnum2 != EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High != o.NullableEnum2 || o.Enum1 != EnumStoredAsString.Large) + ); + } + + [Test] + public void CoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum1 ?? EnumStoredAsString.Large)) + ); + } + + [Test] + public void DoubleCoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + db.Users.Where(o => ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1)) + ); + } + + [Test] + public void ConditionalTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1)) + ); + } + + [Test] + public void DoubleConditionalTest() + { + AssertResults( + new Dictionary> + { + {"0", o => o is PersistentEnumType}, + {"2", o => o is EnumStoredAsStringType}, + {"Small", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.Enum2 != EnumStoredAsInt32.Unspecified + ? (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Enum2 != EnumStoredAsInt32.Unspecified + ? EnumStoredAsString.Small + : (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1))) + ); + } + + [Test] + public void CoalesceMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NotMappedUser ?? o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o ?? o.NotMappedUser).Enum1) + ); + } + + [Test] + public void ConditionalMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"\"test\"", o => o is AnsiStringType}, + }, + db.Users.Where(o => (o.Name == "test" ? o.NotMappedUser : o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Name == "test" ? o : o.NotMappedUser).Enum1) + ); + } + + [Test] + public void DynamicMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, + db.DynamicUsers.Where("Properties.Name == @0", "test"), + db.DynamicUsers.Where("@0 == Properties.Name", "test") + ); + } + + [Test] + public void DynamicDictionaryMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, +#pragma warning disable CS0252 + db.DynamicUsers.Where(o => o.Settings["Property1"] == "test"), +#pragma warning restore CS0252 +#pragma warning disable CS0253 + db.DynamicUsers.Where(o => "test" == o.Settings["Property1"]) +#pragma warning restore CS0253 + ); + } + + [Test] + public void AssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Component = new UserComponent {Property1 = "prop1"}} + ); + } + + [Test] + public void AssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User + { + Component = new UserComponent {OtherComponent = new UserComponent2 {OtherProperty1 = "other"}} + } + ); + } + + [Test] + public void AnonymousAssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AnonymousAssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {Property1 = "prop1"}} + ); + } + + [Test] + public void AnonymousAssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {OtherComponent = new {OtherProperty1 = "other"}}} + ); + } + + private void AssertResults( + Dictionary> expectedResults, + params IQueryable[] queries) + { + foreach (var query in queries) + { + AssertResult(expectedResults, query); + } + } + + private void AssertResult( + Dictionary> expectedResults, + IQueryable query) + { + AssertResult(expectedResults, QueryMode.Select, query.Expression, query.Expression.Type); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpressionFromAnonymous(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + Expression expression, + System.Type targetType) + { + var result = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(queryMode, Sfi)); + var parameters = ExpressionParameterVisitor.Visit(result); + expression = result.Expression; + var queryModel = NhRelinqQueryParser.Parse(expression); + ParameterTypeLocator.SetParameterTypes(parameters, queryModel, targetType, Sfi); + Assert.That(parameters.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of parameters"); + foreach (var pair in parameters) + { + var origCulture = CultureInfo.CurrentCulture; + try + { + CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; + var expressionText = pair.Key.ToString(); + Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); + Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Type?.Name ?? "null"}"); + } + finally + { + CultureInfo.CurrentCulture = origCulture; + } + } + } + } +} diff --git a/src/NHibernate.Test/Linq/SelectionTests.cs b/src/NHibernate.Test/Linq/SelectionTests.cs index e68c87c5cac..4654db4e1bf 100644 --- a/src/NHibernate.Test/Linq/SelectionTests.cs +++ b/src/NHibernate.Test/Linq/SelectionTests.cs @@ -3,8 +3,12 @@ using System.Linq; using NHibernate.DomainModel.NHSpecific; using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Driver; +using NHibernate.Exceptions; +using NHibernate.Proxy; using NHibernate.Type; using NUnit.Framework; +using static NHibernate.Linq.ExpressionEvaluation; namespace NHibernate.Test.Linq { @@ -130,7 +134,7 @@ public void CanSelectNestedMemberInitExpression() { InvalidLoginAttempts = user.InvalidLoginAttempts, Dto2 = new UserDto2 - { + { RegisteredAt = user.RegisteredAt, Enum = user.Enum2 }, @@ -153,7 +157,7 @@ public void CanSelectNestedMemberInitWithinNewExpression() user.Name, user.InvalidLoginAttempts, Dto = new UserDto2 - { + { RegisteredAt = user.RegisteredAt, Enum = user.Enum2 }, @@ -301,6 +305,161 @@ public void CanSelectWithAggregateSubQuery() Assert.AreEqual(4, timesheets[2].EntryCount); } + [Test] + public void CanSelectConditional() + { + // SqlServerCeDriver and OdbcDriver have an issue matching the case statements inside select and order by statement, + // when having one or more parameters inside them. Throws with the following error: + // ORDER BY items must appear in the select list if SELECT DISTINCT is specified. + if (!(Sfi.ConnectionProvider.Driver is OdbcDriver) && !(Sfi.ConnectionProvider.Driver is SqlServerCeDriver)) + { + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => o.ShippedTo.Contains("test") ? o.ShippedTo : o.Customer.CompanyName) + .OrderBy(o => o) + .Distinct() + .ToList(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(2)); + } + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate) + .FirstOrDefault(); + + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "case"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : o.ShippedTo) + }) + .FirstOrDefault(); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : null) + }) + .FirstOrDefault(); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : "default") + }) + .FirstOrDefault(); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + + var defaultValue = "default"; + using (var sqlLog = new SqlLogSpy()) + { + var q = db.Orders.Where(o => o.Customer.CustomerId == "test") + .Select(o => new + { + Value = o.OrderDate.HasValue + ? o.Customer.CompanyName + : (o.ShippingDate.HasValue + ? o.Shipper.CompanyName + "Shipper" + : defaultValue) + }) + .FirstOrDefault(); + + var log = sqlLog.GetWholeLog(); + Assert.That(FindAllOccurrences(log, "as col"), Is.EqualTo(1)); + } + } + + [Test] + public void CanSelectConditionalSubQuery() + { + if (!Dialect.SupportsScalarSubSelects) + Assert.Ignore(Dialect.GetType().Name + " does not support scalar sub-queries"); + + var list = db.Customers + .Select(c => new + { + Date = db.Orders.Where(o => o.Customer.CustomerId == c.CustomerId) + .Select(o => o.OrderDate.HasValue ? o.OrderDate : o.ShippingDate) + .Max() + }) + .ToList(); + Assert.That(list, Has.Count.GreaterThan(0)); + + var list2 = db.Orders + .Select( + o => new + { + UnitPrice = o.Freight.HasValue + ? o.OrderLines.Where(l => l.Discount == 1) + .Select(l => l.Product.UnitPrice.HasValue ? l.Product.UnitPrice : l.UnitPrice) + .Max() + : o.OrderLines.Where(l => l.Discount == 0) + .Select(l => l.Product.UnitPrice.HasValue ? l.Product.UnitPrice : l.UnitPrice) + .Max() + }) + .ToList(); + Assert.That(list2, Has.Count.GreaterThan(0)); + + var list3 = db.Orders + .Select(o => new + { + Date = o.OrderLines.Any(l => o.OrderDate.HasValue) + ? db.Employees + .Select(e => e.BirthDate.HasValue ? e.BirthDate : e.HireDate) + .Max() + : o.Employee.Superior != null ? o.Employee.Superior.BirthDate : o.Employee.BirthDate + }) + .ToList(); + Assert.That(list3, Has.Count.GreaterThan(0)); + + var list4 = db.Orders + .Select(o => new + { + Employee = db.Employees.Any(e => e.Superior != null) + ? db.Employees + .Where(e => e.Superior != null) + .Select(e => e.Superior).FirstOrDefault() + : o.Employee.Superior != null ? o.Employee.Superior : o.Employee + }) + .ToList(); + Assert.That(list4, Has.Count.GreaterThan(0)); + } + [Test, KnownBug("NH-3045")] public void CanSelectFirstElementFromChildCollection() { @@ -410,56 +569,56 @@ public void CanSelectConditionalKnownTypes() if (!Dialect.SupportsScalarSubSelects) Assert.Ignore(Dialect.GetType().Name + " does not support scalar sub-queries"); - var moreThanTwoOrderLinesBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : false }).ToList(); + var moreThanTwoOrderLinesBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : false, Param = true }).ToList(); Assert.That(moreThanTwoOrderLinesBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410)); - var moreThanTwoOrderLinesNBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : (bool?)null }).ToList(); + var moreThanTwoOrderLinesNBool = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? true : (bool?)null, Param = (bool?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNBool.Count(x => x.HasMoreThanTwo == true), Is.EqualTo(410)); - var moreThanTwoOrderLinesShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short)1 : (short)0 }).ToList(); + var moreThanTwoOrderLinesShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short)1 : (short)0, Param = (short)0 }).ToList(); Assert.That(moreThanTwoOrderLinesShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesNShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short?)1 : (short?)null }).ToList(); + var moreThanTwoOrderLinesNShort = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? (short?)1 : (short?)null, Param = (short?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNShort.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : 0 }).ToList(); + var moreThanTwoOrderLinesInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : 0, Param = 1 }).ToList(); Assert.That(moreThanTwoOrderLinesInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesNInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : (int?)null }).ToList(); + var moreThanTwoOrderLinesNInt = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1 : (int?)null, Param = (int?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNInt.Count(x => x.HasMoreThanTwo == 1), Is.EqualTo(410)); - var moreThanTwoOrderLinesDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : 0m }).ToList(); + var moreThanTwoOrderLinesDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : 0m, Param = 1m }).ToList(); Assert.That(moreThanTwoOrderLinesDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : (decimal?)null }).ToList(); + var moreThanTwoOrderLinesNDecimal = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1m : (decimal?)null, Param = (decimal?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNDecimal.Count(x => x.HasMoreThanTwo == 1m), Is.EqualTo(410)); - var moreThanTwoOrderLinesSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : 0f }).ToList(); + var moreThanTwoOrderLinesSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : 0f, Param = 1f }).ToList(); Assert.That(moreThanTwoOrderLinesSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410)); - var moreThanTwoOrderLinesNSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : (float?)null }).ToList(); + var moreThanTwoOrderLinesNSingle = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1f : (float?)null, Param = (float?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNSingle.Count(x => x.HasMoreThanTwo == 1f), Is.EqualTo(410)); - var moreThanTwoOrderLinesDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : 0d }).ToList(); + var moreThanTwoOrderLinesDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : 0d, Param = 1d }).ToList(); Assert.That(moreThanTwoOrderLinesDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : (double?)null }).ToList(); + var moreThanTwoOrderLinesNDouble = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? 1d : (double?)null, Param = (double?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNDouble.Count(x => x.HasMoreThanTwo == 1d), Is.EqualTo(410)); - var moreThanTwoOrderLinesString = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? "yes" : "no" }).ToList(); + var moreThanTwoOrderLinesString = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? "yes" : "no", Param = "no" }).ToList(); Assert.That(moreThanTwoOrderLinesString.Count(x => x.HasMoreThanTwo == "yes"), Is.EqualTo(410)); var now = DateTime.Now.Date; - var moreThanTwoOrderLinesDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate.Value : now }).ToList(); + var moreThanTwoOrderLinesDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate.Value : now, Param = now }).ToList(); Assert.That(moreThanTwoOrderLinesDateTime.Count(x => x.HasMoreThanTwo != now), Is.EqualTo(410)); - var moreThanTwoOrderLinesNDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate : null }).ToList(); + var moreThanTwoOrderLinesNDateTime = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.OrderDate : null, Param = (DateTime?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNDateTime.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410)); - var moreThanTwoOrderLinesGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : Guid.Empty }).ToList(); + var moreThanTwoOrderLinesGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : Guid.Empty, Param = Guid.Empty }).ToList(); Assert.That(moreThanTwoOrderLinesGuid.Count(x => x.HasMoreThanTwo != Guid.Empty), Is.EqualTo(410)); - var moreThanTwoOrderLinesNGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : (Guid?)null }).ToList(); + var moreThanTwoOrderLinesNGuid = db.Orders.Select(o => new { Id = o.OrderId, HasMoreThanTwo = o.OrderLines.Count() > 2 ? o.Shipper.Reference : (Guid?)null, Param = (Guid?)null }).ToList(); Assert.That(moreThanTwoOrderLinesNGuid.Count(x => x.HasMoreThanTwo != null), Is.EqualTo(410)); } @@ -492,6 +651,674 @@ public void CanSelectConditionalEntityValueWithEntityComparison() Assert.That(fatherInsteadOfChild, Has.Exactly(2).EqualTo("5678")); } + [Test] + public void CanSelectModulus() + { + var list = db.Animals.Select(a => new { Sql = a.Id % 2.1f, a.Id }).ToList(); + Assert.That(list.Select(o => o.Sql), Is.EqualTo(list.Select(o => o.Id % 2.1f)).Within(GetTolerance())); + var list1 = db.Animals.Select(a => new { Sql = a.Id % 2.1d, a.Id }).ToList(); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id % 2.1d)).Within(GetTolerance())); + var list2 = db.Animals.Select(a => new { Sql = a.BodyWeight % 2.1f, a.BodyWeight }).ToList(); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.BodyWeight % 2.1f)).Within(GetTolerance())); + var list3 = db.Animals.Select(a => new { Sql = a.Id % 2.1m, a.Id }).ToList(); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id % 2.1m))); + var list4 = db.Animals.Select(a => new { Sql = a.Id % 2, a.Id }).ToList(); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id % 2))); + var list5 = db.Animals.Select(a => new { Sql = a.Id % 2L, a.Id }).ToList(); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id % 2L))); + var list7 = db.Animals.Select(a => new { Sql = a.BodyWeight % 2, a.BodyWeight }).ToList(); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.BodyWeight % 2))); + var list8 = db.Animals.Select(a => new { Sql = a.BodyWeight % 2L, a.BodyWeight }).ToList(); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight % 2L))); + var list9 = db.Products.Select(a => new { Sql = a.UnitPrice % 2L, a.UnitPrice }).ToList(); + Assert.That(list9.Select(o => o.Sql), Is.EqualTo(list9.Select(o => o.UnitPrice % 2L))); + var list10 = db.Products.Select(a => new { Sql = a.UnitPrice % 2, a.UnitPrice }).ToList(); + Assert.That(list10.Select(o => o.Sql), Is.EqualTo(list10.Select(o => o.UnitPrice % 2))); + } + + [Test] + public void CanSelectModulusSameExpression() + { + var list1 = db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1m, OriginalValue = a.Id }).ToList(); + Assert.That(list1.Select(o => o.CalculatedValue), Is.EqualTo(list1.Select(o => o.OriginalValue % 2.1m))); + var list2 = db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2L, OriginalValue = a.Id }).ToList(); + Assert.That(list2.Select(o => o.CalculatedValue), Is.EqualTo(list2.Select(o => o.OriginalValue % 2L))); + var list3 = db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1f, OriginalValue = a.Id }).ToList(); + Assert.That(list3.Select(o => o.CalculatedValue), Is.EqualTo(list3.Select(o => o.OriginalValue % 2.1f)).Within(GetTolerance())); + var list4 = db.Animals.Select(a => new ObjectDto { CalculatedValue = a.Id % 2.1d, OriginalValue = a.Id }).ToList(); + Assert.That(list4.Select(o => o.CalculatedValue), Is.EqualTo(list4.Select(o => o.OriginalValue % 2.1d)).Within(GetTolerance())); + } + + [Test] + public void CanForceDatabaseEvaluation() + { + var namedParameters = !(Sfi.ConnectionProvider.Driver is OdbcDriver); + Assert.That(GetSqlSelect(db.Animals.Select(a => DatabaseEval(() => 5))), Does.Contain(namedParameters ? "p0" : "?")); + Assert.That(GetSqlSelect(db.Products.Select(a => DatabaseEval(() => a.UnitPrice * 1234.4321m))), Does.Contain("*")); + Assert.That(FindAllOccurrences(GetSqlSelect(db.Products.Select(a => new + { + Server = DatabaseEval(() => a.UnitPrice * 1234.4321m), + Default = a.UnitPrice * 1234.4321m + })), "*"), Is.EqualTo(1)); + } + + [Test] + public void CanForceClientEvaluation() + { + var query = db.Animals.Select(a => ClientEval(() => a.Id + 5)); + Assert.That(GetSqlSelect(query), Does.Not.Contain("+")); + Assert.That(query.ToList(), Is.EqualTo(db.Animals.Select(a => a.Id + 5).ToList())); + + query = db.Animals.Select(a => ClientEval(() => a.SerialNumber.Length)); + Assert.That(GetSqlSelect(query), Does.Not.Contain("len(").And.Not.Contain("length(")); + Assert.That(query.ToList(), Is.EqualTo(db.Animals.Select(a => a.SerialNumber.Length).ToList())); + + var query2 = db.Animals.Select(a => ClientEval(() => a.SerialNumber.Substring(0, 1))); + Assert.That(GetSqlSelect(query2), Does.Not.Contain("substr(").And.Not.Contain("substring(")); + Assert.That(query2.ToList(), Is.EqualTo(db.Animals.Select(a => a.SerialNumber.Substring(0, 1)).ToList())); + + query2 = db.Animals.Select(a => ClientEval(() => a.Id % 2 == 0 ? a.SerialNumber : a.Description)); + Assert.That(GetSqlSelect(query2), Does.Not.Contain("case")); + Assert.That(query2.ToList(), Is.EqualTo(db.Animals.Select(a => a.Id % 2 == 0 ? a.SerialNumber : a.Description).ToList())); + + var query3 = db.Animals.Select(a => new + { + Client = ClientEval(() => a.Id % 2 == 0 ? a.SerialNumber.Substring(0, 1) : a.Description), + Server = a.Id % 2 == 0 ? a.SerialNumber.Substring(0, 1) : a.Description, + }).ToList(); + Assert.That(query3.Select(o => o.Client), Is.EqualTo(query3.Select(o => o.Server))); + } + + [Test] + public void CanSelectMultiplyOperator() + { + var list1 = db.Animals.Select(a => new { Sql = a.Id * 5, a.Id }).ToList(); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id * 5))); + var list2 = db.Animals.Select(a => new { Sql = a.Id * 12345.54321m, a.Id }).ToList(); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id * 12345.54321m))); + var list3 = db.Animals.Select(a => new { Sql = a.Id * 123.321f, a.Id }).ToList(); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id * 123.321f)).Within(GetTolerance())); + var list4 = db.Animals.Select(a => new { Sql = a.Id * 12345.54321d, a.Id }).ToList(); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id * 12345.54321d)).Within(GetTolerance())); + var list5 = db.Animals.Select(a => new { Sql = a.Id * 2L, a.Id }).ToList(); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id * 2L))); + + var list6 = db.Products.Select(a => new { Sql = a.UnitPrice * 12345.54321m, a.UnitPrice }).ToList(); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice * 12345.54321m))); + var list7 = db.Products.Select(a => new { Sql = a.UnitPrice * 12345L, a.UnitPrice }).ToList(); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice * 12345L))); + + var list8 = db.Animals.Select(a => new { Sql = a.BodyWeight * 12345.54321f, a.BodyWeight }).ToList(); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight * 12345.54321f))); + } + + [Test] + public void CanSelectDivideOperator() + { + var list1 = db.Animals.Select(a => new { Sql = a.Id / 5, a.Id }).ToList(); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id / 5))); + var list2 = db.Animals.Select(a => new { Sql = a.Id / 12345.54321m, a.Id }).ToList(); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id / 12345.54321m))); + var list3 = db.Animals.Select(a => new { Sql = a.Id / 12345.54321f, a.Id }).ToList(); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id / 12345.54321f)).Within(GetTolerance())); + var list4 = db.Animals.Select(a => new { Sql = a.Id / 12345.54321d, a.Id }).ToList(); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id / 12345.54321d)).Within(GetTolerance())); + var list5 = db.Animals.Select(a => new { Sql = a.Id / 2L, a.Id }).ToList(); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id / 2L))); + + var list6 = db.Products.Select(a => new { Sql = a.UnitPrice / 12345.54321m, a.UnitPrice }).ToList(); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice / 12345.54321m))); + var list7 = db.Products.Select(a => new { Sql = a.UnitPrice.Value / 12345L, a.UnitPrice }).ToList(); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice / 12345L))); + + var list8 = db.Animals.Select(a => new { Sql = a.BodyWeight / 12345.54321f, a.BodyWeight }).ToList(); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight / 12345.54321f)).Within(GetTolerance())); + } + + [Test] + public void CanSelectAddOperator() + { + var list1 = db.Animals.Select(a => new { Sql = a.Id + 5, a.Id }).ToList(); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id + 5))); + var list2 = db.Animals.Select(a => new { Sql = a.Id + 12345.54321m, a.Id }).ToList(); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id + 12345.54321m))); + var list3 = db.Animals.Select(a => new { Sql = a.Id + 12345.54321f, a.Id }).ToList(); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id + 12345.54321f)).Within(GetTolerance())); + var list4 = db.Animals.Select(a => new { Sql = a.Id + 12345.54321d, a.Id }).ToList(); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id + 12345.54321d))); + var list5 = db.Animals.Select(a => new { Sql = a.Id + 2L, a.Id }).ToList(); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id + 2L))); + + var list6 = db.Products.Select(a => new { Sql = a.UnitPrice + 12345.54321m, a.UnitPrice }).ToList(); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice + 12345.54321m))); + var list7 = db.Products.Select(a => new { Sql = a.UnitPrice + 12345L, a.UnitPrice }).ToList(); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice + 12345L))); + + var list8 = db.Animals.Select(a => new { Sql = a.BodyWeight + 12345.54321f, a.BodyWeight }).ToList(); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight + 12345.54321f))); + } + + [Test] + public void CanSelectSubtractOperator() + { + var list1 = db.Animals.Select(a => new { Sql = a.Id - 5, a.Id }).ToList(); + Assert.That(list1.Select(o => o.Sql), Is.EqualTo(list1.Select(o => o.Id - 5))); + var list2 = db.Animals.Select(a => new { Sql = a.Id - 12345.54321m, a.Id }).ToList(); + Assert.That(list2.Select(o => o.Sql), Is.EqualTo(list2.Select(o => o.Id - 12345.54321m))); + var list3 = db.Animals.Select(a => new { Sql = a.Id - 12345.54321f, a.Id }).ToList(); + Assert.That(list3.Select(o => o.Sql), Is.EqualTo(list3.Select(o => o.Id - 12345.54321f)).Within(GetTolerance())); + var list4 = db.Animals.Select(a => new { Sql = a.Id - 12345.54321d, a.Id }).ToList(); + Assert.That(list4.Select(o => o.Sql), Is.EqualTo(list4.Select(o => o.Id - 12345.54321d))); + var list5 = db.Animals.Select(a => new { Sql = a.Id - 2L, a.Id }).ToList(); + Assert.That(list5.Select(o => o.Sql), Is.EqualTo(list5.Select(o => o.Id - 2L))); + + var list6 = db.Products.Select(a => new { Sql = a.UnitPrice - 12345.54321m, a.UnitPrice }).ToList(); + Assert.That(list6.Select(o => o.Sql), Is.EqualTo(list6.Select(o => o.UnitPrice - 12345.54321m))); + var list7 = db.Products.Select(a => new { Sql = a.UnitPrice - 12345L, a.UnitPrice }).ToList(); + Assert.That(list7.Select(o => o.Sql), Is.EqualTo(list7.Select(o => o.UnitPrice - 12345L))); + + var list8 = db.Animals.Select(a => new { Sql = a.BodyWeight - 12345.54321f, a.BodyWeight }).ToList(); + Assert.That(list8.Select(o => o.Sql), Is.EqualTo(list8.Select(o => o.BodyWeight - 12345.54321f))); + } + + private class ObjectDto + { + public object CalculatedValue { get; set; } + + public int OriginalValue { get; set; } + } + + [Test] + public void CanSelectConditionalEntityValueWithEntityComparisonComplex() + { + var animal = db.Animals.Select( + a => new + { + Parent = a.Father != null || a.Mother != null ? (a.Father ?? a.Mother) : null, + ParentSerialNumber = a.Father != null || a.Mother != null ? (a.Father ?? a.Mother).SerialNumber : null, + Parent2 = a.Mother ?? a.Father, + a.Father, + a.Mother + }) + .FirstOrDefault(o => o.ParentSerialNumber == "5678"); + + Assert.That(animal, Is.Not.Null); + Assert.That(animal.Father, Is.Not.Null); + Assert.That(animal.Mother, Is.Not.Null); + Assert.That(animal.Parent, Is.Not.Null); + Assert.That(animal.Parent2, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(animal.Parent), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Parent2), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Father), Is.True); + Assert.That(NHibernateUtil.IsInitialized(animal.Mother), Is.True); + } + + [Test] + public void CanSelectConditionalEntityValueWithEntityCast() + { + var list = db.Animals.Select( + a => new + { + BodyWeight = (double?) (a is Cat + ? (a.Father ?? a.Mother).BodyWeight + : (a is Dog + ? (a.Mother ?? a.Father).BodyWeight + : (a.Father.Father.BodyWeight) + )) + }) + .ToList(); + Assert.That(list, Has.Exactly(1).With.Property("BodyWeight").Not.Null); + } + + [Test] + public void CanSelectBinaryClientSideTest() + { + var exception = Assert.Throws(() => + { + db.Animals.Select(a => a.FatherOrMother.BodyWeight + a.BodyWeight).ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. Cast expression '([a].FatherOrMother.BodyWeight + [a].BodyWeight)' to 'System.Nullable`1[System.Double]'.")); + + var list = db.Animals.Select(a => (double?) (a.FatherOrMother.BodyWeight + a.BodyWeight)).ToList(); + Assert.That(list, Has.Exactly(5).Null.And.Exactly(1).EqualTo(271d)); + + // Arithmetic operator + var list2 = db.Animals.Select(a => new + { + // Left side null + Client = (double?) (a.FatherOrMother.BodyWeight + a.BodyWeight + a.Father.BodyWeight), + Server = (double?) (a.Father ?? a.Mother).BodyWeight + a.BodyWeight + a.Father.BodyWeight, + // Right side null + Client2 = (double?) (a.BodyWeight - a.Father.BodyWeight - a.FatherOrMother.BodyWeight), + Server2 = (double?) a.BodyWeight - a.Father.BodyWeight - (a.Father ?? a.Mother).BodyWeight + }).ToList(); + Assert.That(list2.Select(o => o.Client), Is.EqualTo(list2.Select(o => o.Server))); + Assert.That(list2.Select(o => o.Client2), Is.EqualTo(list2.Select(o => o.Server2))); + + // Boolean logic operator + var list3 = db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.IsActive && true, + Server = u.Role.IsActive && true, + // Right side null + Client2 = true && u.NotMappedUser.Role.IsActive, + Server2 = true && u.Role.IsActive + }).ToList(); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + list3 = db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.IsActive || true, + Server = u.Role.IsActive || true, + // Right side null + Client2 = false || u.NotMappedUser.Role.IsActive, + Server2 = false || u.Role.IsActive + }).ToList(); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + // Comparison operator + list3 = db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.Id > 0, + Server = u.Role.Id > 0, + // Right side null + Client2 = 0 < u.NotMappedUser.Role.Id, + Server2 = 0 < u.Role.Id + }).ToList(); + Assert.That(list3.Select(o => o.Client), Is.EqualTo(list3.Select(o => o.Server))); + Assert.That(list3.Select(o => o.Client2), Is.EqualTo(list3.Select(o => o.Server2))); + + // Bitwise boolean operator + var list4 = db.Users.Select(u => new + { + // Left side null + Client = (bool?) (u.NotMappedUser.Role.IsActive | true), + Server = (bool?) (u.Role.IsActive | true), + // Right side null + Client2 = (bool?) (true | u.NotMappedUser.Role.IsActive), + Server2 = (bool?) (true | u.Role.IsActive) + }).ToList(); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + Assert.That(list4.Select(o => o.Client2), Is.EqualTo(list4.Select(o => o.Server2))); + + // Bitwise number operator + var list5 = db.Users.Select(u => new + { + // Left side null + Client = (int?) (u.NotMappedUser.Role.Id | 5), + Server = (int?) (u.Role.Id | 5), + // Right side null + Client2 = (int?) (5 | u.NotMappedUser.Role.Id), + Server2 = (int?) (5 | u.Role.Id) + }).ToList(); + Assert.That(list5.Select(o => o.Client), Is.EqualTo(list5.Select(o => o.Server))); + Assert.That(list5.Select(o => o.Client2), Is.EqualTo(list5.Select(o => o.Server2))); + + // Coalesce operator + var list6 = db.Users.Select(u => new + { + // Left side null + Client = u.NotMappedUser.Role.Name ?? u.NotMappedUser.Name, + Server = u.Role.Name ?? u.Name, + // Right side null + Client2 = u.NotMappedUser.Name ?? u.NotMappedUser.Role.Name, + Server2 = u.Name ?? u.Role.Name, + // Both side null + Client3 = u.NotMappedUser.Role.Name ?? u.NotMappedUser.Role.Name, + Server3 = u.Role.Name ?? u.Role.Name + }).ToList(); + Assert.That(list6.Select(o => o.Client), Is.EqualTo(list6.Select(o => o.Server))); + Assert.That(list6.Select(o => o.Client2), Is.EqualTo(list6.Select(o => o.Server2))); + Assert.That(list6.Select(o => o.Client3), Is.EqualTo(list6.Select(o => o.Server3))); + } + + [Test] + public void CanSelectUnaryClientSideTest() + { + var exception = Assert.Throws(() => + { + db.Animals.Select(a => -a.FatherOrMother.BodyWeight).ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. Cast expression '-[a].FatherOrMother.BodyWeight' to 'System.Nullable`1[System.Double]'.")); + + // Negate + var list = db.Animals.Select(a => new + { + Client = (double?) -a.FatherOrMother.BodyWeight, + Server = (double?) -((a.Father ?? a.Mother).BodyWeight) + }).ToList(); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // Convert + list = db.Animals.Select(a => new + { + Client = (double?) a.FatherOrMother.BodyWeight, + Server = (double?) (a.Father ?? a.Mother).BodyWeight + }).ToList(); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // UnaryPlus + list = db.Animals.Select(a => new + { + Client = (double?) +a.FatherOrMother.BodyWeight, + Server = (double?) +((a.Father ?? a.Mother).BodyWeight) + }).ToList(); + Assert.That(list.Select(o => o.Client), Is.EqualTo(list.Select(o => o.Server))); + + // Not + var list2 = db.Users.Select(u => new + { + Client = (bool?) !u.NotMappedUser.Role.IsActive, + Server = (bool?) !u.Role.IsActive + }).ToList(); + Assert.That(list2.Select(o => o.Client), Is.EqualTo(list2.Select(o => o.Server))); + + // Convert value type + var list3 = db.Users.Select(u => (int?) (u.Role != null ? 5 : 10)).ToList(); + Assert.That(list3, Has.Exactly(3).Not.Null); + + // Convert enum + list3 = db.Users.Select(u => (int?) u.Role.CreatedBy.Enum2).ToList(); + Assert.That(list3, Has.Exactly(3).Null); + + // Convert reference type + var list4 = db.Animals.Select(a => new + { + Client = (Dog) a.FatherOrMother, + Server = (Dog) (a.Father ?? a.Mother) + }).ToList(); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + + // TypeAs + list4 = db.Animals.Select(a => new + { + Client = a.FatherOrMother as Dog, + Server = (a.Father ?? a.Mother) as Dog + }).ToList(); + Assert.That(list4.Select(o => o.Client), Is.EqualTo(list4.Select(o => o.Server))); + + // Convert constant reference type + var list5 = db.Animals.Select(a => (Animal) new Dog()).ToList(); + Assert.That(list5, Has.Exactly(6).Not.Null); + } + + [Test] + public void CanSelectConditionalClientSideWithNullValueTypeTest() + { + var exception = Assert.Throws(() => + { + db.Animals.Select( + a => new + { + BodyWeight = (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + }) + .ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Double'. " + + "Cast expression 'IIF(IsNullOrWhiteSpace([a].Description), [_3].BodyWeight, [_1].BodyWeight)' to 'System.Nullable`1[System.Double]'.")); + + var list = db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + }) + .ToList(); + Assert.That(list, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list2 = db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : 5d) + }) + .ToList(); + Assert.That(list2, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list3 = db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? 5d + : a.Father.Mother.BodyWeight) + }) + .ToList(); + Assert.That(list3, Has.Exactly(6).With.Property("BodyWeight").Not.Null); + + var list4 = db.Animals.Select( + a => new + { + BodyWeightHashCode = (int?) ((string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight)).GetHashCode() + }) + .ToList(); + Assert.That(list4, Has.Exactly(0).With.Property("BodyWeightHashCode").Not.Null); + + var list5 = db.Animals.Select( + a => new + { + BodyWeight = (double?) (string.IsNullOrWhiteSpace(a.Description) + ? (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight) + : (string.IsNullOrWhiteSpace(a.Description) + ? a.Mother.Mother.BodyWeight + : a.Father.Mother.BodyWeight)) + }) + .ToList(); + Assert.That(list5, Has.Exactly(0).With.Property("BodyWeight").Not.Null); + + var list6 = db.Animals.Select( + a => new + { + Client = a.Father.HasFather ? (double?) null : a.BodyWeight, + Server = a.Father.Father != null ? (double?) null : a.BodyWeight, + }) + .ToList(); + Assert.That(list6.Select(o => o.Client), Is.EqualTo(list6.Select(o => o.Server))); + + var list7 = db.Users.Select( + a => new + { + Client = a.NotMappedUser.Role.IsActive ? 1 : 2, + Server = a.Role.IsActive ? 1 : 2 + }) + .ToList(); + Assert.That(list7.Select(o => o.Client), Is.EqualTo(list7.Select(o => o.Server))); + } + + [Test] + public void CanExecuteMethodWithNullObjectClientSideTest() + { + var exception = Assert.Throws(() => + { + db.Animals.Select( + a => new + { + a.Id, + FatherId = a.Father.Father.Id + }) + .ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_0].Father.Id' to 'System.Nullable`1[System.Int32]'.")); + + exception = Assert.Throws(() => + { + db.Animals.Select( + a => new + { + a.Id, + FatherIdHashCode = a.Father.Father.Id.GetHashCode() + }) + .ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_1].Id.GetHashCode()' to 'System.Nullable`1[System.Int32]'.")); + + var list = db.Animals.Select( + a => new + { + NullableId = (int?) a.Father.Father.Id, + NullableIdHashCode = (int?) a.Father.Father.Id.GetHashCode() + }) + .ToList(); + Assert.That(list, Has.Exactly(0).With.Property("NullableId").Not.Null); + } + + [Test] + public void CanSelectConstant() + { + AssertOneSelectColumn(db.Animals.Select(a => new { Test = a.Id + 1f + 5d })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = a.Id + 1m + 5 })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = 1 })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = "test" })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = 1 + 5 })); + AssertOneSelectColumn(db.Animals.Select(a => new { Id = a.Id, Test = 1 })); + AssertOneSelectColumn(db.Animals.Select(a => new { a.Id, Test = "test" })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = 1, Test2 = "test" })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = a.Id, Test2 = new { Value = "test" }, Test3 = 1 })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = new UserDto(1, "test") })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = new UserDto(1, "test"), a.Id })); + AssertOneSelectColumn(db.Animals.Select(a => new { Test = new UserDto(1, "test") { Dto2 = { Enum = EnumStoredAsInt32.High } }, a.Id })); + AssertOneSelectColumn(db.Animals.Select(a => new UserDto(1, "test"))); + AssertOneSelectColumn(db.Animals.Select(a => new UserDto(1, "test") {RoleName = a.Description})); + AssertOneSelectColumn(db.Animals.Select(a => new UserDto(a.Id, "test"))); + AssertOneSelectColumn(db.Animals.Select(a => 1)); + AssertOneSelectColumn(db.Animals.Select(a => "test")); + } + + [Test] + public void CanSelectWithIsOperator() + { + Assert.DoesNotThrow(() => db.Animals.Select(a => a is Dog).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => a.FatherSerialNumber is string).ToList()); + } + + [Test] + public void CanSelectComponentProperty() + { + AssertOneSelectColumn(db.Users.Select(u => u.Component.Property1)); + AssertOneSelectColumn(db.Users.Select(u => u.Component.OtherComponent.OtherProperty1)); + } + + [Test] + public void CanSelectNonMappedComponentProperty() + { + Assert.DoesNotThrow(() => db.Users.Select(u => u.Component.Property3).ToList()); + Assert.DoesNotThrow(() => db.Users.Select(u => u.Component.OtherComponent.OtherProperty2).ToList()); + var list = db.Users.Select(u => new + { + u.Component.OtherComponent.OtherProperty1, + OtherProperty3 = u.Component.OtherComponent.OtherProperty2, + u.Component.Property1, + u.Component.Property2, + u.Component.Property3 + }).ToList(); + Assert.That(list.Select(o => o.OtherProperty3), Is.EqualTo(list.Select(o => o.OtherProperty1))); + Assert.That( + list.Select(o => (o.Property1 ?? o.Property2) == null ? null : $"{o.Property1}{o.Property2}"), + Is.EqualTo(list.Select(o => o.Property3))); + } + + [Test] + public void CanSelectWithAnInvocation() + { + Func func = s => s + "postfix"; + Assert.DoesNotThrow(() => db.Animals.Select(a => func(a.SerialNumber)).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => func(a.FatherSerialNumber)).ToList()); + } + + [Test] + public void CanSelectEnumerable() + { + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { a.Id } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { a.Id, 1 } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { 1 } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { a, a.Father, a.Mother } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new + { + Enumerable = new[] + { + new UserDto(a.Id, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber}, + new UserDto(1, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber, InvalidLoginAttempts = 1}, + null, + new UserDto(1, "test") {RoleName = "test", InvalidLoginAttempts = 1}, + new UserDto(1, "test") {Dto2List = {new UserDto2(), new UserDto2()}, Dto2 = {Enum = EnumStoredAsInt32.High}}, + new UserDto(1, a.FatherSerialNumber) + { + Dto2List = {new UserDto2() { Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified }, new UserDto2()}, + Dto2 = {Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified} + } + } + }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { a.SerialNumber, a.FatherSerialNumber, null } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new int[][] { new[] { a.Id }, new[] { 1 }, new[] { a.Id, 1 } } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new List { a.Id, 1 } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new List(5) { a.Id, 1 } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new List(a.Id) { 1 } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new List(a.Id) { a.SerialNumber, a.FatherSerialNumber, null } }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new + { + Enumerable = new List(a.Id) + { + new UserDto(a.Id, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber}, + new UserDto(1, a.FatherSerialNumber) {RoleName = a.FatherSerialNumber, InvalidLoginAttempts = 1}, + null, + new UserDto(1, "test") {RoleName = "test", InvalidLoginAttempts = 1}, + new UserDto(1, "test") {Dto2List = {new UserDto2(), new UserDto2()}, Dto2 = {Enum = EnumStoredAsInt32.High}}, + new UserDto(1, a.FatherSerialNumber) + { + Dto2List = {new UserDto2() { Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified }, new UserDto2()}, + Dto2 = {Enum = a.Id > 0 ? EnumStoredAsInt32.High : EnumStoredAsInt32.Unspecified} + } + } + }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new[] { a.SerialNumber, a.FatherSerialNumber, null }[a.Id - a.Id].Length }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new { Enumerable = new List { a.SerialNumber, a.FatherSerialNumber, null }[a.Id - a.Id].Length }).ToList()); + Assert.DoesNotThrow(() => db.Animals.Select(a => new + { + Enumerable = new Dictionary + { + { a.SerialNumber, a.FatherSerialNumber }, + { "1", a.Father.SerialNumber }, + { "2", null } + }[a.SerialNumber] + }).ToList()); + } + + [Test] + public void CanSelectConditionalSubClassPropertyValue() + { + var animal = db.Animals.Select( + a => new + { + Pregnant = a is Mammal ? ((Mammal) a).Pregnant : false + }) + .Where(o => o.Pregnant) + .ToList(); + + Assert.That(animal, Has.Count.EqualTo(1)); + } + [Test] public void CanSelectConditionalEntityValueWithEntityComparisonRepeat() { @@ -523,10 +1350,200 @@ public void CanCastToCustomRegisteredType() Assert.That(db.Users.Where(o => (NullableInt32) o.Id == 1).ToList(), Has.Count.EqualTo(1)); } + [Test] + public void TestClientSideEvaluation() + { + var list = db.Animals.Select(a => new + { + ClientSide = string.IsNullOrEmpty(a.FatherSerialNumber) ? 1 : 0, + ClientSide2 = string.IsNullOrEmpty(a.Father.SerialNumber) ? 1 : 0 + }).ToList(); + Assert.That(list.Select(o => o.ClientSide), Is.EqualTo(list.Select(o => o.ClientSide2))); + + var list2 = db.Animals.Select(a => new + { + ClientSide = a.Father.IsProxy(), + ClientSide2 = a.FatherSerialNumber.IsProxy() + }).ToList(); + Assert.That(list2.Select(o => o.ClientSide), Is.EqualTo(list2.Select(o => o.ClientSide2))); + + var list3 = db.Orders.Where(o => o.OrderDate.HasValue).Select(o => new + { + ClientSide = o.OrderDate.Value.TimeOfDay.Days, + ClientSide2 = o.OrderDate.Value + }).ToList(); + Assert.That(list3.Select(o => o.ClientSide), Is.EqualTo(list3.Select(o => o.ClientSide2.TimeOfDay.Days))); + + var list4 = db.Orders.Where(o => o.OrderDate.HasValue).Select(o => new + { + o.OrderId, + ClientSide = o.OrderDate.Value.TimeOfDay.CompareTo(new TimeSpan(o.OrderId)), + ClientSide2 = o.OrderDate.Value + }).ToList(); + Assert.That(list4.Select(o => o.ClientSide), Is.EqualTo(list4.Select(o => o.ClientSide2.TimeOfDay.CompareTo(new TimeSpan(o.OrderId))))); + } + + [Test] + public void TestServerAndClientSideEvaluationComparison() + { + var list = db.Animals.Select( + a => new + { + ServerSide = (int?) a.Father.SerialNumber.Length, + ClientSide = (int?) a.FatherSerialNumber.Length + }).ToList(); + Assert.That(list.Select(o => o.ClientSide), Is.EqualTo(list.Select(o => o.ServerSide))); + + var list1 = db.Animals + .Where(a => a.Father.SerialNumber != null) + .Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length, + ClientSide = a.FatherSerialNumber.Length + }) + .ToList(); + Assert.That(list1.Select(o => o.ClientSide), Is.EqualTo(list1.Select(o => o.ServerSide))); + + var clientSide = db.Animals.Select(a => a.FatherSerialNumber.Length.ToString()).ToList(); + var serverSide = db.Animals.Select(a => a.FatherSerialNumber.Length.ToString()).ToList(); + Assert.That(clientSide, Is.EqualTo(serverSide)); + + var exception = Assert.Throws( + () => + { + db.Animals.Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length + }).ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[_0].SerialNumber.Length' to 'System.Nullable`1[System.Int32]'.")); + + exception = Assert.Throws( + () => + { + db.Animals.Select( + a => new + { + ClientSide = a.FatherSerialNumber.Length + }).ToList(); + }); + Assert.That(exception.InnerException, Is.TypeOf()); + Assert.That(exception.InnerException.Message, Is.EqualTo( + "Null value cannot be assigned to a value type 'System.Int32'. Cast expression '[a].FatherSerialNumber.Length' to 'System.Nullable`1[System.Int32]'.")); + + var list2 = db.Animals.Select( + a => new + { + ServerSide = a.Father.SerialNumber.Length.ToString(), + ClientSide = a.FatherSerialNumber.Length.ToString() + }).ToList(); + Assert.That(list2.Select(o => o.ClientSide), Is.EqualTo(list2.Select(o => o.ServerSide))); + + var list3 = db.Animals.Select( + a => new + { + ServerSide = (int?) a.Father.SerialNumber.Substring(0, ((int?) a.Father.SerialNumber.Length - 1) ?? 0).Length, + ClientSide = (int?) a.FatherSerialNumber.Substring(0, ((int?) a.FatherSerialNumber.Length - 1) ?? 0).Length + }).ToList(); + Assert.That(list3.Select(o => o.ClientSide), Is.EqualTo(list3.Select(o => o.ServerSide))); + + var list4 = db.Animals.Select(a => new + { + ServerSide = a.Father.SerialNumber, + ClientSide = a.FatherSerialNumber, + Test = (object) null + }).ToList(); + Assert.That(list4.Select(o => o.ClientSide), Is.EqualTo(list4.Select(o => o.ServerSide))); + + var list5 = db.Animals.Select(a => new + { + ServerSide = a.Father.SerialNumber == null, + ClientSide = a.FatherSerialNumber == null + }).ToList(); + Assert.That(list5.Select(o => o.ClientSide), Is.EqualTo(list5.Select(o => o.ServerSide))); + + var list6 = db.Animals + .Where(a => a.Father.SerialNumber != null) + .Select( + a => new + { + ServerSide = -a.Father.SerialNumber.Length, + ClientSide = -a.FatherSerialNumber.Length + }).ToList(); + Assert.That(list6.Select(o => o.ClientSide), Is.EqualTo(list6.Select(o => o.ServerSide))); + + var list7 = db.Animals + .Select( + a => new + { + ServerSide = a.Father != null ? a.Father.SerialNumber : null, + ClientSide = a.HasFather ? a.FatherSerialNumber : null + }).ToList(); + Assert.That(list7.Select(o => o.ClientSide), Is.EqualTo(list7.Select(o => o.ServerSide))); + + var list8 = db.Animals + .Where(a => a is Dog) + .Select( + a => new + { + ServerSide = (long?) (int?) ((Dog) a).Father.SerialNumber.Length, + ClientSide = (long?) (int?) ((Dog) a).FatherSerialNumber.Length + }).ToList(); + Assert.That(list8.Select(o => o.ClientSide), Is.EqualTo(list8.Select(o => o.ServerSide))); + } + public class Wrapper { public T item; public string message; } + + private double GetTolerance() + { + return !Dialect.SupportsIEEE754FloatingPointNumbers || TestDialect.SendsParameterValuesAsStrings + ? 0.1d + : 0d; + } + + private static void AssertOneSelectColumn(IQueryable query) + { + using (var sqlLog = new SqlLogSpy()) + { + // Execute query + foreach (var item in query) { } + Assert.That(FindAllOccurrences(sqlLog.GetWholeLog(), "as col"), Is.EqualTo(1)); + } + } + + private static string GetSqlSelect(IQueryable query) + { + using (var sqlLog = new SqlLogSpy()) + { + // Execute query + foreach (var item in query) { } + + var sql = sqlLog.GetWholeLog(); + return sql.Substring(0, sql.IndexOf(" from")); + } + } + + private static int FindAllOccurrences(string source, string substring) + { + if (source == null) + { + return 0; + } + int n = 0, count = 0; + while ((n = source.IndexOf(substring, n, StringComparison.InvariantCulture)) != -1) + { + n += substring.Length; + ++count; + } + return count; + } } } diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 20610d32bad..e671b0283d8 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -774,7 +774,8 @@ private void AssertResult( var expression = query.Expression; var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi)); - expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap); + expression = preTransformResult.Expression; + var constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); var visitorParameters = new VisitorParameters( @@ -786,6 +787,7 @@ private void AssertResult( QueryMode.Select); if (rewriteQuery) { + QueryModelRewriter.Rewrite(queryModel, visitorParameters); QueryModelVisitor.GenerateHqlQuery( queryModel, visitorParameters, diff --git a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs index 5318d771ec2..ee2b82c02f9 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs @@ -71,7 +71,7 @@ public void ShouldCreateDifferentKeys_TypeBinaryExpression() private static string GetCacheKey(Expression exp) { - return ExpressionKeyVisitor.Visit(exp, new Dictionary()); + return ExpressionKeyVisitor.Visit(exp, new Dictionary(), null); } } } diff --git a/src/NHibernate.Test/TestDialect.cs b/src/NHibernate.Test/TestDialect.cs index c62c4a7cb28..1e0d5904e04 100644 --- a/src/NHibernate.Test/TestDialect.cs +++ b/src/NHibernate.Test/TestDialect.cs @@ -191,5 +191,10 @@ public bool SupportsSqlType(SqlType sqlType) /// This flag is added to be able to test this feature selectively /// public virtual bool SupportsRowValueConstructorSyntax => _dialect.SupportsRowValueConstructorSyntax; + + /// + /// Whether the parameter values are sent as strings instead of binary. + /// + public virtual bool SendsParameterValuesAsStrings => false; } } diff --git a/src/NHibernate.Test/TestDialects/MySQL5TestDialect.cs b/src/NHibernate.Test/TestDialects/MySQL5TestDialect.cs index 9545219a36f..a7fbe285763 100644 --- a/src/NHibernate.Test/TestDialects/MySQL5TestDialect.cs +++ b/src/NHibernate.Test/TestDialects/MySQL5TestDialect.cs @@ -8,5 +8,18 @@ public MySQL5TestDialect(Dialect.Dialect dialect) } public override bool SupportsAggregateInSubSelect => true; + + /// + /// MySql.Data sends parameters as strings when the query is not prepared. + /// + /// + /// Sending parameters as strings has an impact on and + /// parameters as they can be differently evaluated by the database. For example when there + /// is no e-notation in the string the value will be evaluated as NUMBER type, which may cause + /// issues for and parameters. When there is an e-notation + /// the value will be evaluated as DOUBLE by the database, which may produce unexpected results + /// for parameters. + /// + public override bool SendsParameterValuesAsStrings => true; } } diff --git a/src/NHibernate.Test/TypedManyToOne/Address.cs b/src/NHibernate.Test/TypedManyToOne/Address.cs index 5cb79290600..ab25d0bc7f5 100644 --- a/src/NHibernate.Test/TypedManyToOne/Address.cs +++ b/src/NHibernate.Test/TypedManyToOne/Address.cs @@ -2,14 +2,16 @@ namespace NHibernate.Test.TypedManyToOne { - [Serializable] - public class Address - { - public virtual AddressId AddressId {get; set; } - public virtual string Street { get; set; } - public virtual string City { get; set; } - public virtual string State { get; set; } - public virtual string Zip { get; set; } - public virtual Customer Customer { get; set; } - } + [Serializable] + public class Address + { + public virtual AddressId AddressId { get; set; } + public virtual string Street { get; set; } + public virtual string City { get; set; } + public virtual string State { get; set; } + public virtual string Zip { get; set; } + public virtual Customer Customer { get; set; } + public virtual string BillingNotes { get; set; } + public virtual string ShippingNotes { get; set; } + } } diff --git a/src/NHibernate.Test/TypedManyToOne/Customer.hbm.xml b/src/NHibernate.Test/TypedManyToOne/Customer.hbm.xml index 90822998e19..3f660fbc4c0 100644 --- a/src/NHibernate.Test/TypedManyToOne/Customer.hbm.xml +++ b/src/NHibernate.Test/TypedManyToOne/Customer.hbm.xml @@ -53,7 +53,7 @@ - + - + diff --git a/src/NHibernate.Test/TypedManyToOne/TypedManyToOneTest.cs b/src/NHibernate.Test/TypedManyToOne/TypedManyToOneTest.cs index 4dbf6ce87c0..7c2ff620a2c 100644 --- a/src/NHibernate.Test/TypedManyToOne/TypedManyToOneTest.cs +++ b/src/NHibernate.Test/TypedManyToOne/TypedManyToOneTest.cs @@ -1,4 +1,5 @@ using System.Collections; +using System.Linq; using NHibernate.Dialect; using NUnit.Framework; @@ -24,38 +25,27 @@ protected override bool AppliesTo(Dialect.Dialect dialect) } [Test] - public void TestCreateQuery() + public void TestLinqEntityNameQuery() { - var cust = new Customer(); - cust.CustomerId = "abc123"; - cust.Name = "Matt"; - - var ship = new Address(); - ship.Street = "peachtree rd"; - ship.State = "GA"; - ship.City = "ATL"; - ship.Zip = "30326"; - ship.AddressId = new AddressId("SHIPPING", "xyz123"); - ship.Customer = cust; - - var bill = new Address(); - bill.Street = "peachtree rd"; - bill.State = "GA"; - bill.City = "ATL"; - bill.Zip = "30326"; - bill.AddressId = new AddressId("BILLING", "xyz123"); - bill.Customer = cust; - - cust.BillingAddress = bill; - cust.ShippingAddress = ship; - - using (ISession s = Sfi.OpenSession()) - using (ITransaction t = s.BeginTransaction()) + var cust = CreateCustomer(); + using (var s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) { - s.Persist(cust); + var billingNotes = s.Query().Select(o => o.BillingAddress.BillingNotes).First(); + Assert.That(billingNotes, Is.EqualTo("BillingNotes")); + var shippingNotes = s.Query().Select(o => o.ShippingAddress.ShippingNotes).First(); + Assert.That(shippingNotes, Is.EqualTo("ShippingNotes")); + t.Commit(); } + DeleteCustomer(cust); + } + + [Test] + public void TestCreateQuery() + { + var cust = CreateCustomer(); using (ISession s = Sfi.OpenSession()) using (ITransaction t = s.BeginTransaction()) { @@ -71,20 +61,7 @@ public void TestCreateQuery() t.Commit(); } - using (ISession s = Sfi.OpenSession()) - using (ITransaction t = s.BeginTransaction()) - { - s.SaveOrUpdate(cust); - ship = cust.ShippingAddress; - cust.ShippingAddress = null; - s.Delete("ShippingAddress", ship); - s.Flush(); - - Assert.That(s.Get("ShippingAddress", ship.AddressId), Is.Null); - s.Delete(cust); - - t.Commit(); - } + DeleteCustomer(cust); } [Test] @@ -113,5 +90,60 @@ public void TestCreateQueryNull() t.Commit(); } } + + private Customer CreateCustomer() + { + var cust = new Customer(); + cust.CustomerId = "abc123"; + cust.Name = "Matt"; + + var ship = new Address(); + ship.Street = "peachtree rd"; + ship.State = "GA"; + ship.City = "ATL"; + ship.Zip = "30326"; + ship.AddressId = new AddressId("SHIPPING", "xyz123"); + ship.Customer = cust; + ship.ShippingNotes = "ShippingNotes"; + + var bill = new Address(); + bill.Street = "peachtree rd"; + bill.State = "GA"; + bill.City = "ATL"; + bill.Zip = "30326"; + bill.AddressId = new AddressId("BILLING", "xyz123"); + bill.Customer = cust; + bill.BillingNotes = "BillingNotes"; + + cust.BillingAddress = bill; + cust.ShippingAddress = ship; + + using (ISession s = Sfi.OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + s.Persist(cust); + t.Commit(); + } + + return cust; + } + + private void DeleteCustomer(Customer cust) + { + using (var s = Sfi.OpenSession()) + using (var t = s.BeginTransaction()) + { + s.SaveOrUpdate(cust); + var ship = cust.ShippingAddress; + cust.ShippingAddress = null; + s.Delete("ShippingAddress", ship); + s.Flush(); + + Assert.That(s.Get("ShippingAddress", ship.AddressId), Is.Null); + s.Delete(cust); + + t.Commit(); + } + } } } diff --git a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs index 4d0344a27eb..2cc4b8863b4 100644 --- a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs @@ -21,6 +21,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -103,7 +104,7 @@ public Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdateAsync(cancellationToken); } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 6153309c7ac..0a32143256d 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -2584,6 +2584,21 @@ public virtual string TableTypeString get { return String.Empty; } // for differentiation of mysql storage engines } + /// + /// Whether is stored as a floating point number. + /// + public virtual bool IsDecimalStoredAsFloatingPointNumber => false; + + /// + /// Whether bitwise operators are supported for . + /// + public virtual bool SupportsBitwiseOperatorsOnBoolean => true; + + /// + /// Whether IEEE Standard 754 for floating point numbers is supported + /// + public virtual bool SupportsIEEE754FloatingPointNumbers => true; + /// /// The keyword used to specify a nullable column /// diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index 64b6dc911c3..9f959afd167 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -8,6 +8,7 @@ using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; +using NHibernate.Util; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Dialect @@ -43,6 +44,15 @@ public override string AddColumnString get { return "add"; } } + /// + public override void Configure(IDictionary settings) + { + base.Configure(settings); + // We have to match the scale for decimals as they are mapped with scale 5 by default. Without matching, an overflow exception + // will occur when multiplying a decimal column with a parameter. + DefaultCastScale = PropertiesHelper.GetByte(Environment.QueryDefaultCastScale, settings, null) ?? 5; + } + public override string GetSelectSequenceNextValString(string sequenceName) { return string.Format("gen_id({0}, 1 )", sequenceName); @@ -432,7 +442,7 @@ private void OverrideStandardHQLFunctions() { RegisterFunction("current_timestamp", new CurrentTimeStamp()); RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); - RegisterFunction("length", new StandardSafeSQLFunction("char_length", NHibernateUtil.Int64, 1)); + RegisterFunction("length", new StandardSafeSQLFunction("char_length", NHibernateUtil.Int32, 1)); RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2)); RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1)); RegisterFunction("upper", new StandardSafeSQLFunction("upper", NHibernateUtil.String, 1)); diff --git a/src/NHibernate/Dialect/Oracle10gDialect.cs b/src/NHibernate/Dialect/Oracle10gDialect.cs index 0495805c4ef..1a48013212e 100644 --- a/src/NHibernate/Dialect/Oracle10gDialect.cs +++ b/src/NHibernate/Dialect/Oracle10gDialect.cs @@ -26,12 +26,17 @@ public override JoinFragment CreateOuterJoinFragment() public override void Configure(IDictionary settings) { - base.Configure(settings); - _useBinaryFloatingPointTypes = PropertiesHelper.GetBoolean( Environment.OracleUseBinaryFloatingPointTypes, settings, false); + + if (_useBinaryFloatingPointTypes) + { + RegisterFunction("mod", new ModulusFunction(true, true)); + } + + base.Configure(settings); } // Avoid registering weighted double type when using binary floating point types @@ -66,5 +71,8 @@ protected override void RegisterFunctions() /// public override bool SupportsCrossJoin => true; + + /// + public override bool SupportsIEEE754FloatingPointNumbers => _useBinaryFloatingPointTypes; } } diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index 9bbe3e7dfb9..1fb8b1ede43 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -249,7 +249,7 @@ protected virtual void RegisterFunctions() RegisterFunction("soundex", new StandardSQLFunction("soundex")); RegisterFunction("upper", new StandardSQLFunction("upper")); RegisterFunction("ascii", new StandardSQLFunction("ascii", NHibernateUtil.Int32)); - RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int64)); + RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int32)); RegisterFunction("left", new SQLFunctionTemplate(NHibernateUtil.String, "substr(?1, 1, ?2)")); RegisterFunction("right", new SQLFunctionTemplate(NHibernateUtil.String, "substr(?1, -?2)")); @@ -297,7 +297,7 @@ protected virtual void RegisterFunctions() // Multi-param numeric dialect functions... RegisterFunction("atan2", new StandardSQLFunction("atan2", NHibernateUtil.Double)); RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Int32)); - RegisterFunction("mod", new ModulusFunction(true, true)); + RegisterFunction("mod", new ModulusFunction(true, false)); RegisterFunction("nvl", new StandardSQLFunction("nvl")); RegisterFunction("nvl2", new StandardSQLFunction("nvl2")); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); @@ -542,6 +542,9 @@ public override bool SupportsCurrentTimestampSelection get { return true; } } + /// + public override bool SupportsIEEE754FloatingPointNumbers => false; + public override IDataBaseSchema GetDataBaseSchema(DbConnection connection) { return new OracleDataBaseSchema(connection); diff --git a/src/NHibernate/Dialect/OracleLiteDialect.cs b/src/NHibernate/Dialect/OracleLiteDialect.cs index eb5af510652..e5f17208a34 100644 --- a/src/NHibernate/Dialect/OracleLiteDialect.cs +++ b/src/NHibernate/Dialect/OracleLiteDialect.cs @@ -80,7 +80,7 @@ public OracleLiteDialect() RegisterFunction("rtrim", new StandardSQLFunction("rtrim")); RegisterFunction("upper", new StandardSQLFunction("upper")); RegisterFunction("ascii", new StandardSQLFunction("ascii", NHibernateUtil.Int32)); - RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int64)); + RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int32)); RegisterFunction("to_char", new StandardSQLFunction("to_char", NHibernateUtil.String)); RegisterFunction("to_date", new StandardSQLFunction("to_date", NHibernateUtil.DateTime)); diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs index 0bf2b364eef..026785d9d7b 100644 --- a/src/NHibernate/Dialect/PostgreSQLDialect.cs +++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs @@ -243,6 +243,9 @@ public override SqlString GetLimitString(SqlString queryString, SqlString offset /// public override bool SupportsOuterJoinForUpdate => false; + /// + public override bool SupportsBitwiseOperatorsOnBoolean => false; + public override string GetForUpdateString(string aliases) { return ForUpdateString + " of " + aliases; diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index fa71342185f..eef7191c284 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -357,6 +357,9 @@ public override bool GenerateTablePrimaryKeyConstraintForIdentityColumn get { return false; } } + /// + public override bool IsDecimalStoredAsFloatingPointNumber => true; + public override string Qualify(string catalog, string schema, string table) { StringBuilder qualifiedName = new StringBuilder(); diff --git a/src/NHibernate/Dialect/SybaseASA9Dialect.cs b/src/NHibernate/Dialect/SybaseASA9Dialect.cs index dfae7baa471..999cb191801 100644 --- a/src/NHibernate/Dialect/SybaseASA9Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASA9Dialect.cs @@ -73,7 +73,7 @@ public SybaseASA9Dialect() // Override standard HQL function RegisterFunction("current_timestamp", new StandardSQLFunction("current_timestamp", NHibernateUtil.LocalDateTime)); - RegisterFunction("length", new StandardSafeSQLFunction("length", NHibernateUtil.String, 1)); + RegisterFunction("length", new StandardSafeSQLFunction("length", NHibernateUtil.Int32, 1)); RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2)); RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1)); RegisterFunction("upper", new StandardSafeSQLFunction("upper", NHibernateUtil.String, 1)); diff --git a/src/NHibernate/Driver/MySqlDataDriver.cs b/src/NHibernate/Driver/MySqlDataDriver.cs index efbe9675f9a..cd16c536b91 100644 --- a/src/NHibernate/Driver/MySqlDataDriver.cs +++ b/src/NHibernate/Driver/MySqlDataDriver.cs @@ -56,7 +56,10 @@ public MySqlDataDriver() : base( public override bool SupportsMultipleOpenReaders => false; /// - /// MySql.Data does not support preparing of commands. + /// MySql.Data supports prepared statements but in the current state NHibernate may + /// execute two queries in one command separated by a semicolon (e.g. SELECT LAST_INSERT_ID() when + /// native id generator is used), which throws an exception when + /// is called. /// /// - it is not supported. /// diff --git a/src/NHibernate/Driver/OdbcDriver.cs b/src/NHibernate/Driver/OdbcDriver.cs index cf8df041cea..5ac80336849 100644 --- a/src/NHibernate/Driver/OdbcDriver.cs +++ b/src/NHibernate/Driver/OdbcDriver.cs @@ -78,10 +78,18 @@ private void SetVariableLengthParameterSize(DbParameter dbParam, SqlType sqlType { switch (dbParam.DbType) { - case DbType.AnsiString: + case DbType.StringFixedLength: case DbType.AnsiStringFixedLength: + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } + + break; case DbType.String: - case DbType.StringFixedLength: + case DbType.AnsiString: // NH-4083: do not limit to column length if above 2000. Setting size may trigger conversion from // nvarchar to ntext when size is superior or equal to 2000, causing some queries to fail: // https://stackoverflow.com/q/8569844/1178314 diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index 682d755472a..1002005395d 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -161,7 +161,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq { case DbType.AnsiString: case DbType.AnsiStringFixedLength: - dbParam.Size = IsAnsiText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForAnsiClob : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; + dbParam.Size = IsAnsiText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForAnsiClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; break; case DbType.Binary: dbParam.Size = IsBlob(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForBlob : MsSql2000Dialect.MaxSizeForLengthLimitedBinary; @@ -174,7 +176,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq break; case DbType.String: case DbType.StringFixedLength: - dbParam.Size = IsText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForClob : MsSql2000Dialect.MaxSizeForLengthLimitedString; + dbParam.Size = IsText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedString; break; case DbType.DateTime2: dbParam.Size = MsSql2000Dialect.MaxDateTime2; @@ -283,6 +287,18 @@ protected static bool IsBlob(DbParameter dbParam, SqlType sqlType) return (sqlType is BinaryBlobSqlType) || ((DbType.Binary == dbParam.DbType) && sqlType.LengthDefined && (sqlType.Length > MsSql2000Dialect.MaxSizeForLengthLimitedBinary)); } + /// + /// Interprets if a parameter is a character (for the purposes of setting its default size) + /// + /// The parameter + /// The of the parameter + /// True, if the parameter should be interpreted as a character, otherwise False + protected static bool IsChar(DbParameter dbParam, SqlType sqlType) + { + return (DbType.StringFixedLength == dbParam.DbType || DbType.AnsiStringFixedLength == dbParam.DbType) && + sqlType.LengthDefined && sqlType.Length == 1; + } + public override IResultSetsCommand GetResultSetsCommand(ISessionImplementor session) { return new BasicResultSetsCommand(session); diff --git a/src/NHibernate/Driver/SqlServerCeDriver.cs b/src/NHibernate/Driver/SqlServerCeDriver.cs index eb4f03316ea..0b6a4ad93bc 100644 --- a/src/NHibernate/Driver/SqlServerCeDriver.cs +++ b/src/NHibernate/Driver/SqlServerCeDriver.cs @@ -75,6 +75,12 @@ public override IResultSetsCommand GetResultSetsCommand(Engine.ISessionImplement protected override void InitializeParameter(DbParameter dbParam, string name, SqlType sqlType) { base.InitializeParameter(dbParam, name, AdjustSqlType(sqlType)); + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.LengthDefined && sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } AdjustDbParamTypeForLargeObjects(dbParam, sqlType); } diff --git a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs index 7b9e937bd18..e7b95eed2cb 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Linq; using NHibernate.Util; namespace NHibernate.Hql.Ast.ANTLR @@ -16,15 +17,24 @@ public class ASTQueryTranslatorFactory : IQueryTranslatorFactory { public IQueryTranslator[] CreateQueryTranslators(IQueryExpression queryExpression, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) { - return CreateQueryTranslators(queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); + return CreateQueryTranslators(queryExpression, queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); } - static IQueryTranslator[] CreateQueryTranslators(IASTNode ast, string queryIdentifier, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) + static IQueryTranslator[] CreateQueryTranslators( + IQueryExpression queryExpression, + IASTNode ast, + string queryIdentifier, + string collectionRole, + bool shallow, + IDictionary filters, + ISessionFactoryImplementor factory) { var polymorphicParsers = AstPolymorphicProcessor.Process(ast, factory); var translators = polymorphicParsers - .ToArray(hql => new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); + .ToArray(hql => queryExpression is NhLinqExpression linqExpression + ? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.NamedParameters) + : new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); foreach (var translator in translators) { diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs index 37e5eaffc6e..d6f3a7f0861 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs @@ -36,7 +36,7 @@ public partial class HqlSqlWalker private string _statementTypeName; private int _positionalParameterCount; private int _parameterCount; - private readonly NullableDictionary _namedParameters = new NullableDictionary(); + private readonly NullableDictionary _namedParameterLocations = new NullableDictionary(); private readonly List _parameters = new List(); private FromClause _currentFromClause; private SelectClause _selectClause; @@ -54,6 +54,7 @@ public partial class HqlSqlWalker private readonly LiteralProcessor _literalProcessor; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private JoinType _impliedJoinType; @@ -64,17 +65,30 @@ public partial class HqlSqlWalker private int numberOfParametersInSetClause; private Stack clauseStack=new Stack(); - public HqlSqlWalker(QueryTranslatorImpl qti, - ISessionFactoryImplementor sfi, - ITreeNodeStream input, - IDictionary tokenReplacements, - string collectionRole) + public HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + string collectionRole) + : this(qti, sfi, input, tokenReplacements, null, collectionRole) + { + } + + internal HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) : this(input) { _sessionFactoryHelper = new SessionFactoryHelperExtensions(sfi); _qti = qti; _literalProcessor = new LiteralProcessor(this); _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionFilterRole = collectionRole; } @@ -122,7 +136,7 @@ public ISet QuerySpaces public IDictionary NamedParameters { - get { return _namedParameters; } + get { return _namedParameterLocations; } } internal SessionFactoryHelperExtensions SessionFactoryHelper @@ -1033,13 +1047,20 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode) ); parameter.HqlParameterSpecification = paramSpec; + if (_namedParameters != null && _namedParameters.TryGetValue(name, out var namedParameter)) + { + // Add the parameter type information so that we are able to calculate functions return types + // when the parameter is used as an argument. + parameter.ExpectedType = namedParameter.Type; + } + _parameters.Add(paramSpec); return parameter; } IASTNode GeneratePositionalParameter(IASTNode inputNode) { - if (_namedParameters.Count > 0) + if (_namedParameterLocations.Count > 0) { // NH TODO: remove this limitation throw new SemanticException("cannot define positional parameter after any named parameters have been defined"); @@ -1171,15 +1192,15 @@ public void AddQuerySpaces(string[] spaces) private void TrackNamedParameterPositions(string name) { int loc = _parameterCount++; - object o = _namedParameters[name]; + object o = _namedParameterLocations[name]; if ( o == null ) { - _namedParameters.Add(name, loc); + _namedParameterLocations.Add(name, loc); } else if (o is int) { List list = new List(4) {(int) o, loc}; - _namedParameters[name] = list; + _namedParameterLocations[name] = list; } else { diff --git a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs index 2e27559d1dd..bcf3dc14e11 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs @@ -29,7 +29,8 @@ public partial class QueryTranslatorImpl : IFilterTranslator private readonly string _queryIdentifier; private readonly IASTNode _stageOneAst; private readonly ISessionFactoryImplementor _factory; - + private readonly IDictionary _namedParameters; + private bool _shallowQuery; private bool _compiled; private IDictionary _enabledFilters; @@ -47,10 +48,28 @@ public partial class QueryTranslatorImpl : IFilterTranslator /// Currently enabled filters /// The session factory constructing this translator instance. public QueryTranslatorImpl( - string queryIdentifier, - IASTNode parsedQuery, - IDictionary enabledFilters, - ISessionFactoryImplementor factory) + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory) + : this(queryIdentifier, parsedQuery, enabledFilters, factory, null) + { + } + + /// + /// Creates a new AST-based query translator. + /// + /// The query-identifier (used in stats collection) + /// The hql query to translate + /// Currently enabled filters + /// The session factory constructing this translator instance. + /// The named parameters information. + internal QueryTranslatorImpl( + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory, + IDictionary namedParameters) { _queryIdentifier = queryIdentifier; _stageOneAst = parsedQuery; @@ -58,6 +77,7 @@ public QueryTranslatorImpl( _shallowQuery = false; _enabledFilters = enabledFilters; _factory = factory; + _namedParameters = namedParameters; } /// @@ -434,7 +454,7 @@ private static IStatementExecutor BuildAppropriateStatementExecutor(IStatement s private HqlSqlTranslator Analyze(string collectionRole) { - var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, collectionRole); + var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, _namedParameters, collectionRole); translator.Translate(); @@ -548,15 +568,23 @@ internal class HqlSqlTranslator private readonly QueryTranslatorImpl _qti; private readonly ISessionFactoryImplementor _sfi; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private readonly string _collectionRole; private IStatement _resultAst; - public HqlSqlTranslator(IASTNode ast, QueryTranslatorImpl qti, ISessionFactoryImplementor sfi, IDictionary tokenReplacements, string collectionRole) + public HqlSqlTranslator( + IASTNode ast, + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) { _inputAst = ast; _qti = qti; _sfi = sfi; _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionRole = collectionRole; } @@ -576,7 +604,7 @@ public IStatement Translate() var nodes = new BufferedTreeNodeStream(_inputAst); - var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _collectionRole); + var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _namedParameters, _collectionRole); hqlSqlWalker.TreeAdaptor = new HqlSqlWalkerTreeAdaptor(hqlSqlWalker); try diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs index fd91e09fd3a..2ff3a10f790 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs @@ -63,25 +63,13 @@ private IASTNode GetHighOperand() private static void Check(IASTNode check, IASTNode first, IASTNode second) { - var expectedTypeAwareNode = check as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (!(check is IExpectedTypeAwareNode expectedTypeAwareNode) || + expectedTypeAwareNode.ExpectedType != null) { - IType expectedType = null; - var firstNode = first as SqlNode; - if (firstNode != null) - { - expectedType = firstNode.DataType; - } - if (expectedType == null) - { - var secondNode = second as SqlNode; - if (secondNode != null) - { - expectedType = secondNode.DataType; - } - } - expectedTypeAwareNode.ExpectedType = expectedType; + return; } + + expectedTypeAwareNode.ExpectedType = (first as SqlNode)?.DataType ?? (second as SqlNode)?.DataType; } } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs index 9b706facb49..60ca24ca379 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs @@ -32,32 +32,34 @@ public void Initialize() IType lhType = (lhs is SqlNode) ? ((SqlNode)lhs).DataType : null; IType rhType = (rhs is SqlNode) ? ((SqlNode)rhs).DataType : null; - if (lhs is IExpectedTypeAwareNode && rhType != null) + TrySetExpectedType(lhs, rhType, true); + TrySetExpectedType(rhs, lhType, false); + } + + private void TrySetExpectedType(IASTNode operand, IType otherOperandType, bool leftHandOperand) + { + if (!(operand is IExpectedTypeAwareNode typeAwareNode) || + otherOperandType == null || + typeAwareNode.ExpectedType != null) { - IType expectedType; + return; + } + + IType expectedType = null; - // we have something like : "? [op] rhs" - if (IsDateTimeType(rhType)) + // we have something like : "lhs [op] ?" or "? [op] rhs" + if (IsDateTimeType(otherOperandType)) + { + if (leftHandOperand) { // more specifically : "? [op] datetime" // 1) if the operator is MINUS, the param needs to be of // some datetime type // 2) if the operator is PLUS, the param needs to be of // some numeric type - expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : rhType; + expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : otherOperandType; } - else - { - expectedType = rhType; - } - ((IExpectedTypeAwareNode)lhs).ExpectedType = expectedType; - } - else if (rhs is ParameterNode && lhType != null) - { - IType expectedType = null; - - // we have something like : "lhs [op] ?" - if (IsDateTimeType(lhType)) + else if (Type == HqlSqlWalker.PLUS) { // more specifically : "datetime [op] ?" // 1) if the operator is MINUS, we really cannot determine @@ -65,17 +67,15 @@ public void Initialize() // numeric would be valid // 2) if the operator is PLUS, the param needs to be of // some numeric type - if (Type == HqlSqlWalker.PLUS) - { - expectedType = NHibernateUtil.Double; - } - } - else - { - expectedType = lhType; + expectedType = NHibernateUtil.Double; } - ((IExpectedTypeAwareNode)rhs).ExpectedType = expectedType; } + else + { + expectedType = otherOperandType; + } + + typeAwareNode.ExpectedType = expectedType; } public override IType DataType diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs index bf0560dfc76..cae4b920ec8 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs @@ -65,15 +65,14 @@ public virtual void Initialize() rhsType = lhsType; } - var lshExpectedTypeAwareNode = lhs as IExpectedTypeAwareNode; - if (lshExpectedTypeAwareNode != null) + if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null) { - lshExpectedTypeAwareNode.ExpectedType = rhsType; + lshTypeAwareNode.ExpectedType = rhsType; } - var rshExpectedTypeAwareNode = rhs as IExpectedTypeAwareNode; - if (rshExpectedTypeAwareNode != null) + + if (rhs is IExpectedTypeAwareNode rshTypeAwareNode && rshTypeAwareNode.ExpectedType == null) { - rshExpectedTypeAwareNode.ExpectedType = lhsType; + rshTypeAwareNode.ExpectedType = lhsType; } MutateRowValueConstructorSyntaxesIfNecessary( lhsType, rhsType ); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs index 0ad4e404bda..f0b9856f76a 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs @@ -47,11 +47,12 @@ public override void Initialize() IASTNode inListChild = inList.GetChild(0); while (inListChild != null) { - var expectedTypeAwareNode = inListChild as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (inListChild is IExpectedTypeAwareNode expectedTypeAwareNode && + expectedTypeAwareNode.ExpectedType == null) { expectedTypeAwareNode.ExpectedType = lhsType; } + inListChild = inListChild.NextSibling; } } diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 8967174bde3..475fe637769 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -1,4 +1,5 @@ using System; +using System.CodeDom; using System.Collections.Generic; using System.Linq; using NHibernate.Hql.Ast.ANTLR; @@ -204,6 +205,23 @@ internal HqlQuery(IASTFactory factory, params HqlStatement[] children) public class HqlIdent : HqlExpression { + private static readonly Dictionary SupportedIdentTypes = new Dictionary + { + {TypeCode.Boolean, "bool"}, + {TypeCode.Int16, "short"}, + {TypeCode.Int32, "integer"}, + {TypeCode.Int64, "long"}, + {TypeCode.UInt16, "ushort"}, + {TypeCode.UInt32, "uint"}, + {TypeCode.UInt64, "ulong"}, + {TypeCode.Decimal, "decimal"}, + {TypeCode.Single, "single"}, + {TypeCode.DateTime, "datetime"}, + {TypeCode.String, "string"}, + {TypeCode.Char, "char"}, + {TypeCode.Double, "double"} + }; + internal HqlIdent(IASTFactory factory, string ident) : base(HqlSqlWalker.IDENT, ident, factory) { @@ -212,72 +230,37 @@ internal HqlIdent(IASTFactory factory, string ident) internal HqlIdent(IASTFactory factory, System.Type type) : base(HqlSqlWalker.IDENT, "", factory) { - type = type.UnwrapIfNullable(); - - switch (System.Type.GetTypeCode(type)) + if (!TryGetTypeName(type, out var typeName)) { - case TypeCode.Boolean: - SetText("bool"); - break; - case TypeCode.Int16: - SetText("short"); - break; - case TypeCode.Int32: - SetText("integer"); - break; - case TypeCode.Int64: - SetText("long"); - break; - case TypeCode.Decimal: - SetText("decimal"); - break; - case TypeCode.Single: - SetText("single"); - break; - case TypeCode.DateTime: - SetText("datetime"); - break; - case TypeCode.String: - SetText("string"); - break; - case TypeCode.Double: - SetText("double"); - break; - default: - if (type == typeof(Guid)) - { - SetText("guid"); - break; - } - if (type == typeof(DateTimeOffset)) - { - SetText("datetimeoffset"); - break; - } - throw new NotSupportedException(string.Format("Don't currently support idents of type {0}", type.Name)); + throw new NotSupportedException($"Don't currently support idents of type {type.Name}"); } + + SetText(typeName); } internal static bool SupportsType(System.Type type) + { + return TryGetTypeName(type, out _); + } + + private static bool TryGetTypeName(System.Type type, out string typeName) { type = type.UnwrapIfNullable(); - switch (System.Type.GetTypeCode(type)) + if (SupportedIdentTypes.TryGetValue(System.Type.GetTypeCode(type), out typeName)) { - case TypeCode.Boolean: - case TypeCode.Int16: - case TypeCode.Int32: - case TypeCode.Int64: - case TypeCode.Decimal: - case TypeCode.Single: - case TypeCode.DateTime: - case TypeCode.String: - case TypeCode.Double: - return true; - default: - return - type == typeof(Guid) || - type == typeof(DateTimeOffset); + return true; + } + + if (type == typeof(Guid)) + { + typeName = "guid"; + } + else if (type == typeof(DateTimeOffset)) + { + typeName = "datetimeoffset"; } + + return typeName != null; } } diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index 9ff4c712b0d..ba46b665466 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -142,7 +142,8 @@ protected internal virtual IType DetermineType(int paramPosition, object paramVa protected internal virtual IType DetermineType(int paramPosition, object paramValue) { - IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? GuessType(paramValue); + IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } @@ -154,67 +155,15 @@ protected internal virtual IType DetermineType(string paramName, object paramVal protected internal virtual IType DetermineType(string paramName, object paramValue) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(paramValue); + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } protected internal virtual IType DetermineType(string paramName, System.Type clazz) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(clazz); - return type; - } - - /// - /// Guesses the from the param's value. - /// - /// The object to guess the of. - /// An for the object. - /// - /// Thrown when the param is null because the - /// can't be guess from a null value. - /// - private IType GuessType(object param) - { - if (param == null) - { - throw new ArgumentNullException("param", "The IType can not be guessed for a null value."); - } - - System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); - return GuessType(clazz); - } - - /// - /// Guesses the from the . - /// - /// The to guess the of. - /// An for the . - /// - /// Thrown when the clazz is null because the - /// can't be guess from a null type. - /// - private IType GuessType(System.Type clazz) - { - if (clazz == null) - { - throw new ArgumentNullException("clazz", "The IType can not be guessed for a null value."); - } - - var type = TypeFactory.HeuristicType(clazz); - if (type == null || type is SerializableType) - { - if (session.Factory.TryGetEntityPersister(clazz.FullName) != null) - { - return NHibernateUtil.Entity(clazz); - } - - if (type == null) - { - throw new HibernateException( - "Could not determine a type for class: " + clazz.AssemblyQualifiedName); - } - } - + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(clazz, session.Factory); return type; } @@ -310,7 +259,11 @@ public IQuery SetParameter(int position, T val) { CheckPositionalParameter(position); - return SetParameter(position, val, parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? GuessType(typeof(T))); + return SetParameter( + position, + val, + parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } private void CheckPositionalParameter(int position) @@ -327,7 +280,11 @@ private void CheckPositionalParameter(int position) public IQuery SetParameter(string name, T val) { - return SetParameter(name, val, parameterMetadata.GetNamedParameterExpectedType(name) ?? GuessType(typeof (T))); + return SetParameter( + name, + val, + parameterMetadata.GetNamedParameterExpectedType(name) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } public IQuery SetParameter(string name, object val) @@ -792,7 +749,12 @@ public IQuery SetParameterList(string name, IEnumerable vals) } object firstValue = vals.Cast().FirstOrDefault(); - SetParameterList(name, vals, firstValue == null ? GuessType(vals.GetCollectionElementType()) : DetermineType(name, firstValue)); + SetParameterList( + name, + vals, + firstValue == null + ? ParameterHelper.GuessType(vals.GetCollectionElementType(), session.Factory) + : DetermineType(name, firstValue)); return this; } diff --git a/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs b/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs index 6955afd936b..672c4b03388 100644 --- a/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs @@ -31,6 +31,11 @@ public IBodyClause Clone(CloneContext cloneContext) return new NhOuterJoinClause(JoinClause.Clone(cloneContext)); } + public override string ToString() + { + return $"outer {JoinClause}"; + } + protected override void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index) { if (visitor is INhQueryModelVisitorExtended queryModelVisitorExtended) diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index c8de5a37a5e..912b640a951 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -11,6 +11,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -211,7 +212,7 @@ protected virtual NhLinqExpression PrepareQuery(Expression expression, out IQuer query = Session.CreateFilter(Collection, nhLinqExpression); } - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression, nhLinqExpression.ParameterValuesByName); @@ -252,38 +253,19 @@ protected virtual object ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery #pragma warning restore 618 } - private static void SetParameters(IQuery query, IDictionary> parameters) + private static void SetParameters(IQuery query, IDictionary parameters) { foreach (var parameterName in query.NamedParameters) { - var param = parameters[parameterName]; - - if (param.Item1 == null) + // The parameter type will be taken from the parameter metadata + var parameter = parameters[parameterName]; + if (parameter.IsCollection) { - if (typeof(IEnumerable).IsAssignableFrom(param.Item2.ReturnedClass) && - param.Item2.ReturnedClass != typeof(string)) - { - query.SetParameterList(parameterName, null, param.Item2); - } - else - { - query.SetParameter(parameterName, null, param.Item2); - } + query.SetParameterList(parameter.Name, (IEnumerable) parameter.Value); } else { - if (param.Item1 is IEnumerable && !(param.Item1 is string)) - { - query.SetParameterList(parameterName, (IEnumerable)param.Item1); - } - else if (param.Item2 != null) - { - query.SetParameter(parameterName, param.Item1, param.Item2); - } - else - { - query.SetParameter(parameterName, param.Item1); - } + query.SetParameter(parameter.Name, parameter.Value); } } } @@ -310,7 +292,7 @@ public int ExecuteDml(QueryMode queryMode, Expression expression) var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdate(); } diff --git a/src/NHibernate/Linq/ExpressionEvaluation.cs b/src/NHibernate/Linq/ExpressionEvaluation.cs new file mode 100644 index 00000000000..2e1ec07781a --- /dev/null +++ b/src/NHibernate/Linq/ExpressionEvaluation.cs @@ -0,0 +1,36 @@ +using System; +using System.Linq.Expressions; + +namespace NHibernate.Linq +{ + /// + /// Contains methods that can be used to force database or client evaluation of an expression in a select statement + /// when using NHibernate Linq query provider. + /// + public static class ExpressionEvaluation + { + /// + /// Forces client evaluation of an expression in a select statement when using NHibernate Linq query provider. + /// + /// The return type of . + /// The expression to force client evaluation. + /// When the method is used outside NHibernate Linq query. + [NoPreEvaluation] + public static T ClientEval(Expression> expression) + { + throw new InvalidOperationException("The method should be used inside NHibernate Linq query"); + } + + /// + /// Forces database evaluation of an expression in a select statement when using NHibernate Linq query provider. + /// + /// The return type of . + /// The expression to force database evaluation. + /// When the method is used outside NHibernate Linq query. + [NoPreEvaluation] + public static T DatabaseEval(Expression> expression) + { + throw new InvalidOperationException("The method should be used inside NHibernate Linq query"); + } + } +} diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index 53955d136a6..3de1d27eda2 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -14,5 +14,12 @@ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGe public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); public virtual bool AllowsNullableReturnType(MethodInfo method) => true; + + /// + public virtual bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } } diff --git a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs index 33cb12c2c6c..cc0fa7202b9 100644 --- a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs +++ b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs @@ -42,5 +42,12 @@ private static HqlExpression GetRhs(MethodInfo method, ReadOnlyCollection !method.ReturnType.IsValueType; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index 73ad8b3d9e4..3ef7583ee6b 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -18,6 +18,14 @@ public interface IHqlGeneratorForMethod internal interface IHqlGeneratorForMethodExtended { bool AllowsNullableReturnType(MethodInfo method); + + /// + /// Try getting a collection parameter from . + /// + /// The method call expression. + /// Output parameter for the retrieved collection parameter. + /// Whether collection parameter was retrieved. + bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter); } internal static class HqlGeneratorForMethodExtensions @@ -33,6 +41,21 @@ public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generato return true; } + // 6.0 TODO: Remove + public static bool TryGetCollectionParameters( + this IHqlGeneratorForMethod generator, + MethodCallExpression expression, + out ConstantExpression collectionParameter) + { + if (generator is IHqlGeneratorForMethodExtended extendedGenerator) + { + return extendedGenerator.TryGetCollectionParameter(expression, out collectionParameter); + } + + collectionParameter = null; + return false; + } + // 6.0 TODO: merge into IHqlGeneratorForMethod /// /// Should pre-evaluation be allowed for this method? diff --git a/src/NHibernate/Linq/Functions/QueryableGenerator.cs b/src/NHibernate/Linq/Functions/QueryableGenerator.cs index f007fa22592..eaf44c64f4a 100644 --- a/src/NHibernate/Linq/Functions/QueryableGenerator.cs +++ b/src/NHibernate/Linq/Functions/QueryableGenerator.cs @@ -155,6 +155,15 @@ public CollectionContainsGenerator() public override bool AllowsNullableReturnType(MethodInfo method) => false; + /// + public override bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + var argument = expression.Method.IsStatic ? expression.Arguments[0] : expression.Object; + collectionParameter = argument as ConstantExpression; + + return collectionParameter != null; + } + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // TODO - alias generator diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 31edf4a42d1..9605ff8f3d8 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -59,6 +59,13 @@ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) } public bool AllowsNullableReturnType(MethodInfo method) => false; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } public class LengthGenerator : BaseHqlGeneratorForProperty diff --git a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs index cf55ac05a7e..e0fc5bf6324 100644 --- a/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/NestedSelectRewriter.cs @@ -225,7 +225,7 @@ private static Expression GetIdentifier(ISessionFactory sessionFactory, Expressi var classMetadata = sessionFactory.GetClassMetadata(expression.Type); if (classMetadata == null) - return Expression.Constant(null); + return null; var propertyName=classMetadata.IdentifierPropertyName; NHibernate.Type.EmbeddedComponentType componentType; diff --git a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs index aa1a869c611..a5830b0f25c 100644 --- a/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs +++ b/src/NHibernate/Linq/NestedSelects/SelectClauseRewriter.cs @@ -23,7 +23,11 @@ public SelectClauseRewriter(Expression parameter, ICollection this.expressions = expressions; this.parameter = parameter; this.tuple = tuple; - this.expressions.Add(new ExpressionHolder { Expression = expression, Tuple = tuple }); //ID placeholder + if (expression != null) + { + this.expressions.Add(new ExpressionHolder { Expression = expression, Tuple = tuple }); //ID placeholder + } + _dictionary = dictionary; } @@ -59,4 +63,4 @@ private Expression AddAndConvertExpression(Expression expression) expression.Type); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index ad39397fd71..9b23500ffb5 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -34,6 +34,8 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression protected virtual QueryMode QueryMode { get; } + internal IDictionary NamedParameters { get; } + private readonly Expression _expression; private readonly IDictionary _constantToParameterMap; @@ -56,12 +58,12 @@ internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFa // referenced from the main query. LinqLogging.LogExpression("Expression (partially evaluated)", _expression); - _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap); + _constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name, - p => System.Tuple.Create(p.Value, p.Type)); - - Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); + p => System.Tuple.Create(p.Value, p.Type)); + NamedParameters = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name); + Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap, sessionFactory); Type = _expression.Type; @@ -88,8 +90,10 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var queryModel = NhRelinqQueryParser.Parse(_expression); queryModel.TransformExpressions(TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers); + ParameterTypeLocator.SetParameterTypes(_constantToParameterMap, queryModel, TargetType, sessionFactory, true); var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, new QuerySourceNamer(), TargetType, QueryMode); + QueryModelRewriter.Rewrite(queryModel, visitorParameters); ExpressionToHqlTranslationResults = QueryModelVisitor.GenerateHqlQuery(queryModel, visitorParameters, true, ReturnType); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 99a8e009571..9826ffdbf06 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -29,8 +29,8 @@ private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel q { _sessionFactory = sessionFactory; var joiner = new Joiner(queryModel, AddJoin); - _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); - _whereJoinDetector = new WhereJoinDetector(this, joiner); + _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner, _sessionFactory); + _whereJoinDetector = new WhereJoinDetector(this, joiner, _sessionFactory); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) diff --git a/src/NHibernate/Linq/ReWriters/SimplifyConditionalRewriter.cs b/src/NHibernate/Linq/ReWriters/SimplifyConditionalRewriter.cs new file mode 100644 index 00000000000..183cc3c0750 --- /dev/null +++ b/src/NHibernate/Linq/ReWriters/SimplifyConditionalRewriter.cs @@ -0,0 +1,33 @@ +using NHibernate.Linq.Clauses; +using NHibernate.Linq.Visitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; + +namespace NHibernate.Linq.ReWriters +{ + internal class SimplifyConditionalRewriter : NhQueryModelVisitorBase + { + private static readonly SimplifyConditionalVisitor ConditionalVisitor = new SimplifyConditionalVisitor(); + private static readonly SimplifyConditionalRewriter Instance = new SimplifyConditionalRewriter(); + + public static void Rewrite(QueryModel queryModel) + { + Instance.VisitQueryModel(queryModel); + } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + whereClause.Predicate = ConditionalVisitor.Visit(whereClause.Predicate); + } + + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + havingClause.Predicate = ConditionalVisitor.Visit(havingClause.Predicate); + } + + public override void VisitNhWithClause(NhWithClause withClause, QueryModel queryModel, int index) + { + withClause.Predicate = ConditionalVisitor.Visit(withClause.Predicate); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ef4981d2aec..140ec065c4f 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -1,13 +1,19 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Text; +using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -22,22 +28,68 @@ namespace NHibernate.Linq.Visitors public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; + private readonly ISessionFactoryImplementor _sessionFactory; readonly StringBuilder _string = new StringBuilder(); + private QueryModelKeyVisitor _queryModelKeyVisitor; - private ExpressionKeyVisitor(IDictionary constantToParameterMap) + private ExpressionKeyVisitor( + IDictionary constantToParameterMap, + ISessionFactoryImplementor sessionFactory) { _constantToParameterMap = constantToParameterMap; + _sessionFactory = sessionFactory; } + private QueryModelKeyVisitor QueryModelKeyVisitor => + _queryModelKeyVisitor ?? (_queryModelKeyVisitor = new QueryModelKeyVisitor(this, _string)); + + // Since v5.3 + [Obsolete("Use the overload with ISessionFactoryImplementor parameter")] public static string Visit(Expression expression, IDictionary parameters) { - var visitor = new ExpressionKeyVisitor(parameters); + var visitor = new ExpressionKeyVisitor(parameters, null); visitor.Visit(expression); return visitor.ToString(); } + /// + /// Generates the key for the expression. + /// + /// The expression. + /// The session factory. + /// Parameters found in . + /// The key for the expression. + public static string Visit( + Expression rootExpression, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + var visitor = new ExpressionKeyVisitor(parameters, sessionFactory); + visitor.Visit(rootExpression); + + return visitor.ToString(); + } + + /// + /// Generates the key for a child expression based on the of one of its parents. + /// + /// The child expression. + /// The query parameters for the parent expression. + /// The session factory. + /// The key for the child expression. + internal static string VisitChild( + Expression childExpression, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + var visitor = new ExpressionKeyVisitor(parameters, sessionFactory); + visitor.Visit(childExpression); + + return visitor.ToString(); + } + public override string ToString() { return _string.ToString(); @@ -45,23 +97,22 @@ public override string ToString() protected override Expression VisitBinary(BinaryExpression expression) { - if (expression.Method != null) - { - _string.Append(expression.Method.DeclaringType.Name); - _string.Append("."); - VisitMethod(expression.Method); - } - else + if (expression.NodeType == ExpressionType.ArrayIndex) { - _string.Append(expression.NodeType); + Visit(expression.Left); + _string.Append("["); + Visit(expression.Right); + _string.Append("]"); + + return expression; } _string.Append("("); - Visit(expression.Left); - _string.Append(", "); + _string.Append(" "); + _string.Append(expression.NodeType); + _string.Append(" "); Visit(expression.Right); - _string.Append(")"); return expression; @@ -69,11 +120,13 @@ protected override Expression VisitBinary(BinaryExpression expression) protected override Expression VisitConditional(ConditionalExpression expression) { + _string.Append("IIF("); Visit(expression.Test); - _string.Append(" ? "); + _string.Append(","); Visit(expression.IfTrue); - _string.Append(" : "); + _string.Append(","); Visit(expression.IfFalse); + _string.Append(")"); return expression; } @@ -86,62 +139,92 @@ protected override Expression VisitConstant(ConstantExpression expression) throw new InvalidOperationException("Cannot visit a constant without a constant to parameter map."); if (_constantToParameterMap.TryGetValue(expression, out param)) { - // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. - if (param.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = param.Value as IEnumerable; - if (value != null && !(value is string) && !value.Cast().Any()) - { - _string.Append("EmptyList"); - } - else - { - _string.Append(param.Name); - } - } + VisitParameter(param); } else { - if (expression.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) - { - _string.Append("{"); - _string.Append(String.Join(",", value.Cast())); - _string.Append("}"); - } - else - { - _string.Append(expression.Value); - } - } + VisitConstantValue(expression.Value); } return base.VisitConstant(expression); } - private T AppendCommas(T expression) where T : Expression + private void VisitConstantValue(object value) { - Visit(expression); - _string.Append(", "); + if (value == null) + { + _string.Append("NULL"); + return; + } - return expression; + if (value is string) + { + _string.Append('"'); + _string.Append(value); + _string.Append('"'); + return; + } + + if (value is IEnumerable enumerable && !(value is IQueryable)) + { + _string.Append("{"); + _string.Append(string.Join(",", enumerable.Cast())); + _string.Append("}"); + return; + } + + // When MappedAs is used we have to put all sql types information in the key in order to + // distinct when different precisions/sizes are used. + if (_sessionFactory != null && value is IType type) + { + _string.Append(type.Name); + _string.Append('['); + _string.Append(string.Join(",", type.SqlTypes(_sessionFactory).Select(o => o.ToString()))); + _string.Append(']'); + return; + } + + var stringValue = value.ToString(); + if (stringValue == value.GetType().ToString()) + { + _string.Append("value("); + _string.Append(stringValue); + _string.Append(')'); + return; + } + + _string.Append(value); + } + + private void VisitParameter(NamedParameter param) + { + // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. + if (param.Value == null) + { + _string.Append("NULL"); + return; + } + + if (param.IsCollection && !((IEnumerable) param.Value).Cast().Any()) + { + _string.Append("EmptyList"); + } + else + { + _string.Append(param.Name); + } + + // Add the type in order to avoid invalid parameter conversions (string -> char) + _string.Append("<"); + _string.Append(param.Value.GetType()); + _string.Append(">"); } protected override Expression VisitLambda(Expression expression) { _string.Append('('); - Visit(expression.Parameters, AppendCommas); + Visit(expression.Parameters, ','); _string.Append(") => ("); Visit(expression.Body); _string.Append(')'); @@ -149,9 +232,36 @@ protected override Expression VisitLambda(Expression expression) return expression; } + protected override Expression VisitListInit(ListInitExpression expression) + { + Visit(expression.NewExpression); + _string.Append(" {"); + Visit(expression.Initializers, VisitElementInit, ','); + _string.Append('}'); + + return expression; + } + + protected override Expression VisitRuntimeVariables(RuntimeVariablesExpression node) + { + _string.Append('('); + Visit(node.Variables, ','); + _string.Append(')'); + + return node; + } + protected override Expression VisitMember(MemberExpression expression) { - base.VisitMember(expression); + if (expression.Expression != null) + { + Visit(expression.Expression); + } + else + { + // Static members + _string.Append(expression.Member.DeclaringType.Name); + } _string.Append('.'); _string.Append(expression.Member.Name); @@ -159,25 +269,127 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } + protected override Expression VisitMemberInit(MemberInitExpression expression) + { + if (expression.NewExpression.Arguments.Count == 0 && + expression.NewExpression.Type.Name.Contains('<')) + { + // Anonymous type + _string.Append("new"); + } + else + { + Visit(expression.NewExpression); + } + + _string.Append(" {"); + Visit(expression.Bindings, VisitMemberBinding, ','); + _string.Append('}'); + + return expression; + } + + protected override MemberAssignment VisitMemberAssignment(MemberAssignment assignment) + { + _string.Append(assignment.Member.Name); + _string.Append(" = "); + Visit(assignment.Expression); + + return assignment; + } + + protected override MemberListBinding VisitMemberListBinding(MemberListBinding binding) + { + _string.Append(binding.Member.Name); + _string.Append(" = {"); + Visit(binding.Initializers, VisitElementInit, ','); + _string.Append('}'); + + return binding; + } + + protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) + { + _string.Append(binding.Member.Name); + _string.Append(" = {"); + Visit(binding.Bindings, VisitMemberBinding, ','); + _string.Append('}'); + + return binding; + } + + protected override ElementInit VisitElementInit(ElementInit initializer) + { + _string.Append(initializer.AddMethod); + Visit(initializer.Arguments, ',', '(', ')'); + + return initializer; + } + + protected override Expression VisitInvocation(InvocationExpression expression) + { +#if NETCOREAPP2_0 + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var memberBinder)) + { + Visit(expression.Arguments[1]); + FormatBinder(memberBinder); + return expression; + } +#endif + + _string.Append("Invoke("); + Visit(expression.Expression); + Visit(expression.Arguments, ',', null, ')'); + + return expression; + } + protected override Expression VisitMethodCall(MethodCallExpression expression) { - Visit(expression.Object); - _string.Append('.'); + if (expression.Object != null) + { + Visit(expression.Object); + _string.Append('.'); + } + VisitMethod(expression.Method); - _string.Append('('); - ExpressionVisitor.Visit(expression.Arguments, AppendCommas); - _string.Append(')'); + Visit(expression.Arguments, ',', '(', ')'); return expression; } + protected override Expression VisitNewArray(NewArrayExpression node) + { + switch (node.NodeType) + { + case ExpressionType.NewArrayBounds: + // new MyType[](expr1, expr2) + _string.Append("new "); + _string.Append(node.Type); + Visit(node.Expressions, ',', '(', ')'); + break; + case ExpressionType.NewArrayInit: + // new [] {expr1, expr2} + _string.Append("new [] "); + Visit(node.Expressions, ',', '{', '}'); + break; + } + + return node; + } + protected override Expression VisitNew(NewExpression expression) { _string.Append("new "); - _string.Append(expression.Constructor.DeclaringType.AssemblyQualifiedName); - _string.Append('('); - Visit(expression.Arguments, AppendCommas); - _string.Append(')'); + _string.Append(GetTypeName(expression.Constructor.DeclaringType)); + var preVisitAction = expression.Members != null + ? i => + { + _string.Append(expression.Members[i].Name); + _string.Append(" = "); + } + : (Action) null; + Visit(expression.Arguments, ',', '(', ')', preVisitAction); return expression; } @@ -191,10 +403,19 @@ protected override Expression VisitParameter(ParameterExpression expression) protected override Expression VisitTypeBinary(TypeBinaryExpression expression) { - _string.Append("IsType("); + _string.Append("("); Visit(expression.Expression); - _string.Append(", "); - _string.Append(expression.TypeOperand.AssemblyQualifiedName); + switch (expression.NodeType) + { + case ExpressionType.TypeIs: + _string.Append(" Is "); + break; + case ExpressionType.TypeEqual: + _string.Append(" TypeEqual "); + break; + } + + _string.Append(GetTypeName(expression.TypeOperand)); _string.Append(")"); return expression; @@ -205,21 +426,99 @@ protected override Expression VisitUnary(UnaryExpression expression) _string.Append(expression.NodeType); _string.Append('('); Visit(expression.Operand); + + switch (expression.NodeType) + { + case ExpressionType.TypeAs: + _string.Append(" As "); + _string.Append(GetTypeName(expression.Type)); + _string.Append(')'); + break; + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + _string.Append(", "); + _string.Append(GetTypeName(expression.Type)); + _string.Append(')'); + break; + default: + _string.Append(')'); + break; + } + + return expression; + } + + protected override Expression VisitDefault(DefaultExpression node) + { + _string.Append("default("); + _string.Append(GetTypeName(node.Type)); + _string.Append(')'); + + return node; + } + + protected override Expression VisitIndex(IndexExpression node) + { + if (node.Object != null) + { + Visit(node.Object); + } + else + { + _string.Append(node.Indexer.DeclaringType.Name); + } + if (node.Indexer != null) + { + _string.Append('.'); + _string.Append(node.Indexer.Name); + } + + Visit(node.Arguments, ',', '[', ']'); + + return node; + } + + protected override Expression VisitExtension(Expression expression) + { + _string.Append(expression.GetType()); + _string.Append('('); + base.VisitExtension(expression); _string.Append(')'); return expression; } - protected override Expression VisitQuerySourceReference(Remotion.Linq.Clauses.Expressions.QuerySourceReferenceExpression expression) + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { - _string.Append(expression.ReferencedQuerySource.ItemName); + // When parameters are involved we have to traverse the reference query source in order + // to replace constant expressions with parameter names + if (_constantToParameterMap != null) + { + QueryModelKeyVisitor.VisitQuerySource(expression.ReferencedQuerySource); + } + else + { + _string.Append(expression.ReferencedQuerySource); + } + return expression; } protected override Expression VisitDynamic(DynamicExpression expression) { + Visit(expression.Arguments, ',', '(', ')'); FormatBinder(expression.Binder); - Visit(expression.Arguments, AppendCommas); + + return expression; + } + + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + _string.Append("SubQuery"); + _string.Append('('); + QueryModelKeyVisitor.VisitQueryModel(expression.QueryModel); + _string.Append(')'); + return expression; } @@ -229,11 +528,60 @@ private void VisitMethod(MethodInfo methodInfo) if (methodInfo.IsGenericMethod) { _string.Append('['); - _string.Append(string.Join(",", methodInfo.GetGenericArguments().Select(a => a.AssemblyQualifiedName))); + _string.Append(string.Join(",", methodInfo.GetGenericArguments().Select(GetTypeName))); _string.Append(']'); } } + private void Visit(ReadOnlyCollection nodes, char separator, Action preVisitAction = null) + where T : Expression + { + Visit(nodes, separator, null, null, preVisitAction); + } + + private void Visit( + ReadOnlyCollection nodes, + char separator, + char? openSymbol, + char? closeSymbol, + Action preVisitAction = null) + where T : Expression + { + if (openSymbol.HasValue) + { + _string.Append(openSymbol.Value); + } + + for (var i = 0; i < nodes.Count; i++) + { + if (i > 0) + { + _string.Append(separator); + } + + preVisitAction?.Invoke(i); + Visit(nodes[i]); + } + + if (closeSymbol.HasValue) + { + _string.Append(closeSymbol.Value); + } + } + + private void Visit(ReadOnlyCollection nodes, Func elementVisitor, char separator) + { + for (var i = 0; i < nodes.Count; i++) + { + if (i > 0) + { + _string.Append(separator); + } + + elementVisitor(nodes[i]); + } + } + private void FormatBinder(CallSiteBinder binder) { switch (binder) @@ -279,5 +627,12 @@ private void FormatBinder(CallSiteBinder binder) break; } } + + internal static string GetTypeName(System.Type type) + { + return type.Namespace == "System" + ? type.FullName + : type.AssemblyQualifiedName; + } } } diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 45134248a51..f6a9e5de43f 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -1,9 +1,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; @@ -18,23 +20,18 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly Dictionary _variableParameters = new Dictionary(); + private readonly HashSet _collectionParameters = new HashSet(); private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private static readonly MethodInfo QueryableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), 0); - private static readonly MethodInfo QueryableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), 0); - private static readonly MethodInfo EnumerableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), 0); - private static readonly MethodInfo EnumerableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), 0); - - private readonly ICollection _pagingMethods = new HashSet - { - QueryableSkipDefinition, QueryableTakeDefinition, - EnumerableSkipDefinition, EnumerableTakeDefinition - }; + private static readonly HashSet PagingMethods = new HashSet + { + ReflectionCache.EnumerableMethods.SkipDefinition, + ReflectionCache.EnumerableMethods.TakeDefinition, + ReflectionCache.QueryableMethods.SkipDefinition, + ReflectionCache.QueryableMethods.TakeDefinition + }; // Since v5.3 [Obsolete("Please use overload with preTransformationResult parameter instead.")] @@ -47,6 +44,7 @@ public ExpressionParameterVisitor(PreTransformationResult preTransformationResul { _sessionFactory = preTransformationResult.SessionFactory; _queryVariables = preTransformationResult.QueryVariables; + _functionRegistry = _sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } // Since v5.3 @@ -59,22 +57,19 @@ public static IDictionary Visit(Expression e return visitor._parameters; } - public static Expression Visit( - PreTransformationResult preTransformationResult, - out IDictionary parameters) + public static IDictionary Visit(PreTransformationResult preTransformationResult) { var visitor = new ExpressionParameterVisitor(preTransformationResult); - var expression = visitor.Visit(preTransformationResult.Expression); - parameters = visitor._parameters; - - return expression; + visitor.Visit(preTransformationResult.Expression); + return visitor._parameters; } protected override Expression VisitMethodCall(MethodCallExpression expression) { - if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) + if (VisitorUtil.IsMappedAs(expression.Method)) { var rawParameter = Visit(expression.Arguments[0]); + // TODO 6.0: Remove below code and return expression as this logic is now inside ConstantTypeLocator var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -95,10 +90,10 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) ? expression.Method.GetGenericMethodDefinition() : expression.Method; - if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) + if (PagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) { - //TODO: find a way to make this code cleaner var query = Visit(expression.Arguments[0]); + //TODO 6.0: Remove the below code and return expression var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -107,6 +102,13 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } + if (_functionRegistry != null && + _functionRegistry.TryGetGenerator(method, out var generator) && + generator.TryGetCollectionParameters(expression, out var collectionParameter)) + { + _collectionParameters.Add(collectionParameter); + } + if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory)) { return expression; @@ -115,6 +117,20 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return base.VisitMethodCall(expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out _)) + { + // Avoid adding System.Runtime.CompilerServices.CallSite instance as a parameter + base.Visit(expression.Arguments[1]); + return expression; + } + + return base.VisitInvocation(expression); + } +#endif + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) @@ -125,11 +141,14 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); // Constant characters should be sent as strings - if (expression.Type == typeof(char)) + // TODO 6.0: Remove + if (_queryVariables == null && expression.Type == typeof(char)) { value = value.ToString(); } @@ -144,13 +163,13 @@ protected override Expression VisitConstant(ConstantExpression expression) _queryVariables.TryGetValue(expression, out var variable) && !_variableParameters.TryGetValue(variable, out parameter)) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); _variableParameters.Add(variable, parameter); } if (parameter == null) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); } _parameters.Add(expression, parameter); @@ -161,6 +180,15 @@ protected override Expression VisitConstant(ConstantExpression expression) return base.VisitConstant(expression); } + private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) + { + return new NamedParameter( + "p" + (_parameters.Count + 1), + value, + type, + _collectionParameters.Contains(expression)); + } + private static bool IsNullObject(ConstantExpression expression) { return expression.Type == typeof(Object) && expression.Value == null; diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index cd9cd49eadb..97fbbc938ea 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Data; using System.Dynamic; using System.Linq; @@ -19,6 +20,13 @@ namespace NHibernate.Linq.Visitors { public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor { + private static readonly HashSet IntegerTypes = new HashSet + { + typeof(short), typeof(ushort), + typeof(int), typeof(uint), + typeof(long), typeof(ulong) + }; + private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; @@ -226,18 +234,14 @@ private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { - //This is an ugly workaround for dynamic expressions. - //Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression - if (expression.Arguments.Count == 2 && - expression.Arguments[0] is ConstantExpression constant && - constant.Value is CallSite site && - site.Binder is GetMemberBinder binder) +#if NETCOREAPP2_0 + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var binder)) { return _hqlTreeBuilder.Dot( VisitExpression(expression.Arguments[1]).AsExpression(), _hqlTreeBuilder.Ident(binder.Name)); } - +#endif return VisitExpression(expression.Expression); } @@ -295,8 +299,7 @@ protected HqlTreeNode VisitNhSum(NhSumExpression expression) protected HqlTreeNode VisitNhDistinct(NhDistinctExpression expression) { - var visitor = new HqlGeneratorExpressionVisitor(_parameters); - return _hqlTreeBuilder.ExpressionSubTreeHolder(_hqlTreeBuilder.Distinct(), visitor.VisitExpression(expression.Expression)); + return _hqlTreeBuilder.ExpressionSubTreeHolder(_hqlTreeBuilder.Distinct(), VisitExpression(expression.Expression)); } protected HqlTreeNode VisitQuerySourceReferenceExpression(QuerySourceReferenceExpression expression) @@ -341,6 +344,7 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return _hqlTreeBuilder.BooleanOr(lhs.ToBooleanExpression(), rhs.ToBooleanExpression()); case ExpressionType.Add: + case ExpressionType.AddChecked: if (expression.Left.Type == typeof (string) && expression.Right.Type == typeof(string)) { return _hqlTreeBuilder.MethodCall("concat", lhs, rhs); @@ -348,16 +352,46 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return _hqlTreeBuilder.Add(lhs, rhs); case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: return _hqlTreeBuilder.Subtract(lhs, rhs); case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: return _hqlTreeBuilder.Multiply(lhs, rhs); case ExpressionType.Divide: + // In some databases (e.g. Oracle) division of two integer produces a decimal value where in .NET + // the result is an integer. Use floor method if exists otherwise cast in order to prevent ORA-01406 + // error in Oracle and to simulate what .NET does. We cannot use always cast method as in some + // databases (e.g. Oracle) it rounds the result. + if (IntegerTypes.Contains(expression.Left.Type.UnwrapIfNullable()) && + IntegerTypes.Contains(expression.Right.Type.UnwrapIfNullable())) + { + return _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("floor") != null + ? (HqlTreeNode) _hqlTreeBuilder.MethodCall("floor", _hqlTreeBuilder.Divide(lhs, rhs)) + : _hqlTreeBuilder.Cast(_hqlTreeBuilder.Divide(lhs, rhs), expression.Type); + } + + // In Oracle division can return a number with up to 40 digits, which cannot be retrieved from the data reader due to the lack of such + // numeric type in .NET (this does not apply for binary_float and binary_double). In order to avoid that we have to add a cast to trim + // the number so that it can be converted into a .NET numeric type. + // We have to avoid casting for other dialects as in some databases (e.g. MySql) it may not return the same number. + if (!_parameters.SessionFactory.Dialect.SupportsIEEE754FloatingPointNumbers) + { + return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Divide(lhs, rhs), expression.Type); + } + return _hqlTreeBuilder.Divide(lhs, rhs); case ExpressionType.Modulo: - return _hqlTreeBuilder.MethodCall("mod", lhs, rhs); + var modFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("mod"); + var modReturnType = modFunction?.GetEffectiveReturnType( + ExpressionsHelper.GetTypes(_parameters, expression.Left, expression.Right), + _parameters.SessionFactory, + true); + return IsCastRequired(modReturnType, TypeFactory.GetDefaultTypeFor(expression.Type), out _) + ? (HqlTreeNode) _hqlTreeBuilder.Cast(_hqlTreeBuilder.MethodCall("mod", lhs, rhs), expression.Type) + : _hqlTreeBuilder.TransparentCast(_hqlTreeBuilder.MethodCall("mod", lhs, rhs), expression.Type); case ExpressionType.LessThan: return _hqlTreeBuilder.LessThan(lhs, rhs); diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 580ba3cf00c..019769fccb1 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -19,16 +19,18 @@ internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private bool _requiresJoinForNonIdentifier; private bool _preventJoinsInConditionalTest; private bool _hasIdentifier; private int _memberExpressionDepth; - public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } protected override Expression VisitMember(MemberExpression expression) @@ -55,7 +57,7 @@ protected override Expression VisitMember(MemberExpression expression) ((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) && _joiner.CanAddJoin(expression)) { - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); return _joiner.AddJoin(result, key); } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs new file mode 100644 index 00000000000..34326640169 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -0,0 +1,321 @@ +using System.Collections.Generic; +using System.Dynamic; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + /// + /// Locates parameter actual type based on its usage. + /// + public static class ParameterTypeLocator + { + /// + /// List of for which the should be related to the other side + /// of a (e.g. o.MyEnum == MyEnum.Option -> MyEnum.Option should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet ValidBinaryExpressionTypes = new HashSet + { + ExpressionType.Equal, + ExpressionType.NotEqual, + ExpressionType.GreaterThanOrEqual, + ExpressionType.GreaterThan, + ExpressionType.LessThan, + ExpressionType.LessThanOrEqual, + ExpressionType.Coalesce, + ExpressionType.Assign + }; + + /// + /// List of for which the should be copied across + /// as related (e.g. (o.MyEnum ?? MyEnum.Option) == MyEnum.Option2 -> MyEnum.Option2 should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet NonVoidOperators = new HashSet + { + ExpressionType.Coalesce, + ExpressionType.Conditional + }; + + /// + /// Set query parameter types based on the given query model. + /// + /// The query parameters. + /// The query model. + /// The target entity type. + /// The session factory. + public static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory) + { + SetParameterTypes(parameters, queryModel, targetType, sessionFactory, false); + } + + internal static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory, + bool removeMappedAsCalls) + { + if (parameters.Count == 0) + { + return; + } + + var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); + queryModel.TransformExpressions(visitor.Visit); + + foreach (var pair in visitor.ConstantExpressions) + { + var type = pair.Value; + var constantExpression = pair.Key; + if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + { + continue; + } + + if (type != null) + { + // MappedAs was used + namedParameter.Type = type; + continue; + } + + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + { + foreach (var memberExpression in memberExpressions) + { + if (ExpressionsHelper.TryGetMappedType( + sessionFactory, + memberExpression, + out type, + out _, + out _, + out _)) + { + break; + } + } + } + + // No related MemberExpressions was found, guess the type by value or its type when null. + if (type == null) + { + type = constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); + } + + namedParameter.Type = type; + } + } + + private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor + { + private readonly bool _removeMappedAsCalls; + private readonly System.Type _targetType; + private readonly IDictionary _parameters; + private readonly ISessionFactoryImplementor _sessionFactory; + public readonly Dictionary ConstantExpressions = + new Dictionary(); + public readonly Dictionary> RelatedExpressions = + new Dictionary>(); + + public ConstantTypeLocatorVisitor( + bool removeMappedAsCalls, + System.Type targetType, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + _removeMappedAsCalls = removeMappedAsCalls; + _targetType = targetType; + _sessionFactory = sessionFactory; + _parameters = parameters; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + node = (BinaryExpression) base.VisitBinary(node); + if (!ValidBinaryExpressionTypes.Contains(node.NodeType)) + { + return node; + } + + var left = Unwrap(node.Left); + var right = Unwrap(node.Right); + if (node.NodeType == ExpressionType.Assign) + { + VisitAssign(left, right); + } + else + { + AddRelatedExpression(node, left, right); + AddRelatedExpression(node, right, left); + } + + return node; + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + node = (ConditionalExpression) base.VisitConditional(node); + var ifTrue = Unwrap(node.IfTrue); + var ifFalse = Unwrap(node.IfFalse); + AddRelatedExpression(node, ifTrue, ifFalse); + AddRelatedExpression(node, ifFalse, ifTrue); + + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (VisitorUtil.IsMappedAs(node.Method)) + { + var rawParameter = Visit(node.Arguments[0]); + var parameter = rawParameter as ConstantExpression; + var type = node.Arguments[1] as ConstantExpression; + if (parameter == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} must be called on an expression which can be evaluated as " + + $"{nameof(ConstantExpression)}. It was call on {rawParameter?.GetType().Name ?? "null"} instead."); + if (type == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} type must be supplied as {nameof(ConstantExpression)}. " + + $"It was {node.Arguments[1]?.GetType().Name ?? "null"} instead."); + + ConstantExpressions[parameter] = (IType) type.Value; + + return _removeMappedAsCalls + ? rawParameter + : node; + } + + return base.VisitMethodCall(node); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) + { + return node; + } + + RelatedExpressions.Add(node, new HashSet()); + ConstantExpressions.Add(node, null); + return node; + } + + public override Expression Visit(Expression node) + { + if (node is SubQueryExpression subQueryExpression) + { + subQueryExpression.QueryModel.TransformExpressions(Visit); + } + + return base.Visit(node); + } + + private void VisitAssign(Expression leftNode, Expression rightNode) + { + // Insert and Update statements have assign expressions, where the left side is a parameter and its name + // represents the property path to be assigned + if (!(leftNode is ParameterExpression parameterExpression) || + !(rightNode is ConstantExpression constantExpression)) + { + return; + } + + var entityName = _sessionFactory.TryGetGuessEntityName(_targetType); + if (entityName == null) + { + return; + } + + var persister = _sessionFactory.GetEntityPersister(entityName); + ConstantExpressions[constantExpression] = persister.EntityMetamodel.GetPropertyType(parameterExpression.Name); + } + + private void AddRelatedExpression(Expression node, Expression left, Expression right) + { + if (left.NodeType == ExpressionType.MemberAccess || + IsDynamicMember(left) || + left is QuerySourceReferenceExpression) + { + AddRelatedExpression(right, left); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, left); + } + } + + // Copy all found MemberExpressions to the other side + // (e.g. (o.Prop ?? constant1) == constant2 -> copy o.Prop to constant2) + if (RelatedExpressions.TryGetValue(left, out var set)) + { + foreach (var nestedMemberExpression in set) + { + AddRelatedExpression(right, nestedMemberExpression); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, nestedMemberExpression); + } + } + } + } + + private void AddRelatedExpression(Expression expression, Expression relatedExpression) + { + if (!RelatedExpressions.TryGetValue(expression, out var set)) + { + set = new HashSet(); + RelatedExpressions.Add(expression, set); + } + + set.Add(relatedExpression); + } + + private bool IsDynamicMember(Expression expression) + { + switch (expression) + { +#if NETCOREAPP2_0 + case InvocationExpression invocationExpression: + // session.Query().Where("Properties.Name == @0", "First Product") + return ExpressionsHelper.TryGetDynamicMemberBinder(invocationExpression, out _); +#endif + case DynamicExpression dynamicExpression: + return dynamicExpression.Binder is GetMemberBinder; + case MethodCallExpression methodCallExpression: + // session.Query() where p.Properties["Name"] == "First Product" select p + return VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(methodCallExpression, out _); + default: + return false; + } + } + + private static Expression Unwrap(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + return unaryExpression.Operand; + } + + return expression; + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelKeyVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelKeyVisitor.cs new file mode 100644 index 00000000000..fa8783e81dc --- /dev/null +++ b/src/NHibernate/Linq/Visitors/QueryModelKeyVisitor.cs @@ -0,0 +1,269 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Text; +using NHibernate.Linq.Clauses; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.ResultOperators; +using Remotion.Linq.EagerFetching; + +namespace NHibernate.Linq.Visitors +{ + internal class QueryModelKeyVisitor : NhQueryModelVisitorBase, INhQueryModelVisitorExtended + { + private static readonly QueryModel DefaultQueryModel = new QueryModel( + new MainFromClause( + "x", + typeof(QueryModelVisitor), + Expression.Constant(0)), + new SelectClause(Expression.Constant(0))); + + private readonly ExpressionKeyVisitor _keyVisitor; + private readonly StringBuilder _string; + private HashSet _processedSources; + + public QueryModelKeyVisitor(ExpressionKeyVisitor keyVisitor, StringBuilder stringBuilder) + { + _keyVisitor = keyVisitor; + _string = stringBuilder; + } + + public override void VisitQueryModel(QueryModel queryModel) + { + if (queryModel.IsIdentityQuery()) + { + _keyVisitor.Visit(queryModel.MainFromClause.FromExpression); + } + else + { + VisitMainFromClause(queryModel.MainFromClause, queryModel); + VisitBodyClauses(queryModel.BodyClauses, queryModel); + VisitSelectClause(queryModel.SelectClause, queryModel); + } + + VisitResultOperators(queryModel.ResultOperators, queryModel); + } + + public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, QueryModel queryModel, int index) + { + VisitJoinClause(groupJoinClause.JoinClause, queryModel, index); + _string.Append(" into "); + _string.Append(groupJoinClause.ItemType.Name); + _string.Append(" "); + _string.Append(groupJoinClause.ItemName); + } + + public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + VisitJoin(joinClause.ItemType, joinClause.ItemName, joinClause.InnerSequence); + _string.Append(" on "); + _keyVisitor.Visit(joinClause.OuterKeySelector); + _string.Append(" equals "); + _keyVisitor.Visit(joinClause.InnerKeySelector); + } + + public override void VisitMainFromClause(MainFromClause fromClause, QueryModel queryModel) + { + VisitFromClauseBase(fromClause); + } + + public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) + { + VisitFromClauseBase(fromClause); + } + + public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) + { + _string.Append(" having "); + _keyVisitor.Visit(havingClause.Predicate); + } + + public override void VisitSelectClause(SelectClause selectClause, QueryModel queryModel) + { + _string.Append(" select "); + _keyVisitor.Visit(selectClause.Selector); + } + + public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) + { + _string.Append(" where "); + _keyVisitor.Visit(whereClause.Predicate); + } + + public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel queryModel, int index) + { + VisitJoin(joinClause.ItemType, joinClause.ItemName, joinClause.FromExpression); + } + + public void VisitNhOuterJoinClause(NhOuterJoinClause nhOuterJoinClause, QueryModel queryModel, int index) + { + _string.Append(" outer"); + VisitJoinClause(nhOuterJoinClause.JoinClause, queryModel, index); + } + + public override void VisitNhWithClause(NhWithClause nhWhereClause, QueryModel queryModel, int index) + { + _string.Append(" with "); + _keyVisitor.Visit(nhWhereClause.Predicate); + } + + public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel queryModel, int index) + { + _string.Append(" orderby "); + base.VisitOrderByClause(orderByClause, queryModel, index); + } + + public override void VisitOrdering(Ordering ordering, QueryModel queryModel, OrderByClause orderByClause, int index) + { + if (index > 0) + { + _string.Append(','); + } + + _keyVisitor.Visit(ordering.Expression); + _string.Append(ordering.OrderingDirection == OrderingDirection.Asc ? " asc" : " desc"); + } + + public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) + { + _string.Append(" => "); + + // Custom visitors for operators that do not expose all information with TransformExpressions method in order to mimic their ToString method + switch (resultOperator) + { + case CastResultOperator castResult: + VisitTypeChangeOperator("Cast", castResult.CastItemType); + break; + case FetchRequestBase fetchBase: + VisitFetchRequestBase(fetchBase); + break; + case ChoiceResultOperatorBase operatorBase: + VisitChoiceResultOperatorBase(operatorBase); + break; + case OfTypeResultOperator ofTypeResult: + VisitTypeChangeOperator("OfType", ofTypeResult.SearchedItemType); + break; + default: + VisitResultOperatorBase(resultOperator); + break; + } + } + + public void VisitQuerySource(IQuerySource querySource) + { + if (!AddReferencedQuerySource(querySource)) + { + _string.Append(querySource.ItemName); + return; + } + + switch (querySource) + { + case MainFromClause mainFromClause: + VisitMainFromClause(mainFromClause, DefaultQueryModel); + break; + case ResultOperatorBase resultOperator: + VisitResultOperator(resultOperator, DefaultQueryModel, 0); + break; + case NhClauseBase nhClauseBase: + nhClauseBase.Accept(this, DefaultQueryModel, 0); + break; + case IBodyClause bodyClause: + bodyClause.Accept(this, DefaultQueryModel, 0); + break; + default: + throw new NotSupportedException($"Unknown query source {querySource}"); + } + } + + private void VisitResultOperatorBase(ResultOperatorBase resultOperator) + { + _string.Append(resultOperator.GetType().Name.Replace("ResultOperator", "(")); + var index = 0; + resultOperator.TransformExpressions( + expression => + { + if (expression == null) + { + return null; + } + + if (index > 0) + { + _string.Append(','); + } + + _keyVisitor.Visit(expression); + index++; + + return expression; + }); + + _string.Append(')'); + } + + private void VisitChoiceResultOperatorBase(ChoiceResultOperatorBase operatorBase) + { + _string.Append(operatorBase.GetType().Name.Replace("ResultOperator", "")); + if (operatorBase.ReturnDefaultWhenEmpty) + { + _string.Append("OrDefault"); + } + + _string.Append("()"); + } + + private void VisitFetchRequestBase(FetchRequestBase fetchBase) + { + _string.Append("Fetch ("); + _string.Append(ExpressionKeyVisitor.GetTypeName(fetchBase.RelationMember.DeclaringType)); + _string.Append('.'); + _string.Append(fetchBase.RelationMember.Name); + _string.Append(')'); + + foreach (var innerFetch in fetchBase.InnerFetchRequests) + { + VisitFetchRequestBase(innerFetch); + } + } + + private void VisitTypeChangeOperator(string name, System.Type type) + { + _string.Append(name); + _string.Append('<'); + _string.Append(ExpressionKeyVisitor.GetTypeName(type)); + _string.Append(">()"); + } + + private void VisitJoin(System.Type itemType, string itemName, Expression expression) + { + _string.Append(" join "); + _string.Append(itemType.Name); + _string.Append(" "); + _string.Append(itemName); + _string.Append(" in "); + _keyVisitor.Visit(expression); + } + + private void VisitFromClauseBase(FromClauseBase fromClause) + { + _string.Append(" from "); + _string.Append(fromClause.ItemType.Name); + _string.Append(" "); + _string.Append(fromClause.ItemName); + _string.Append(" in "); + _keyVisitor.Visit(fromClause.FromExpression); + } + + private bool AddReferencedQuerySource(IQuerySource querySource) + { + if (_processedSources == null) + { + _processedSources = new HashSet(); + } + + return _processedSources.Add(querySource); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelRewriter.cs b/src/NHibernate/Linq/Visitors/QueryModelRewriter.cs new file mode 100644 index 00000000000..cc41ea8c981 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/QueryModelRewriter.cs @@ -0,0 +1,119 @@ +using System.Linq.Expressions; +using NHibernate.Linq.GroupBy; +using NHibernate.Linq.GroupJoin; +using NHibernate.Linq.NestedSelects; +using NHibernate.Linq.ReWriters; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + public static class QueryModelRewriter + { + /// + /// Rewrites the given and all that are found inside it. + /// + /// The query model to rewrite. + /// The visitor parameters. + public static void Rewrite(QueryModel rootQueryModel, VisitorParameters parameters) + { + Rewrite(rootQueryModel, parameters, true); + // Rewrite sub-queries + var rewriter = new SubqueryRewriterVisitor(parameters); + rootQueryModel.TransformExpressions(rewriter.Visit); + } + + internal static void Rewrite(QueryModel queryModel, VisitorParameters parameters, bool root) + { + // Expand conditionals in subquery FROM clauses into multiple subqueries + if (root) + { + // This expander works recursively + SubQueryConditionalExpander.ReWrite(queryModel); + } + + NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); + + // Remove unnecessary body operators + RemoveUnnecessaryBodyOperators.ReWrite(queryModel); + + // Merge aggregating result operators (distinct, count, sum etc) into the select clause + MergeAggregatingResultsRewriter.ReWrite(queryModel); + + // Swap out non-aggregating group-bys + NonAggregatingGroupByRewriter.ReWrite(queryModel); + + // Rewrite aggregate group-by statements + AggregatingGroupByRewriter.ReWrite(queryModel); + + // Rewrite aggregating group-joins + AggregatingGroupJoinRewriter.ReWrite(queryModel); + + // Rewrite non-aggregating group-joins + NonAggregatingGroupJoinRewriter.ReWrite(queryModel); + + SubQueryFromClauseFlattener.ReWrite(queryModel); + + // Rewrite left-joins + LeftJoinRewriter.ReWrite(queryModel); + + // Rewrite paging + PagingRewriter.ReWrite(queryModel); + + // Flatten pointless subqueries + QueryReferenceExpressionFlattener.ReWrite(queryModel); + + // Flatten array index access to query references + ArrayIndexExpressionFlattener.ReWrite(queryModel); + + // Add joins for references + AddJoinsReWriter.ReWrite(queryModel, parameters); + + // Expand coalesced and conditional joins to their logical equivalents + ConditionalQueryReferenceExpander.ReWrite(queryModel); + + // Move OrderBy clauses to end + MoveOrderByToEndRewriter.ReWrite(queryModel); + + // Give a rewriter provided by the session factory a chance to + // rewrite the query. + var rewriterFactory = parameters.SessionFactory.Settings.QueryModelRewriterFactory; + var customVisitor = rewriterFactory?.CreateVisitor(parameters); + customVisitor?.VisitQueryModel(queryModel); + + // rewrite any operators that should be applied on the outer query + // by flattening out the sub-queries that they are located in + parameters.AddQueryModelRewriterResult(queryModel, ResultOperatorRewriter.Rewrite(queryModel)); + + // Remove conditional expressions where they can be reduced to just their IfTrue or IfFalse part. + SimplifyConditionalRewriter.Rewrite(queryModel); + + // Identify and name query sources + QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel); + } + + private class SubqueryRewriterVisitor : RelinqExpressionVisitor + { + private readonly VisitorParameters _parameters; + + public SubqueryRewriterVisitor(VisitorParameters parameters) + { + _parameters = parameters; + } + + public override Expression Visit(Expression node) + { + if (node?.NodeType == ExpressionType.Extension && + node is SubQueryExpression subQueryExpression && + !_parameters.QueryModelRewriterResults.ContainsKey(subQueryExpression.QueryModel)) + { + Rewrite(subQueryExpression.QueryModel, _parameters, false); + subQueryExpression.QueryModel.TransformExpressions(Visit); + } + + return base.Visit(node); + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 040e9b38932..615da6d830c 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -7,8 +7,6 @@ using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.GroupBy; -using NHibernate.Linq.GroupJoin; -using NHibernate.Linq.NestedSelects; using NHibernate.Linq.ResultOperators; using NHibernate.Linq.ReWriters; using NHibernate.Linq.Visitors.ResultOperatorProcessors; @@ -30,73 +28,14 @@ public class QueryModelVisitor : NhQueryModelVisitorBase, INhQueryModelVisitor, public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel queryModel, VisitorParameters parameters, bool root, NhLinqExpressionReturnType? rootReturnType) { - // Expand conditionals in subquery FROM clauses into multiple subqueries - if (root) + // Rewrite the query model in case it was not yet rewritten + if (!parameters.QueryModelRewriterResults.TryGetValue(queryModel, out var result)) { - // This expander works recursively - SubQueryConditionalExpander.ReWrite(queryModel); + // TODO 6.0: Throw an exception + QueryModelRewriter.Rewrite(queryModel, parameters, root); + result = parameters.QueryModelRewriterResults[queryModel]; } - NestedSelectRewriter.ReWrite(queryModel, parameters.SessionFactory); - - // Remove unnecessary body operators - RemoveUnnecessaryBodyOperators.ReWrite(queryModel); - - // Merge aggregating result operators (distinct, count, sum etc) into the select clause - MergeAggregatingResultsRewriter.ReWrite(queryModel); - - // Swap out non-aggregating group-bys - NonAggregatingGroupByRewriter.ReWrite(queryModel); - - // Rewrite aggregate group-by statements - AggregatingGroupByRewriter.ReWrite(queryModel); - - // Rewrite aggregating group-joins - AggregatingGroupJoinRewriter.ReWrite(queryModel); - - // Rewrite non-aggregating group-joins - NonAggregatingGroupJoinRewriter.ReWrite(queryModel); - - SubQueryFromClauseFlattener.ReWrite(queryModel); - - // Rewrite left-joins - LeftJoinRewriter.ReWrite(queryModel); - - // Rewrite paging - PagingRewriter.ReWrite(queryModel); - - // Flatten pointless subqueries - QueryReferenceExpressionFlattener.ReWrite(queryModel); - - // Flatten array index access to query references - ArrayIndexExpressionFlattener.ReWrite(queryModel); - - // Add joins for references - AddJoinsReWriter.ReWrite(queryModel, parameters); - - // Expand coalesced and conditional joins to their logical equivalents - ConditionalQueryReferenceExpander.ReWrite(queryModel); - - // Move OrderBy clauses to end - MoveOrderByToEndRewriter.ReWrite(queryModel); - - // Give a rewriter provided by the session factory a chance to - // rewrite the query. - var rewriterFactory = parameters.SessionFactory.Settings.QueryModelRewriterFactory; - if (rewriterFactory != null) - { - var customVisitor = rewriterFactory.CreateVisitor(parameters); - if (customVisitor != null) - customVisitor.VisitQueryModel(queryModel); - } - - // rewrite any operators that should be applied on the outer query - // by flattening out the sub-queries that they are located in - var result = ResultOperatorRewriter.Rewrite(queryModel); - - // Identify and name query sources - QuerySourceIdentifier.Visit(parameters.QuerySourceNamer, queryModel); - var visitor = new QueryModelVisitor(parameters, root, queryModel, rootReturnType) { RewrittenOperatorResult = result, @@ -490,9 +429,6 @@ private void VisitDeleteClause(Expression expression) public override void VisitWhereClause(WhereClause whereClause, QueryModel queryModel, int index) { - var visitor = new SimplifyConditionalVisitor(); - whereClause.Predicate = visitor.Visit(whereClause.Predicate); - // Visit the predicate to build the query var expression = HqlGeneratorExpressionVisitor.Visit(whereClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddWhereClause(expression); @@ -558,9 +494,6 @@ public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, Query public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel queryModel, int index) { - var visitor = new SimplifyConditionalVisitor(); - havingClause.Predicate = visitor.Visit(havingClause.Predicate); - // Visit the predicate to build the query var expression = HqlGeneratorExpressionVisitor.Visit(havingClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddHavingClause(expression); @@ -568,9 +501,6 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel public override void VisitNhWithClause(NhWithClause withClause, QueryModel queryModel, int index) { - var visitor = new SimplifyConditionalVisitor(); - withClause.Predicate = visitor.Visit(withClause.Predicate); - // Visit the predicate to build the query var expression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddWhereClause(expression); diff --git a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs index 10d7bd07d3c..2867455062f 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs @@ -1,11 +1,12 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; -using NHibernate.Engine; +using NHibernate.Dialect.Function; +using NHibernate.Hql.Ast; using NHibernate.Linq.Functions; using NHibernate.Linq.Expressions; using NHibernate.Util; -using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors { @@ -13,10 +14,34 @@ namespace NHibernate.Linq.Visitors /// Analyze the select clause to determine what parts can be translated /// fully to HQL, and some other properties of the clause. /// - class SelectClauseHqlNominator : RelinqExpressionVisitor + class SelectClauseHqlNominator { + private static readonly HashSet DateTypes = new HashSet + { + typeof(DateTime), typeof(DateTimeOffset), typeof(TimeSpan) + }; + + private static readonly HashSet ArithmeticOperations = new HashSet + { + ExpressionType.Add, ExpressionType.AddChecked, + ExpressionType.Subtract, ExpressionType.SubtractChecked, + ExpressionType.Multiply, ExpressionType.MultiplyChecked, + ExpressionType.Divide + }; + + private static readonly HashSet BitwiseOperations = new HashSet + { + ExpressionType.And, + ExpressionType.Or, + ExpressionType.ExclusiveOr, + ExpressionType.OnesComplement, + ExpressionType.LeftShift, + ExpressionType.RightShift + }; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private readonly ISessionFactoryImplementor _sessionFactory; + private readonly VisitorParameters _parameters; + private bool _forceClientSide; /// /// The expression parts that can be converted to pure HQL. @@ -31,155 +56,383 @@ class SelectClauseHqlNominator : RelinqExpressionVisitor /// public bool ContainsUntranslatedMethodCalls { get; private set; } - private bool _canBeCandidate; - Stack _stateStack; - public SelectClauseHqlNominator(VisitorParameters parameters) { + _parameters = parameters; _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; - _sessionFactory = parameters.SessionFactory; } - internal Expression Nominate(Expression expression) + internal void Nominate(Expression expression) { HqlCandidates = new HashSet(); ContainsUntranslatedMethodCalls = false; - _canBeCandidate = true; - _stateStack = new Stack(); - _stateStack.Push(false); - - return Visit(expression); + CanBeEvaluatedInHql(expression); } - public override Expression Visit(Expression expression) + private bool CanBeEvaluatedInHql(Expression expression) { - if (expression == null) - return null; - - if (expression is NhNominatedExpression nominatedExpression) + // Do client side evaluation for constants + if (expression == null || expression.NodeType == ExpressionType.Constant) { - // Add the nominated clause and strip the nominator wrapper from the select expression - var innerExpression = nominatedExpression.Expression; - HqlCandidates.Add(innerExpression); - return innerExpression; + return true; } - var projectConstantsInHql = _stateStack.Peek() || expression.NodeType == ExpressionType.Equal || IsRegisteredFunction(expression); + bool canBeEvaluated; + switch (expression.NodeType) + { + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Power: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.And: + case ExpressionType.Or: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + case ExpressionType.AndAlso: + case ExpressionType.OrElse: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.Coalesce: + case ExpressionType.ArrayIndex: + canBeEvaluated = CanBeEvaluatedInHql((BinaryExpression) expression); + break; + case ExpressionType.Conditional: + canBeEvaluated = CanBeEvaluatedInHql((ConditionalExpression) expression); + break; + case ExpressionType.Call: + canBeEvaluated = CanBeEvaluatedInHql((MethodCallExpression) expression); + break; + case ExpressionType.ArrayLength: + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + case ExpressionType.Not: + case ExpressionType.Quote: + case ExpressionType.TypeAs: + case ExpressionType.UnaryPlus: + canBeEvaluated = CanBeEvaluatedInHql(((UnaryExpression) expression)); + break; + case ExpressionType.MemberAccess: + canBeEvaluated = CanBeEvaluatedInHql((MemberExpression) expression); + break; + case ExpressionType.Extension: + if (expression is NhNominatedExpression nominatedExpression) + { + expression = nominatedExpression.Expression; + } + + canBeEvaluated = true; // Sub queries cannot be executed client side + break; + case ExpressionType.MemberInit: + canBeEvaluated = CanBeEvaluatedInHql((MemberInitExpression) expression); + break; + case ExpressionType.NewArrayInit: + case ExpressionType.NewArrayBounds: + canBeEvaluated = CanBeEvaluatedInHql((NewArrayExpression) expression); + break; + case ExpressionType.ListInit: + canBeEvaluated = CanBeEvaluatedInHql((ListInitExpression) expression); + break; + case ExpressionType.New: + canBeEvaluated = CanBeEvaluatedInHql((NewExpression) expression); + break; + case ExpressionType.Dynamic: + canBeEvaluated = CanBeEvaluatedInHql((DynamicExpression) expression); + break; + case ExpressionType.Invoke: + canBeEvaluated = CanBeEvaluatedInHql((InvocationExpression) expression); + break; + case ExpressionType.TypeIs: + canBeEvaluated = CanBeEvaluatedInHql(((TypeBinaryExpression) expression).Expression); + break; + default: + canBeEvaluated = true; + break; + } - // Set some flags, unless we already have proper values for them: - // projectConstantsInHql if they are inside a method call executed server side. - // ContainsUntranslatedMethodCalls if a method call must be executed locally. - var isMethodCall = expression.NodeType == ExpressionType.Call; - if (isMethodCall && (!projectConstantsInHql || !ContainsUntranslatedMethodCalls)) + if (canBeEvaluated) { - var isRegisteredFunction = IsRegisteredFunction(expression); - projectConstantsInHql = projectConstantsInHql || isRegisteredFunction; - ContainsUntranslatedMethodCalls = ContainsUntranslatedMethodCalls || !isRegisteredFunction; + HqlCandidates.Add(expression); } - _stateStack.Push(projectConstantsInHql); - bool saveCanBeCandidate = _canBeCandidate; - _canBeCandidate = true; + return canBeEvaluated; + } - try + private bool CanBeEvaluatedInHql(MethodCallExpression methodExpression) + { + if (VisitorUtil.TryGetEvalExpression(methodExpression, out var expression)) { - if (CanBeEvaluatedInHqlStatementShortcut(expression)) + if (methodExpression.Method.Name == nameof(ExpressionEvaluation.DatabaseEval)) { HqlCandidates.Add(expression); - return expression; + return false; } - expression = base.Visit(expression); - - if (_canBeCandidate) + if (_forceClientSide) { - if (CanBeEvaluatedInHqlSelectStatement(expression, projectConstantsInHql)) - { - HqlCandidates.Add(expression); - } - else - { - _canBeCandidate = false; - } + throw new InvalidOperationException( + $"{nameof(ExpressionEvaluation.ClientEval)} cannot be used inside another {nameof(ExpressionEvaluation.ClientEval)}."); } + + _forceClientSide = true; + CanBeEvaluatedInHql(expression); + _forceClientSide = false; + return false; } - finally + + var canBeEvaluated = _functionRegistry.TryGetGenerator(methodExpression.Method, out var methodGenerator) && !_forceClientSide; + canBeEvaluated &= methodExpression.Object == null || // Is static or extension method + // Does not ignore the parameter it belongs to + methodGenerator?.IgnoreInstance(methodExpression.Method) == true || + ( + // Does not belong to a parameter + methodExpression.Object.NodeType != ExpressionType.Constant && + CanBeEvaluatedInHql(methodExpression.Object) + ); + foreach (var argumentExpression in methodExpression.Arguments) { - _stateStack.Pop(); - _canBeCandidate = _canBeCandidate && saveCanBeCandidate; + // If one of the arguments cannot be converted to hql we have to execute the method on the client side + canBeEvaluated &= CanBeEvaluatedInHql(argumentExpression); } - return expression; + ContainsUntranslatedMethodCalls |= !canBeEvaluated; + return canBeEvaluated; } - private bool IsRegisteredFunction(Expression expression) + private bool CanBeEvaluatedInHql(MemberExpression memberExpression) { - if (expression.NodeType == ExpressionType.Call) + if (!CanBeEvaluatedInHql(memberExpression.Expression)) + { + return false; + } + + // Check for a mapped property e.g. Count + if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) + { + return !_forceClientSide; + } + + // Check whether the member is mapped. TryGetEntityName will return not return the entity name when the + // member is part of a composite element of a collection, so check if the type was found. + return ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, memberExpression, out _, out _, out _, out _); + } + + private bool CanBeEvaluatedInHql(ConditionalExpression conditionalExpression) + { + var canBeEvaluated = CanBeEvaluatedInHql(conditionalExpression.Test); + // In Oracle, when a query that selects a parameter is executed multiple times with different parameter types, + // will fail to get the value from the data reader. e.g. select case when then @p0 else @p1 end. + // In order to prevent that, we have to execute only the condition on the server side and do the rest on the client side. + if (canBeEvaluated && + conditionalExpression.IfTrue.NodeType == ExpressionType.Constant && + conditionalExpression.IfFalse.NodeType == ExpressionType.Constant) + { + return false; + } + + canBeEvaluated &= (CanBeEvaluatedInHql(conditionalExpression.IfTrue) && HqlIdent.SupportsType(conditionalExpression.IfTrue.Type)) & + (CanBeEvaluatedInHql(conditionalExpression.IfFalse) && HqlIdent.SupportsType(conditionalExpression.IfFalse.Type)); + + return !_forceClientSide && canBeEvaluated; + } + + private bool CanBeEvaluatedInHql(UnaryExpression unaryExpression) + { + return CanBeEvaluatedInHql(unaryExpression.Operand) && + CanBitwiseOperationBeEvaluatedInHql(unaryExpression.NodeType, unaryExpression.Operand); + } + + private bool CanBeEvaluatedInHql(BinaryExpression binaryExpression) + { + var canBeEvaluated = CanBeEvaluatedInHql(binaryExpression.Left) & + CanBeEvaluatedInHql(binaryExpression.Right); + if (!canBeEvaluated || _forceClientSide) + { + return false; + } + + // Subtract dates on the client side as the result varies when executed on the server side. + // In Sql Server when using datetime2 subtract is not possible. + // In Oracle a number is returned that represents the difference between the two in days. + if ((binaryExpression.NodeType == ExpressionType.Subtract || binaryExpression.NodeType == ExpressionType.SubtractChecked) && + ContainsAnyOfTypes(DateTypes, binaryExpression.Left, binaryExpression.Right)) + { + return false; + } + + if (!CanBitwiseOperationBeEvaluatedInHql( + binaryExpression.NodeType, + binaryExpression.Left, + binaryExpression.Right)) + { + return false; + } + + if (!CanArithmeticOperationBeEvaluatedInHql(binaryExpression)) + { + return false; + } + + // Concatenation of strings can be only done on the server side when the left and right side types match. + if (binaryExpression.NodeType == ExpressionType.Add && + (binaryExpression.Left.Type == typeof(string) || binaryExpression.Right.Type == typeof(string))) + { + return binaryExpression.Left.Type == binaryExpression.Right.Type; + } + + if (binaryExpression.NodeType == ExpressionType.Modulo) { - var methodCallExpression = (MethodCallExpression) expression; - if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out var methodGenerator)) + var sqlFunction = _parameters.SessionFactory.SQLFunctionRegistry.FindSQLFunction("mod"); + if (sqlFunction == null || !(sqlFunction is ISQLFunctionExtended extendedSqlFunction)) { - // is static or extension method - return methodCallExpression.Object == null || - // does not belong to parameter - methodCallExpression.Object.NodeType != ExpressionType.Constant || - // does not ignore the parameter it belongs to - methodGenerator.IgnoreInstance(methodCallExpression.Method); + return false; // Fallback to old behavior } + + var arguments = ExpressionsHelper.GetTypes(_parameters, binaryExpression.Left, binaryExpression.Right); + return extendedSqlFunction.GetEffectiveReturnType(arguments, _parameters.SessionFactory, false) != null; } - else if (expression is NhSumExpression || - expression is NhCountExpression || - expression is NhAverageExpression || - expression is NhMaxExpression || - expression is NhMinExpression) + + return true; + } + + private bool CanBitwiseOperationBeEvaluatedInHql(ExpressionType expressionType, params Expression[] expressions) + { + if (!BitwiseOperations.Contains(expressionType) || !ContainsType(typeof(bool), expressions)) { return true; } - return false; + + return _parameters.SessionFactory.Dialect.SupportsBitwiseOperatorsOnBoolean; } - private bool CanBeEvaluatedInHqlSelectStatement(Expression expression, bool projectConstantsInHql) + private bool CanArithmeticOperationBeEvaluatedInHql(BinaryExpression expression) { - // HQL can't do New or Member Init - if (expression.NodeType == ExpressionType.MemberInit || - expression.NodeType == ExpressionType.New || - expression.NodeType == ExpressionType.NewArrayInit || - expression.NodeType == ExpressionType.NewArrayBounds) + if (!ContainsType(typeof(decimal), expression.Left, expression.Right)) + { + return true; + } + + // Some databases (e.g. SQLite) stores decimals as a floating point number, which may cause incorrect results when using + // any arithmetic operation. + if (_parameters.SessionFactory.Dialect.IsDecimalStoredAsFloatingPointNumber && ArithmeticOperations.Contains(expression.NodeType)) { return false; } - // Constants will only be evaluated in HQL if they're inside a method call - if (expression.NodeType == ExpressionType.Constant) + // Divide and Multiply operator on decimals produce different results when executed on server side, due to different precisions. + // In order to achieve the best precision possible, do the calculation on the client. + if (expression.NodeType == ExpressionType.Divide || + expression.NodeType == ExpressionType.Multiply || + expression.NodeType == ExpressionType.MultiplyChecked) { - return projectConstantsInHql; + return false; } - if (expression.NodeType == ExpressionType.Call) + return true; + } + + private bool CanBeEvaluatedInHql(MemberInitExpression memberInitExpression) + { + CanBeEvaluatedInHql(memberInitExpression.NewExpression); + VisitMemberBindings(memberInitExpression.Bindings); + return false; + } + + private bool CanBeEvaluatedInHql(DynamicExpression dynamicExpression) + { + foreach (var argument in dynamicExpression.Arguments) { - // Depends if it's in the function registry - return IsRegisteredFunction(expression); + CanBeEvaluatedInHql(argument); } - if (expression.NodeType == ExpressionType.Conditional) + return false; + } + + private bool CanBeEvaluatedInHql(ListInitExpression listInitExpression) + { + CanBeEvaluatedInHql(listInitExpression.NewExpression); + foreach (var initializer in listInitExpression.Initializers) { - // Theoretically, any conditional that returns a CAST-able primitive should be constructable in HQL. - // The type needs to be CAST-able because HQL wraps the CASE clause in a CAST and only supports - // certain types (as defined by the HqlIdent constructor that takes a System.Type as the second argument). - // However, this may still not cover all cases, so to limit the nomination of conditional expressions, - // we will only consider those which are already getting constants projected into them. - return projectConstantsInHql; + foreach (var listInitArgument in initializer.Arguments) + { + CanBeEvaluatedInHql(listInitArgument); + } } - return !(expression is MemberExpression memberExpression) || // Assume all is good - // Nominate only expressions that represent a mapped property or a translatable method call - ExpressionsHelper.TryGetMappedType(_sessionFactory, expression, out _, out _, out _, out _) || - _functionRegistry.TryGetGenerator(memberExpression.Member, out _); + return false; + } + private bool CanBeEvaluatedInHql(NewArrayExpression newArrayExpression) + { + foreach (var arrayExpression in newArrayExpression.Expressions) + { + CanBeEvaluatedInHql(arrayExpression); + } + + return false; + } + + private bool CanBeEvaluatedInHql(InvocationExpression invocationExpression) + { + foreach (var argument in invocationExpression.Arguments) + { + CanBeEvaluatedInHql(argument); + } + + return false; + } + + private bool CanBeEvaluatedInHql(NewExpression newExpression) + { + foreach (var argument in newExpression.Arguments) + { + CanBeEvaluatedInHql(argument); + } + + return false; + } + + private void VisitMemberBindings(IEnumerable bindings) + { + foreach (var binding in bindings) + { + switch (binding) + { + case MemberAssignment assignment: + CanBeEvaluatedInHql(assignment.Expression); + break; + case MemberListBinding listBinding: + foreach (var argument in listBinding.Initializers.SelectMany(o => o.Arguments)) + { + CanBeEvaluatedInHql(argument); + } + + break; + case MemberMemberBinding memberBinding: + VisitMemberBindings(memberBinding.Bindings); + break; + } + } + } + + private static bool ContainsAnyOfTypes(HashSet types, params Expression[] expressions) + { + return expressions.Any(o => types.Contains(o.Type.UnwrapIfNullable())); } - private static bool CanBeEvaluatedInHqlStatementShortcut(Expression expression) + private static bool ContainsType(System.Type type, params Expression[] expressions) { - return expression is NhCountExpression; + return expressions.Any(o => type == o.Type.UnwrapIfNullable()); } } } diff --git a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs index df1cdfb3daa..925a31252ec 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseVisitor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -12,14 +13,37 @@ namespace NHibernate.Linq.Visitors { public class SelectClauseVisitor : RelinqExpressionVisitor { + private static readonly MethodInfo UnwrapNullableDefinition = ReflectHelper.GetMethodDefinition(() => UnwrapNullable(default, null)); + private static readonly HashSet ComparisonOperators = new HashSet + { + ExpressionType.Equal, + ExpressionType.NotEqual, + ExpressionType.GreaterThanOrEqual, + ExpressionType.GreaterThan, + ExpressionType.LessThan, + ExpressionType.LessThanOrEqual + }; + private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private HashSet _hqlNodes; + private readonly Dictionary _hqlNodeKeyIndexes = new Dictionary(); private readonly ParameterExpression _inputParameter; private readonly VisitorParameters _parameters; private int _iColumn; private List _hqlTreeNodes = new List(); private readonly HqlGeneratorExpressionVisitor _hqlVisitor; + /// + /// Expressions for which we cannot alter their types to nullable: + /// - Root expression + /// - Arguments + /// - Member assignments + /// - Array elements + /// - List elements + /// - Test conditions + /// + private readonly HashSet _requireTypeMatching = new HashSet(); + public SelectClauseVisitor(System.Type inputType, VisitorParameters parameters) { _inputParameter = Expression.Parameter(inputType, "input"); @@ -44,9 +68,12 @@ public void VisitSelector(Expression expression) // Find the sub trees that can be expressed purely in HQL var nominator = new SelectClauseHqlNominator(_parameters); - expression = nominator.Nominate(expression); + nominator.Nominate(expression); _hqlNodes = nominator.HqlCandidates; + // Strip the nominator wrapper from the select expression + expression = UnwrapExpression(expression); + // Linq2SQL ignores calls to local methods. Linq2EF seems to not support // calls to local methods at all. For NHibernate we support local methods, // but prevent their use together with server-side distinct, since it may @@ -55,13 +82,19 @@ public void VisitSelector(Expression expression) throw new NotSupportedException("Cannot use distinct on result that depends on methods for which no SQL equivalent exist."); // Now visit the tree + AddRequiredTypeMatchingExpression(expression); var projection = Visit(expression); - - if ((projection != expression) && !_hqlNodes.Contains(expression)) + if ((projection != expression && !_hqlNodes.Contains(expression)) || _hqlTreeNodes.Count == 0) { ProjectionExpression = Expression.Lambda(projection, _inputParameter); } + // When having only constants in the select clause we need to add one column in order to have a valid sql statement + if (_hqlTreeNodes.Count == 0) + { + _hqlTreeNodes.Add(_hqlVisitor.Visit(Expression.Constant(1)).AsExpression()); + } + // Handle any boolean results in the output nodes _hqlTreeNodes = _hqlTreeNodes.ConvertAll(node => node.ToArithmeticExpression()); @@ -73,6 +106,8 @@ public void VisitSelector(Expression expression) } } + #region Overrides + public override Expression Visit(Expression expression) { if (expression == null) @@ -80,18 +115,404 @@ public override Expression Visit(Expression expression) return null; } + expression = UnwrapExpression(expression); if (_hqlNodes.Contains(expression)) { - // Pure HQL evaluation - _hqlTreeNodes.Add(_hqlVisitor.Visit(expression).AsExpression()); + // In order to avoid selecting the same expressions multiple times, calculate the expression key + // and use the same column index for them. + var key = ExpressionKeyVisitor.VisitChild(expression, _parameters.ConstantToParameterMap, _parameters.SessionFactory); + if (!_hqlNodeKeyIndexes.TryGetValue(key, out var index)) + { + index = _iColumn++; + _hqlNodeKeyIndexes.Add(key, index); + // Pure HQL evaluation + _hqlTreeNodes.Add(_hqlVisitor.Visit(expression).AsExpression()); + } + + var input = Expression.ArrayIndex(_inputParameter, Expression.Constant(index)); - return Convert(Expression.ArrayIndex(_inputParameter, Expression.Constant(_iColumn++)), expression.Type); + // When the value of the _inputParameter is a value type we need to make an additional null check + // in order to prevent a NRE when trying to cast null to a value type. + // (e.g. Select(o => (int?) o.ManyToOne.Id), where Id is a value type but can be null when ManyToOne is null) + if (!expression.Type.IsNullableOrReference()) + { + return UnwrapIfTypeMatchingRequired( + CreateNullConditional(expression.Type, input, input.Type, arg => Convert(arg, expression.Type)), + expression); + } + + return Convert(input, expression.Type); } // Can't handle this node with HQL. Just recurse down, and emit the expression return base.Visit(expression); } + protected override Expression VisitBinary(BinaryExpression node) + { + var leftNode = Visit(node.Left); + var rightNode = Visit(node.Right); + + if (leftNode.Type == node.Left.Type && rightNode.Type == node.Right.Type) + { + return node.Update(leftNode, node.Conversion, rightNode); + } + + // Rules for ?.: + // - All arithmetic operations return null when one or both sides are null + // - Sides of a boolean logical operation return false in case they are null + Expression test; + Expression defaultValue; + var leftVariable = Expression.Variable(leftNode.Type, "left"); + var rightVariable = Expression.Variable(rightNode.Type, "right"); + + if (leftNode.Type.IsNullableOrReference() && rightNode.Type.IsNullableOrReference()) + { + test = Expression.MakeBinary(ExpressionType.OrElse, + Expression.Equal(leftVariable, Expression.Default(leftNode.Type)), + Expression.Equal(rightVariable, Expression.Default(rightNode.Type)) + ); + } + else if (leftNode.Type.IsNullableOrReference()) + { + test = Expression.Equal(leftVariable, Expression.Default(leftNode.Type)); + } + else if (rightNode.Type.IsNullableOrReference()) + { + test = Expression.Equal(rightVariable, Expression.Default(rightNode.Type)); + } + else + { + test = Expression.Constant(false); + } + + // OrElse and AndAlso logical operators never return null even when one of the sides is null, so we have to use false for sides that are null. + // As AndAlso require both side to be true we don't need to apply the operator when one of the sides is null. + if (node.NodeType == ExpressionType.OrElse) + { + defaultValue = Expression.MakeBinary( + node.NodeType, + leftVariable.Type.IsNullable() + ? Expression.Coalesce(leftVariable, Expression.Constant(false)) + : (Expression) leftVariable, + rightVariable.Type.IsNullable() + ? Expression.Coalesce(rightVariable, Expression.Constant(false)) + : (Expression) rightVariable, + node.IsLiftedToNull, + node.Method + ); + } + // Comparison operators never return null when one of the sides is null + else if (node.NodeType == ExpressionType.AndAlso || ComparisonOperators.Contains(node.NodeType)) + { + defaultValue = Expression.Default(typeof(bool)); + } + else + { + defaultValue = Expression.Default(node.Type.GetNullableType()); + } + + // var left = ; + // var right = ; + // return left == default || right == default ? default() : left right; + return UnwrapIfTypeMatchingRequired( + Expression.Block( + new[] { leftVariable, rightVariable }, + Expression.Assign(leftVariable, leftNode), + Expression.Assign(rightVariable, rightNode), + Expression.Condition( + test, + defaultValue, + ConvertIfNeeded( + Expression.MakeBinary( + node.NodeType, + ConvertIfNeeded(leftVariable, node.Left.Type), + ConvertIfNeeded(rightVariable, node.Right.Type), + node.IsLiftedToNull, + node.Method + ), + defaultValue.Type + ) + ) + ), + node); + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + // Rules for ?.: + // - Test expression return false in case the result is null + var testNode = Visit(node.Test); + var ifTrueNode = Visit(node.IfTrue); + var ifFalseNode = Visit(node.IfFalse); + + // Test expressions in sql never return null as case with zero and one is used. In order + // to simulate it on the client we have to add a coalesce expression when the test + // expression is null and return false instead. (e.g. Select(o => o.ManyToOne.Bool ? o.Prop1 : o.Prop2), + // where ManyToOne can be null) + if (testNode.Type.IsNullable()) + { + testNode = Expression.Coalesce(testNode, Expression.Constant(false)); + } + + if (ifTrueNode.Type == node.Type && ifFalseNode.Type == node.Type) + { + return node.Update(testNode, ifTrueNode, ifFalseNode); + } + + if (node.Type == ifFalseNode.Type) + { + ifFalseNode = Expression.Convert(ifFalseNode, ifTrueNode.Type); + } + else if (node.Type == ifTrueNode.Type) + { + ifTrueNode = Expression.Convert(ifTrueNode, ifFalseNode.Type); + } + + return UnwrapIfTypeMatchingRequired(Expression.Condition(testNode, ifTrueNode, ifFalseNode), node); + } + + protected override Expression VisitDynamic(DynamicExpression node) + { + return node.Update(VisitArguments(node) ?? (IEnumerable) node.Arguments); + } + + protected override Expression VisitInvocation(InvocationExpression node) + { + var args = VisitArguments(node) ?? (IEnumerable) node.Arguments; + return node.Update(node.Expression, args); + } + + // Override the original implementation to visit arguments first + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var args = VisitArguments(node) ?? (IEnumerable) node.Arguments; + if (node.Object == null) // Static method + { + return node.Update(node.Object, args); + } + + var obj = Visit(node.Object); + if (!obj.Type.IsNullableOrReference()) + { + return node.Update(obj, args); + } + + return UnwrapIfTypeMatchingRequired( + CreateNullConditional(node.Type, obj, node.Object.Type, arg => Expression.Call(arg, node.Method, args)), + node); + } + + protected override Expression VisitMember(MemberExpression node) + { + var expression = Visit(node.Expression); + if (expression == null) + { + return node; + } + + if (!expression.Type.IsNullableOrReference()) + { + return node.Update(expression); + } + + return UnwrapIfTypeMatchingRequired( + CreateNullConditional(node.Type, expression, node.Expression.Type, arg => Expression.MakeMemberAccess(arg, node.Member)), + node); + } + + protected override Expression VisitNewArray(NewArrayExpression node) + { + return node.Update(VisitExpressions(node.Expressions) ?? (IEnumerable) node.Expressions); + } + + protected override Expression VisitNew(NewExpression node) + { + return node.Update(VisitExpressions(node.Arguments) ?? (IEnumerable) node.Arguments); + } + + protected override Expression VisitTypeBinary(TypeBinaryExpression node) + { + return node.Update(Visit(node.Expression)); + } + + protected override Expression VisitUnary(UnaryExpression node) + { + // Rules for ?.: + // - When the operand is null then the result of the unary operator is also null (except for Not) + // - When Not operand is null then the result is false + var operand = Visit(node.Operand); + if (node.NodeType == ExpressionType.Convert || node.NodeType == ExpressionType.ConvertChecked) + { + if (operand.Type == node.Type) + { + return operand; // Cast was already done + } + + if (operand.Type.IsNullableOrReference()) + { + return node.Update(operand); // A NRE will never occur for casting to reference types + } + } + // Not is transformed into case with zero and one, so when the expression is null we have to return false + // to match server evaluation. In order to return false on null, use true for the Not operator. + else if (node.NodeType == ExpressionType.Not && operand.Type.IsNullable()) + { + operand = Expression.Coalesce(operand, Expression.Constant(true)); + } + + if (!operand.Type.IsNullableOrReference() || node.NodeType == ExpressionType.TypeAs) + { + return node.Update(operand); + } + + return UnwrapIfTypeMatchingRequired( + CreateNullConditional(node.Type, operand, node.Operand.Type, arg => Expression.MakeUnary(node.NodeType, arg, node.Type, node.Method)), + node); + } + + protected override Expression VisitMemberInit(MemberInitExpression node) + { + return node.Update(VisitAndConvert(node.NewExpression, nameof(VisitMemberInit)), Visit(node.Bindings, VisitMemberBinding)); + } + + protected override Expression VisitListInit(ListInitExpression node) + { + return node.Update(VisitAndConvert(node.NewExpression, nameof(VisitListInit)), Visit(node.Initializers, VisitElementInit)); + } + + protected override ElementInit VisitElementInit(ElementInit node) + { + return node.Update(VisitExpressions(node.Arguments) ?? (IEnumerable) node.Arguments); + } + + protected override MemberAssignment VisitMemberAssignment(MemberAssignment node) + { + AddRequiredTypeMatchingExpression(node.Expression); + return node.Update(Visit(node.Expression)); + } + + #endregion + + private Expression[] VisitExpressions(ReadOnlyCollection nodes) + { + Expression[] newNodes = null; + for (int i = 0, n = nodes.Count; i < n; i++) + { + var curNode = nodes[i]; + AddRequiredTypeMatchingExpression(curNode); + var node = Visit(curNode); + if (newNodes != null) + { + newNodes[i] = node; + } + else if (!ReferenceEquals(node, curNode)) + { + newNodes = new Expression[n]; + for (var j = 0; j < i; j++) + { + newNodes[j] = nodes[j]; + } + + newNodes[i] = node; + } + } + + return newNodes; + } + + private Expression[] VisitArguments(IArgumentProvider nodes) + { + Expression[] newNodes = null; + for (int i = 0, n = nodes.ArgumentCount; i < n; i++) + { + var curNode = nodes.GetArgument(i); + AddRequiredTypeMatchingExpression(curNode); + var node = Visit(curNode); + if (newNodes != null) + { + newNodes[i] = node; + } + else if (!ReferenceEquals(node, curNode)) + { + newNodes = new Expression[n]; + for (var j = 0; j < i; j++) + { + newNodes[j] = nodes.GetArgument(j); + } + + newNodes[i] = node; + } + } + + return newNodes; + } + + private void AddRequiredTypeMatchingExpression(Expression expression) + { + _requireTypeMatching.Add(UnwrapExpression(expression)); + } + + private Expression UnwrapExpression(Expression expression) + { + if (expression is NhNominatedExpression nominatedExpression) + { + return nominatedExpression.Expression; + } + + if (expression is MethodCallExpression methodExpression && + VisitorUtil.TryGetEvalExpression(methodExpression, out var evalExpression)) + { + return evalExpression; + } + + return expression; + } + + private static Expression ConvertIfNeeded(Expression node, System.Type type) + { + return node.Type == type + ? node + : Expression.Convert(node, type); + } + + private Expression UnwrapIfTypeMatchingRequired(Expression expression, Expression node) + { + return _requireTypeMatching.Contains(node) && expression.Type.IsNullable() + ? CallUnwrapNullable(expression, node) + : expression; + } + + /// + /// Adds a null check (simulates ?. operator) for client side evaluations in order to prevent NRE and makes it consistent with the server + /// side evaluation. + /// + /// The root expression. + /// The operand to check for value. + /// The non transformed operand to match the type. + /// The expression that will be used when the value is not null. + /// The transformed expression. + private static BlockExpression CreateNullConditional(System.Type nodeType, Expression operand, System.Type originalOperandType, Func getOperationFunc) + { + // Simulate ?. operator by using an ExpressionBlock + // var value = ; + // return value == default() ? default() : operation(value) + var valueVariable = Expression.Variable(operand.Type, "value"); + var returnType = nodeType.GetNullableType(); + return Expression.Block( + new[] { valueVariable }, + Expression.Assign(valueVariable, operand), // Assign to a variable to avoid evaluating the operand multiple times + Expression.Condition( + Expression.Equal(valueVariable, Expression.Default(operand.Type)), + Expression.Default(returnType), + ConvertIfNeeded( + getOperationFunc(ConvertIfNeeded(valueVariable, originalOperandType)), // Cast will be done only for nullable types + returnType + ) + ) + ); + } + private static readonly MethodInfo ConvertChangeType = ReflectHelper.FastGetMethod(System.Convert.ChangeType, default(object), default(System.Type)); @@ -108,6 +529,24 @@ private static Expression Convert(Expression expression, System.Type type) return Expression.Convert(expression, type); } + + private static MethodCallExpression CallUnwrapNullable(Expression valueExpression, Expression node) + { + return Expression.Call(UnwrapNullableDefinition.MakeGenericMethod(node.Type), + valueExpression, + Expression.Constant(node)); + } + + private static TValueType UnwrapNullable(TValueType? value, Expression node) where TValueType : struct + { + if (!value.HasValue) + { + throw new InvalidOperationException( + $"Null value cannot be assigned to a value type '{typeof(TValueType)}'. Cast expression '{node}' to '{typeof(TValueType?)}'."); + } + + return value.Value; + } } // Since v5 diff --git a/src/NHibernate/Linq/Visitors/VisitorParameters.cs b/src/NHibernate/Linq/Visitors/VisitorParameters.cs index a4d2a1f2c65..deccd45c59a 100644 --- a/src/NHibernate/Linq/Visitors/VisitorParameters.cs +++ b/src/NHibernate/Linq/Visitors/VisitorParameters.cs @@ -2,7 +2,9 @@ using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Engine.Query; +using NHibernate.Linq.ReWriters; using NHibernate.Param; +using Remotion.Linq; namespace NHibernate.Linq.Visitors { @@ -23,6 +25,14 @@ public class VisitorParameters public QueryMode RootQueryMode { get; } + internal Dictionary QueryModelRewriterResults { get; } + = new Dictionary(); + + internal void AddQueryModelRewriterResult(QueryModel queryModel, ResultOperatorRewriterResult rewriterResult) + { + QueryModelRewriterResults.Add(queryModel, rewriterResult); + } + public VisitorParameters( ISessionFactoryImplementor sessionFactory, IDictionary constantToParameterMap, @@ -39,4 +49,4 @@ public VisitorParameters( RootQueryMode = rootQueryMode; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 22ac89dd0aa..c68681d0b2e 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -13,25 +13,12 @@ public static class VisitorUtil { public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) { - memberName = null; - - // A dynamic component must be an IDictionary with a string key. - - if (method.Name != "get_Item" || !typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type)) - return false; - - var key = arguments.First() as ConstantExpression; - if (key == null || key.Type != typeof(string)) - return false; - - // The potential member name - memberName = (string)key.Value; - - // Need the owning member (the dictionary). - var member = targetObject as MemberExpression; - if (member == null) + if (!TryGetPotentialDynamicComponentDictionaryMember(method, targetObject, arguments, out memberName)) + { return false; + } + var member = (MemberExpression) targetObject; var memberPath = member.Member.Name; var metaData = sessionFactory.GetClassMetadata(member.Expression.Type); @@ -131,5 +118,54 @@ public static string GetMemberPath(this MemberExpression memberExpression) } return path; } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember(MethodCallExpression expression, out string memberName) + { + return TryGetPotentialDynamicComponentDictionaryMember( + expression.Method, + expression.Object, + expression.Arguments, + out memberName); + } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember( + MethodInfo method, + Expression targetObject, + IEnumerable arguments, + out string memberName) + { + memberName = null; + // A dynamic component must be an IDictionary with a string key. + if (method.Name != "get_Item" || + targetObject.NodeType != ExpressionType.MemberAccess || // Need the owning member (the dictionary). + !(arguments.First() is ConstantExpression key) || + key.Type != typeof(string) || + (!typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type))) + { + return false; + } + + // The potential member name + memberName = (string) key.Value; + return true; + } + + internal static bool IsMappedAs(MethodInfo methodInfo) + { + return methodInfo.Name == nameof(LinqExtensionMethods.MappedAs) && + methodInfo.DeclaringType == typeof(LinqExtensionMethods); + } + + internal static bool TryGetEvalExpression(MethodCallExpression methodExpression, out Expression expression) + { + if (methodExpression.Method.DeclaringType != typeof(ExpressionEvaluation)) + { + expression = null; + return false; + } + + expression = ((LambdaExpression) ((UnaryExpression) methodExpression.Arguments[0]).Operand).Body; + return true; + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 886d4e0e2b1..689457a7403 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -62,6 +62,7 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private readonly Stack _handled = new Stack(); @@ -71,10 +72,11 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // The following is used for member expressions traversal. private int _memberExpressionDepth; - internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } public Expression Transform(Expression expression) @@ -329,7 +331,7 @@ protected override Expression VisitMember(MemberExpression expression) { // Don't add joins for things like a.B == a.C where B and C are entities. // We only need to join B when there's something like a.B.D. - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); if (_memberExpressionDepth > 0 && _joiner.CanAddJoin(expression)) { diff --git a/src/NHibernate/Param/NamedParameter.cs b/src/NHibernate/Param/NamedParameter.cs index b42f69925f0..a9a9b67de2b 100644 --- a/src/NHibernate/Param/NamedParameter.cs +++ b/src/NHibernate/Param/NamedParameter.cs @@ -5,16 +5,24 @@ namespace NHibernate.Param public class NamedParameter { public NamedParameter(string name, object value, IType type) + : this(name, value, type, false) + { + } + + internal NamedParameter(string name, object value, IType type, bool isCollection) { Name = name; Value = value; Type = type; + IsCollection = isCollection; } public string Name { get; private set; } public object Value { get; internal set; } public IType Type { get; internal set; } + public virtual bool IsCollection { get; } + public bool Equals(NamedParameter other) { if (ReferenceEquals(null, other)) @@ -38,4 +46,4 @@ public override int GetHashCode() return (Name != null ? Name.GetHashCode() : 0); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 08a60aeeb66..2980faeab7d 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -15,6 +16,7 @@ using NHibernate.Type; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; namespace NHibernate.Util { @@ -30,6 +32,33 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } +#if NETCOREAPP2_0 + /// + /// Try to retrieve from a reduced expression. + /// + /// The reduced dynamic expression. + /// The out binder parameter. + /// Whether the binder was found. + internal static bool TryGetDynamicMemberBinder(InvocationExpression expression, out GetMemberBinder memberBinder) + { + // This is an ugly workaround for dynamic expressions in .NET Core. In .NET Core a dynamic expression is reduced + // when first visited by a expression visitor that is not a DynamicExpressionVisitor, where in .NET Framework it is never reduced. + // As RelinqExpressionVisitor does not extend DynamicExpressionVisitor, we will always have a reduced dynamic expression in .NET Core. + // Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression + if (expression.Arguments.Count == 2 && + expression.Arguments[0] is ConstantExpression constant && + constant.Value is CallSite site && + site.Binder is GetMemberBinder binder) + { + memberBinder = binder; + return true; + } + + memberBinder = null; + return false; + } +#endif + /// /// Check whether the given expression represent a variable. /// @@ -96,6 +125,21 @@ internal static IType GetType(VisitorParameters parameters, Expression expressio return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type); } + /// + /// Get the mapped types for the given expressions. + /// + /// The query parameters. + /// The expressions. + /// An enumerable of mapped types or when a mapped type was not + /// found or an item type of is . + internal static IEnumerable GetTypes(VisitorParameters parameters, params Expression[] expressions) + { + foreach (var expression in expressions) + { + yield return GetType(parameters, expression); + } + } + /// /// Try to get the mapped nullability from the given expression. /// @@ -635,6 +679,34 @@ protected override Expression VisitMember(MemberExpression node) return base.Visit(node.Expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression node) + { + if (TryGetDynamicMemberBinder(node, out var binder)) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[1]); + } + + return base.VisitInvocation(node); + } +#endif + + protected override Expression VisitDynamic(DynamicExpression node) + { + if (node.Binder is GetMemberBinder binder) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[0]); + } + + return Visit(node); + } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) { if (node.ReferencedQuerySource is IFromClause fromClause) @@ -721,6 +793,14 @@ protected override Expression VisitMethodCall(MethodCallExpression node) ); } + if (VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(node, out var memberName)) + { + _memberPaths.Push(new MemberMetadata(memberName, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Object); + } + return Visit(node); } diff --git a/src/NHibernate/Util/ParameterHelper.cs b/src/NHibernate/Util/ParameterHelper.cs new file mode 100644 index 00000000000..d0b6bd14625 --- /dev/null +++ b/src/NHibernate/Util/ParameterHelper.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Proxy; +using NHibernate.Type; + +namespace NHibernate.Util +{ + internal static class ParameterHelper + { + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (param == null) + { + return null; + } + + if (param is IEnumerable enumerable && isCollection) + { + var firstValue = enumerable.Cast().FirstOrDefault(); + return firstValue == null + ? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory) + : TryGuessType(firstValue, sessionFactory, false); + } + + var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType GuessType(object param, ISessionFactoryImplementor sessionFactory) + { + if (param == null) + { + throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + } + + System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return GuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (clazz == null) + { + return null; + } + + if (isCollection) + { + return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, false); + } + + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + return TryGuessType(clazz, sessionFactory) ?? + throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + return null; + } + + var type = TypeFactory.HeuristicType(clazz); + if (type == null || type is SerializableType) + { + if (sessionFactory.TryGetEntityPersister(clazz.FullName) != null) + { + return NHibernateUtil.Entity(clazz); + } + } + + return type; + } + } +} diff --git a/src/NHibernate/Util/ReflectionCache.cs b/src/NHibernate/Util/ReflectionCache.cs index c40a395f98d..47fde15950d 100644 --- a/src/NHibernate/Util/ReflectionCache.cs +++ b/src/NHibernate/Util/ReflectionCache.cs @@ -54,6 +54,11 @@ internal static class EnumerableMethods internal static readonly MethodInfo ToListDefinition = ReflectHelper.FastGetMethodDefinition(Enumerable.ToList, default(IEnumerable)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), default(int)); } internal static class MethodBaseMethods @@ -215,6 +220,11 @@ internal static class QueryableMethods ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); internal static readonly MethodInfo AverageWithSelectorOfNullableDecimalDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), default(int)); } internal static class TypeMethods diff --git a/src/NHibernate/Util/TypeExtensions.cs b/src/NHibernate/Util/TypeExtensions.cs index 71bf854957f..b324059d0fa 100644 --- a/src/NHibernate/Util/TypeExtensions.cs +++ b/src/NHibernate/Util/TypeExtensions.cs @@ -48,5 +48,15 @@ internal static System.Type UnwrapIfNullable(this System.Type type) return type; } + + internal static System.Type GetNullableType(this System.Type type) + { + if (type.IsNullable()) + { + return type; + } + + return type.IsValueType ? typeof(Nullable<>).MakeGenericType(type) : type; + } } }