Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue/OFFI-77: Invoice calculations API #278

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 68 additions & 21 deletions Lombiq.HelpfulLibraries.Common/Extensions/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable

using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading.Tasks;
Expand Down Expand Up @@ -39,7 +41,7 @@ public static TAccumulate AggregateSeed<TSource, TAccumulate>(
/// <param name="beforeFirst">The action to perform before the first item's <paramref name="action"/>.</param>
/// <typeparam name="T">The type of the items in <paramref name="source"/>.</typeparam>
/// <returns><see langword="true"/> if the <paramref name="source"/> had at least one item.</returns>
public static bool ForEach<T>(this IEnumerable<T> source, Action<T> action, Action<T> beforeFirst = null)
public static bool ForEach<T>(this IEnumerable<T> source, Action<T> action, Action<T>? beforeFirst = null)
{
bool any = false;

Expand Down Expand Up @@ -165,18 +167,31 @@ public static IList<T> AsList<T>(this IEnumerable<T> collection) =>

/// <summary>
/// Transforms the specified <paramref name="collection"/> with the <paramref name="select"/> function and returns
/// the items that are not null. Or if the <paramref name="where"/> function is given then those that return <see
/// langword="true"/> with it.
/// the items that return <see langword="true"/> when passed to the <paramref name="where"/> function.
/// </summary>
public static IEnumerable<TOut> SelectWhere<TIn, TOut>(
this IEnumerable<TIn> collection,
Func<TIn, TOut> select,
Func<TOut, bool> where = null)
Func<TOut, bool> where)
{
foreach (var item in collection)
{
var converted = select(item);
if (where.Invoke(converted)) yield return converted;
}
}

/// <summary>
/// Transforms the specified <paramref name="collection"/> with the <paramref name="select"/> function and returns
/// the items that are not null.
/// </summary>
public static IEnumerable<TOut> SelectWhere<TIn, TOut>(this IEnumerable<TIn> collection, Func<TIn, TOut?> select)
where TOut : notnull
{
foreach (var item in collection ?? [])
foreach (var item in collection)
{
var converted = select(item);
if (where?.Invoke(converted) ?? converted is not null) yield return converted;
if (converted is not null) yield return converted;
}
}

Expand All @@ -192,6 +207,7 @@ public static Dictionary<TKey, TValue> ToDictionaryOverwrite<TIn, TKey, TValue>(
this IEnumerable<TIn> collection,
Func<TIn, TKey> keySelector,
Func<TIn, TValue> valueSelector)
where TKey : notnull
{
var dictionary = new Dictionary<TKey, TValue>();
foreach (var item in collection) dictionary[keySelector(item)] = valueSelector(item);
Expand All @@ -208,7 +224,8 @@ public static Dictionary<TKey, TValue> ToDictionaryOverwrite<TIn, TKey, TValue>(
Justification = "This is the point of the method.")]
public static Dictionary<TKey, TIn> ToDictionaryOverwrite<TIn, TKey>(
this IEnumerable<TIn> collection,
Func<TIn, TKey> keySelector) =>
Func<TIn, TKey> keySelector)
where TKey : notnull =>
ToDictionaryOverwrite(collection, keySelector, item => item);

