Skip to content

Commit

Permalink
Merge pull request #278 from Lombiq/issue/OFFI-77
Browse files Browse the repository at this point in the history
issue/OFFI-77: Invoice calculations API
  • Loading branch information
wAsnk authored Aug 16, 2024
2 parents 2111aa6 + dba6518 commit 78bcf2a
Showing 1 changed file with 68 additions and 21 deletions.
89 changes: 68 additions & 21 deletions Lombiq.HelpfulLibraries.Common/Extensions/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable

using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading.Tasks;
Expand Down Expand Up @@ -39,7 +41,7 @@ public static TAccumulate AggregateSeed<TSource, TAccumulate>(
/// <param name="beforeFirst">The action to perform before the first item's <paramref name="action"/>.</param>
/// <typeparam name="T">The type of the items in <paramref name="source"/>.</typeparam>
/// <returns><see langword="true"/> if the <paramref name="source"/> had at least one item.</returns>
public static bool ForEach<T>(this IEnumerable<T> source, Action<T> action, Action<T> beforeFirst = null)
public static bool ForEach<T>(this IEnumerable<T> source, Action<T> action, Action<T>? beforeFirst = null)
{
bool any = false;

Expand Down Expand Up @@ -165,18 +167,31 @@ public static IList<T> AsList<T>(this IEnumerable<T> collection) =>

/// <summary>
/// Transforms the specified <paramref name="collection"/> with the <paramref name="select"/> function and returns
/// the items that are not null. Or if the <paramref name="where"/> function is given then those that return <see
/// langword="true"/> with it.
/// the items that return <see langword="true"/> when passed to the <paramref name="where"/> function.
/// </summary>
public static IEnumerable<TOut> SelectWhere<TIn, TOut>(
this IEnumerable<TIn> collection,
Func<TIn, TOut> select,
Func<TOut, bool> where = null)
Func<TOut, bool> where)
{
foreach (var item in collection)
{
var converted = select(item);
if (where.Invoke(converted)) yield return converted;
}
}

/// <summary>
/// Transforms the specified <paramref name="collection"/> with the <paramref name="select"/> function and returns
/// the items that are not null.
/// </summary>
public static IEnumerable<TOut> SelectWhere<TIn, TOut>(this IEnumerable<TIn> collection, Func<TIn, TOut?> select)
where TOut : notnull
{
foreach (var item in collection ?? [])
foreach (var item in collection)
{
var converted = select(item);
if (where?.Invoke(converted) ?? converted is not null) yield return converted;
if (converted is not null) yield return converted;
}
}

Expand All @@ -192,6 +207,7 @@ public static Dictionary<TKey, TValue> ToDictionaryOverwrite<TIn, TKey, TValue>(
this IEnumerable<TIn> collection,
Func<TIn, TKey> keySelector,
Func<TIn, TValue> valueSelector)
where TKey : notnull
{
var dictionary = new Dictionary<TKey, TValue>();
foreach (var item in collection) dictionary[keySelector(item)] = valueSelector(item);
Expand All @@ -208,7 +224,8 @@ public static Dictionary<TKey, TValue> ToDictionaryOverwrite<TIn, TKey, TValue>(
Justification = "This is the point of the method.")]
public static Dictionary<TKey, TIn> ToDictionaryOverwrite<TIn, TKey>(
this IEnumerable<TIn> collection,
Func<TIn, TKey> keySelector) =>
Func<TIn, TKey> keySelector)
where TKey : notnull =>
ToDictionaryOverwrite(collection, keySelector, item => item);

