Skip to content

Commit

Permalink
[#266] 添加GetCompareValueByShardingKey方法优化当出现大量In时Expression的or函数或者and…
Browse files Browse the repository at this point in the history
…函数拼接导致stackoverflow异常
  • Loading branch information
xuejmnet committed Apr 20, 2024
1 parent faeba51 commit b730c4e
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 9 deletions.
4 changes: 4 additions & 0 deletions samples/Sample.AutoCreateIfPresent/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
})
.UseConfig(o =>
{
// o.CacheEntrySize;
// o.CacheModelLockConcurrencyLevel
// o.CacheModelLockObjectSeconds
// o.CacheItemPriority
o.ThrowIfQueryRouteNotMatch = false;
o.AddDefaultDataSource("ds0", "server=127.0.0.1;port=3306;database=shardingTest;userid=root;password=root;");
o.UseShardingQuery((conn, b) =>
Expand Down
13 changes: 12 additions & 1 deletion samples/Sample.MySql/Controllers/WeatherForecastController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public IQueryable<SysTest> GetAll()
[HttpGet]
public async Task<IActionResult> Getxx()
{

var test = new Test();
test.UtcTime=DateTime.Now;
await _defaultTableDbContext.AddAsync(test);
Expand All @@ -141,7 +142,7 @@ public async Task<IActionResult> Getxx()
public async Task<IActionResult> Get()
{
var s = Guid.NewGuid().ToString();
var page =await _defaultTableDbContext.Set<SysUserLogByMonth>().Where(o=>o.Id==s).OrderByDescending(o=>o.Time).ToShardingPageAsync(1,2);
// var page =await _defaultTableDbContext.Set<SysUserLogByMonth>().Include().ThenInclude().Where(o=>o.Id==s).OrderByDescending(o=>o.Time).ToShardingPageAsync(1,2);
// var virtualDataSource = _shardingRuntimeContext.GetVirtualDataSource();
// virtualDataSource.AddPhysicDataSource(new DefaultPhysicDataSource("2023", "xxxxxxxx", false));
// var dataSourceRouteManager = _shardingRuntimeContext.GetDataSourceRouteManager();
Expand Down Expand Up @@ -571,5 +572,15 @@ public void get11()
unShardingDbContext2.SaveChanges();
dbContextTransaction.Commit();
}

[HttpGet]
public async Task<IActionResult> get131()
{
var list = new List<string>();
var idList = Enumerable.Range(1,50000).Select(o=>o.ToString()).ToList();
var sysUserMods = _defaultTableDbContext.Set<SysUserMod>()
.Where(o=>idList.Contains(o.Id)).ToList();
return Ok();
}
}
}
9 changes: 9 additions & 0 deletions samples/Sample.MySql/Shardings/SysUserModVirtualTableRoute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ public override void Configure(EntityMetadataTableBuilder<SysUserMod> builder)
builder.ShardingProperty(o => o.Id);
}

public override object GetCompareValueByShardingKey(object shardingKey, string shardingPropertyName)
{
if ("Id".Equals(shardingPropertyName))
{
return ShardingKeyToTail(shardingKey);
}
return base.GetCompareValueByShardingKey(shardingKey, shardingPropertyName);
}

// protected override List<TableRouteUnit> AfterShardingRouteUnitFilter(DataSourceRouteResult dataSourceRouteResult, List<TableRouteUnit> shardingRouteUnits)
// {
// //拦截
Expand Down
2 changes: 1 addition & 1 deletion samples/Sample.MySql/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
{
app.UseDeveloperExceptionPage();
}
app.ApplicationServices.UseAutoTryCompensateTable();
using (var scope = app.ApplicationServices.CreateScope())
{
var unShardingDbContext = scope.ServiceProvider.GetService<UnShardingDbContext>();
Expand All @@ -236,6 +235,7 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
// var virtualTableRoute = (SysUserLogByMonthRoute)tableRouteManager.GetRoute(typeof(SysUserLogByMonth));
// virtualTableRoute.Append("2021");
}
app.ApplicationServices.UseAutoTryCompensateTable();
// var shardingRuntimeContext = app.ApplicationServices.GetRequiredService<IShardingRuntimeContext>();
// var entityMetadataManager = shardingRuntimeContext.GetEntityMetadataManager();
// var entityMetadata = entityMetadataManager.TryGet<SysUserMod>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public abstract class AbstractShardingOperatorVirtualDataSourceRoute<TEntity, TK
protected override List<string> DoRouteWithPredicate(List<string> allDataSourceNames, IQueryable queryable)
{
//获取路由后缀表达式
var routeParseExpression = ShardingUtil.GetRouteParseExpression(queryable, EntityMetadata, GetRouteFilter, false);
var routeParseExpression = ShardingUtil.GetRouteParseExpression(queryable, EntityMetadata, GetRouteFilter,GetCompareValueByShardingKey, false);
//表达式缓存编译
// var filter = CachingCompile(routeParseExpression);
var filter = routeParseExpression.GetRoutePredicate();
Expand All @@ -33,6 +33,11 @@ protected override List<string> DoRouteWithPredicate(List<string> allDataSourceN
return dataSources;
}

public virtual object GetCompareValueByShardingKey(object shardingKey, string shardingPropertyName)
{
return shardingKey;
}


/// <summary>
/// 如何路由到具体表 shardingKeyValue:分表的值, 返回结果:如果返回true表示返回该表 第一个参数 tail 第二参数是否返回该物理表
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public abstract class AbstractShardingOperatorVirtualTableRoute<TEntity, TKey> :
protected override List<TableRouteUnit> DoRouteWithPredicate(DataSourceRouteResult dataSourceRouteResult, IQueryable queryable)
{
//获取路由后缀表达式
var routeParseExpression = ShardingUtil.GetRouteParseExpression(queryable, EntityMetadata, GetRouteFilter,true);
var routeParseExpression = ShardingUtil.GetRouteParseExpression(queryable, EntityMetadata, GetRouteFilter,GetCompareValueByShardingKey,true);
//表达式缓存编译
// var filter =CachingCompile(routeParseExpression);
var filter =routeParseExpression.GetRoutePredicate();
Expand All @@ -40,6 +40,10 @@ protected override List<TableRouteUnit> DoRouteWithPredicate(DataSourceRouteResu
return sqlRouteUnits;
}

public virtual object GetCompareValueByShardingKey(object shardingKey, string shardingPropertyName)
{
return shardingKey;
}

/// <summary>
/// 如何路由到具体表 shardingKeyValue:分表的值, 返回结果:如果返回true表示返回该表 第一个参数 tail 第二参数是否返回该物理表
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class QueryableRouteShardingTableDiscoverVisitor : ShardingExpressionVisi

private readonly EntityMetadata _entityMetadata;
private readonly Func<object, ShardingOperatorEnum, string, Func<string, bool>> _keyToTailWithFilter;
private readonly Func<object, string, object> _compareValueByKey;

/// <summary>
/// 是否是分表路由
Expand All @@ -63,10 +64,11 @@ public class QueryableRouteShardingTableDiscoverVisitor : ShardingExpressionVisi
private RoutePredicateExpression _where = RoutePredicateExpression.Default;

public QueryableRouteShardingTableDiscoverVisitor(EntityMetadata entityMetadata,
Func<object, ShardingOperatorEnum, string, Func<string, bool>> keyToTailWithFilter, bool shardingTableRoute)
Func<object, ShardingOperatorEnum, string, Func<string, bool>> keyToTailWithFilter,Func<object,string,object> compareValueByKey, bool shardingTableRoute)
{
_entityMetadata = entityMetadata;
_keyToTailWithFilter = keyToTailWithFilter;
_compareValueByKey = compareValueByKey;
_shardingTableRoute = shardingTableRoute;
}

Expand Down Expand Up @@ -378,8 +380,15 @@ private RoutePredicateExpression ResolveInFunc(MethodCallExpression methodCallEx

if (arrayObject is IEnumerable enumerableObj)
{
var compareSet = new HashSet<object>();
foreach (var shardingValue in enumerableObj)
{
var compareValueByKey = _compareValueByKey(shardingValue,shardingPredicateResult.ShardingPropertyName);
if (!compareSet.Add(compareValueByKey))
{
continue;
}

var eq = _keyToTailWithFilter(shardingValue,
@in ? ShardingOperatorEnum.Equal : ShardingOperatorEnum.NotEqual,
shardingPredicateResult.ShardingPropertyName);
Expand Down
4 changes: 2 additions & 2 deletions src/ShardingCore/Utils/ShardingUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ public class ShardingUtil
/// <param name="keyToTailExpression"></param>
/// <param name="shardingTableRoute">sharding table or data source</param>
/// <returns></returns>
public static RoutePredicateExpression GetRouteParseExpression(IQueryable queryable, EntityMetadata entityMetadata, Func<object, ShardingOperatorEnum,string, Func<string, bool>> keyToTailExpression,bool shardingTableRoute)
public static RoutePredicateExpression GetRouteParseExpression(IQueryable queryable, EntityMetadata entityMetadata, Func<object, ShardingOperatorEnum,string, Func<string, bool>> keyToTailExpression,Func<object,string,object> compareValueByKey,bool shardingTableRoute)
{

QueryableRouteShardingTableDiscoverVisitor visitor = new QueryableRouteShardingTableDiscoverVisitor(entityMetadata, keyToTailExpression, shardingTableRoute);
QueryableRouteShardingTableDiscoverVisitor visitor = new QueryableRouteShardingTableDiscoverVisitor(entityMetadata, keyToTailExpression,compareValueByKey, shardingTableRoute);

visitor.Visit(queryable.Expression);

Expand Down
6 changes: 5 additions & 1 deletion test/ShardingCore.CommonTest/ShardingDataSourceMod.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public ShardingDataSourceMod()
_allDataSources = Enumerable.Range(0, 10).Select(o => o.ToString()).ToList();
}

public static object GetCompareValueByShardingKey(object shardingKey, string shardingPropertyName)
{
return shardingKey;
}
public static Func<string, bool> GetRouteFilter(object shardingValue, ShardingOperatorEnum shardingOperator,
string propertyName)
{
Expand All @@ -47,7 +51,7 @@ public static Func<string, bool> GetRouteFilter(object shardingValue, ShardingOp
private void TestId(IQueryable<TestEntity> queryable, string[] dataSourceNames)
{
var routePredicateExpression =
ShardingUtil.GetRouteParseExpression(queryable, _testEntityMetadata, GetRouteFilter, false);
ShardingUtil.GetRouteParseExpression(queryable, _testEntityMetadata, GetRouteFilter,GetCompareValueByShardingKey, false);
Assert.NotNull(routePredicateExpression);
var routePredicate = routePredicateExpression.GetRoutePredicate();

Expand Down
7 changes: 6 additions & 1 deletion test/ShardingCore.CommonTest/ShardingTableTime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public ShardingTableTime()
//[20220101....20220120]
_allTables = Enumerable.Range(0,20).Select(o=>dateTime.AddDays(o).ToString("yyyyMMdd")).ToList();
}

public static object GetCompareValueByShardingKey(object shardingKey, string shardingPropertyName)
{
return shardingKey;
}
public static Func<string, bool> GetRouteFilter(object shardingValue, ShardingOperatorEnum shardingOperator,
string propertyName)
{
Expand Down Expand Up @@ -60,7 +65,7 @@ public static Func<string, bool> GetRouteFilter(object shardingValue, ShardingOp

private void TestId(IQueryable<TestTimeEntity> queryable, string[] tables)
{
var routePredicateExpression = ShardingUtil.GetRouteParseExpression(queryable,_testEntityMetadata,GetRouteFilter,true);
var routePredicateExpression = ShardingUtil.GetRouteParseExpression(queryable,_testEntityMetadata,GetRouteFilter,GetCompareValueByShardingKey,true);
Assert.NotNull(routePredicateExpression);
var routePredicate = routePredicateExpression.GetRoutePredicate();

Expand Down

0 comments on commit b730c4e

Please sign in to comment.