/// <summary>
Expand All @@ -221,7 +238,7 @@ public static Dictionary<TKey, TIn> ToDictionaryOverwrite<TIn, TKey>(
/// after grouping.
/// </para>
/// </remarks>
public static IEnumerable<TItem> Unique<TItem, TKey>(
public static IEnumerable<TItem?> Unique<TItem, TKey>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector) =>
collection.GroupBy(keySelector).Select(group => group.FirstOrDefault());
Expand All @@ -230,7 +247,7 @@ public static IEnumerable<TItem> Unique<TItem, TKey>(
/// Returns the <paramref name="collection"/> without any duplicate items picking the first of each when sorting by
/// <paramref name="orderBySelector"/>.
/// </summary>
public static IEnumerable<TItem> Unique<TItem, TKey, TOrder>(
public static IEnumerable<TItem?> Unique<TItem, TKey, TOrder>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector,
Func<TItem, TOrder> orderBySelector) =>
Expand All @@ -242,7 +259,7 @@ public static IEnumerable<TItem> Unique<TItem, TKey, TOrder>(
/// Returns the <paramref name="collection"/> without any duplicate items picking the last of each when sorting by
/// <paramref name="orderBySelector"/>.
/// </summary>
public static IEnumerable<TItem> UniqueDescending<TItem, TKey, TOrder>(
public static IEnumerable<TItem?> UniqueDescending<TItem, TKey, TOrder>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector,
Func<TItem, TOrder> orderBySelector) =>
Expand All @@ -254,11 +271,13 @@ public static IEnumerable<TItem> UniqueDescending<TItem, TKey, TOrder>(
/// Returns a string that joins the string collection. It excludes null or empty items if there are any.
/// </summary>
/// <returns>The concatenated texts if there are any nonempty, otherwise <see langword="null"/>.</returns>
public static string JoinNotNullOrEmpty(this IEnumerable<string> strings, string separator = ",")
public static string? JoinNotNullOrEmpty(this IEnumerable<string>? strings, string separator = ",")
{
var filteredStrings = strings?.Where(text => !string.IsNullOrWhiteSpace(text)).ToList();
var filteredStrings = (strings ?? [])
.Where(text => !string.IsNullOrWhiteSpace(text))
.ToList();

return filteredStrings?.Count > 0
return filteredStrings.Count > 0
? string.Join(separator, filteredStrings)
: null;
}
Expand All @@ -271,7 +290,7 @@ public static string JoinNotNullOrEmpty(this IEnumerable<string> strings, string
/// <returns>
/// A new <see cref="string"/> that concatenates all values with the <paramref name="separator"/> provided.
/// </returns>
public static string Join(this IEnumerable<string> values, string separator = " ") =>
public static string Join(this IEnumerable<string>? values, string separator = " ") =>
string.Join(separator, values ?? []);

/// <summary>
Expand All @@ -293,14 +312,14 @@ public static IEnumerable<T> WhereNot<T>(this IEnumerable<T> collection, Func<T,
/// Returns <paramref name="collection"/> if it's not <see langword="null"/>, otherwise <see
/// cref="Enumerable.Empty{TResult}"/>.
/// </summary>
public static IEnumerable<T> EmptyIfNull<T>(this IEnumerable<T> collection) =>
public static IEnumerable<T> EmptyIfNull<T>(this IEnumerable<T>? collection) =>
collection ?? [];

/// <summary>
/// Returns <paramref name="array"/> if it's not <see langword="null"/>, otherwise <see
/// cref="Array.Empty{TResult}"/>.
/// </summary>
public static IEnumerable<T> EmptyIfNull<T>(this T[] array) =>
public static IEnumerable<T> EmptyIfNull<T>(this T[]? array) =>
array ?? [];

/// <summary>
Expand Down Expand Up @@ -333,7 +352,7 @@ public static IEnumerable<TResult> Select<TKey, TValue, TResult>(
/// Similar to <see cref="Enumerable.Cast{TResult}"/>, but it checks if the types are correct first, and filters out
/// the ones that couldn't be cast. The optional <paramref name="predicate"/> can filter the cast items.
/// </summary>
public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, bool> predicate = null)
public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, bool>? predicate = null)
{
if (enumerable is IEnumerable<T> alreadyCast)
{
Expand All @@ -342,7 +361,7 @@ public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, b
: alreadyCast.Where(predicate);
}

static IEnumerable<T> Iterate(IEnumerable enumerable, Func<T, bool> predicate)
static IEnumerable<T> Iterate(IEnumerable enumerable, Func<T, bool>? predicate)
{
foreach (var item in enumerable)
{
Expand Down Expand Up @@ -398,10 +417,38 @@ public static Task InvokeFirstOrCompletedAsync<T>(this IEnumerable<T> enumerable
/// If the <paramref name="enumerable"/> is not empty, invokes the <paramref name="funcAsync"/> on the first item
/// and returns its result, otherwise returns <see langword="default"/> for <typeparamref name="TResult"/>.
/// </summary>
public static Task<TResult> InvokeFirstOrDefaultAsync<TItem, TResult>(
public static Task<TResult?> InvokeFirstOrDefaultAsync<TItem, TResult>(
this IEnumerable<TItem> enumerable,
Func<TItem, Task<TResult>> funcAsync) =>
Func<TItem, Task<TResult?>> funcAsync) =>
enumerable.FirstOrDefault() is { } item
? funcAsync(item)
: Task.FromResult(default(TResult));
: Task.FromResult(default(TResult?));

/// <summary>
/// Splits the provided <paramref name="enumerable"/> into two.
/// </summary>
/// <param name="enumerable">The original collection to be tested.</param>
/// <param name="leftPredicate">
/// Tests each item of <paramref name="enumerable"/>. If returns <see langword="true"/>, the
/// item is added to the left collection, otherwise added to the right collection.
/// </param>
/// <typeparam name="T">The type of the items in <paramref name="enumerable"/>.</typeparam>
/// <returns>A tuple of two collections. Each item in <paramref name="enumerable"/> is in one of them.</returns>
public static (IList<T> Left, IList<T> Right) Fork<T>(
this IEnumerable<T>? enumerable,
Func<T, bool> leftPredicate)
{
var left = new List<T>();
var right = new List<T>();

if (enumerable is null) return (left, right);

foreach (var item in enumerable)
{
var target = leftPredicate(item) ? left : right;
target.Add(item);
}

return (left, right);
}
}