/// <summary>
Expand All @@ -221,7 +238,7 @@ public static Dictionary<TKey, TIn> ToDictionaryOverwrite<TIn, TKey>(
/// after grouping.
/// </para>
/// </remarks>
public static IEnumerable<TItem> Unique<TItem, TKey>(
public static IEnumerable<TItem?> Unique<TItem, TKey>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector) =>
collection.GroupBy(keySelector).Select(group => group.FirstOrDefault());
Expand All @@ -230,7 +247,7 @@ public static IEnumerable<TItem> Unique<TItem, TKey>(
/// Returns the <paramref name="collection"/> without any duplicate items picking the first of each when sorting by
/// <paramref name="orderBySelector"/>.
/// </summary>
public static IEnumerable<TItem> Unique<TItem, TKey, TOrder>(
public static IEnumerable<TItem?> Unique<TItem, TKey, TOrder>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector,
Func<TItem, TOrder> orderBySelector) =>
Expand All @@ -242,7 +259,7 @@ public static IEnumerable<TItem> Unique<TItem, TKey, TOrder>(
/// Returns the <paramref name="collection"/> without any duplicate items picking the last of each when sorting by
/// <paramref name="orderBySelector"/>.
/// </summary>
public static IEnumerable<TItem> UniqueDescending<TItem, TKey, TOrder>(
public static IEnumerable<TItem?> UniqueDescending<TItem, TKey, TOrder>(
this IEnumerable<TItem> collection,
Func<TItem, TKey> keySelector,
Func<TItem, TOrder> orderBySelector) =>
Expand All @@ -254,11 +271,13 @@ public static IEnumerable<TItem> UniqueDescending<TItem, TKey, TOrder>(
/// Returns a string that joins the string collection. It excludes null or empty items if there are any.
/// </summary>
/// <returns>The concatenated texts if there are any nonempty, otherwise <see langword="null"/>.</returns>
public static string JoinNotNullOrEmpty(this IEnumerable<string> strings, string separator = ",")
public static string? JoinNotNullOrEmpty(this IEnumerable<string>? strings, string separator = ",")
{
var filteredStrings = strings?.Where(text => !string.IsNullOrWhiteSpace(text)).ToList();
var filteredStrings = (strings ?? [])
.Where(text => !string.IsNullOrWhiteSpace(text))
.ToList();

return filteredStrings?.Count > 0
return filteredStrings.Count > 0
? string.Join(separator, filteredStrings)
: null;
}
Expand All @@ -271,7 +290,7 @@ public static string JoinNotNullOrEmpty(this IEnumerable<string> strings, string
/// <returns>
/// A new <see cref="string"/> that concatenates all values with the <paramref name="separator"/> provided.
/// </returns>
public static string Join(this IEnumerable<string> values, string separator = " ") =>
public static string Join(this IEnumerable<string>? values, string separator = " ") =>
string.Join(separator, values ?? []);

/// <summary>
Expand All @@ -293,14 +312,14 @@ public static IEnumerable<T> WhereNot<T>(this IEnumerable<T> collection, Func<T,
/// Returns <paramref name="collection"/> if it's not <see langword="null"/>, otherwise <see
/// cref="Enumerable.Empty{TResult}"/>.
/// </summary>
public static IEnumerable<T> EmptyIfNull<T>(this IEnumerable<T> collection) =>
public static IEnumerable<T> EmptyIfNull<T>(this IEnumerable<T>? collection) =>
collection ?? [];

/// <summary>
/// Returns <paramref name="array"/> if it's not <see langword="null"/>, otherwise <see
/// cref="Array.Empty{TResult}"/>.
/// </summary>
public static IEnumerable<T> EmptyIfNull<T>(this T[] array) =>
public static IEnumerable<T> EmptyIfNull<T>(this T[]? array) =>
array ?? [];

/// <summary>
Expand Down Expand Up @@ -333,7 +352,7 @@ public static IEnumerable<TResult> Select<TKey, TValue, TResult>(
/// Similar to <see cref="Enumerable.Cast{TResult}"/>, but it checks if the types are correct first, and filters out
/// the ones that couldn't be cast. The optional <paramref name="predicate"/> can filter the cast items.
/// </summary>
public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, bool> predicate = null)
public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, bool>? predicate = null)
{
if (enumerable is IEnumerable<T> alreadyCast)
{
Expand All @@ -342,7 +361,7 @@ public static IEnumerable<T> CastWhere<T>(this IEnumerable enumerable, Func<T, b
: alreadyCast.Where(predicate);
}

static IEnumerable<T> Iterate(IEnumerable enumerable, Func<T, bool> predicate)
static IEnumerable<T> Iterate(IEnumerable enumerable, Func<T, bool>? predicate)
{
foreach (var item in enumerable)
{
Expand Down Expand Up @@ -398,10 +417,38 @@ public static Task InvokeFirstOrCompletedAsync<T>(this IEnumerable<T> enumerable
/// If the <paramref name="enumerable"/> is not empty, invokes the <paramref name="funcAsync"/> on the first item
/// and returns its result, otherwise returns <see langword="default"/> for <typeparamref name="TResult"/>.
/// </summary>
public static Task<TResult> InvokeFirstOrDefaultAsync<TItem, TResult>(
public static Task<TResult?> InvokeFirstOrDefaultAsync<TItem, TResult>(
this IEnumerable<TItem> enumerable,
Func<TItem, Task<TResult>> funcAsync) =>
Func<TItem, Task<TResult?>> funcAsync) =>
enumerable.FirstOrDefault() is { } item
? funcAsync(item)
: Task.FromResult(default(TResult));
: Task.FromResult(default(TResult?));

/// <summary>
/// Splits the provided <paramref name="enumerable"/> into two.
/// </summary>
/// <param name="enumerable">The original collection to be tested.</param>
/// <param name="leftPredicate">
/// Tests each item of <paramref name="enumerable"/>. If returns <see langword="true"/>, the
/// item is added to the left collection, otherwise added to the right collection.
/// </param>
/// <typeparam name="T">The type of the items in <paramref name="enumerable"/>.</typeparam>
/// <returns>A tuple of two collections. Each item in <paramref name="enumerable"/> is in one of them.</returns>
public static (IList<T> Left, IList<T> Right) Fork<T>(
this IEnumerable<T>? enumerable,
Func<T, bool> leftPredicate)
{
var left = new List<T>();
var right = new List<T>();

if (enumerable is null) return (left, right);

foreach (var item in enumerable)
{
var target = leftPredicate(item) ? left : right;
target.Add(item);
}

return (left, right);
}
}

0 comments on commit 78bcf2a

Please sign in to comment.