How to extend LINQ

All LINQ based methods follow one of two similar patterns. They take an enumerable sequence. They return either a different sequence, or a single value. The consistency of the shape enables you to extend LINQ by writing methods with a similar shape. In fact, the .NET libraries have gained new methods in many .NET releases since LINQ was first introduced. In this article, you see examples of extending LINQ by writing your own methods that follow the same pattern.

Add custom methods for LINQ queries

You extend the set of methods that you use for LINQ queries by adding extension methods to the IEnumerable<T> interface. For example, in addition to the standard average or maximum operations, you create a custom aggregate method to compute a single value from a sequence of values. You also create a method that works as a custom filter or a specific data transform for a sequence of values and returns a new sequence. Examples of such methods are Distinct, Skip, and Reverse.

When you extend the IEnumerable<T> interface, you can apply your custom methods to any enumerable collection. For more information, see Extension Methods.

An aggregate method computes a single value from a set of values. LINQ provides several aggregate methods, including Average, Min, and Max. You can create your own aggregate method by adding an extension method to the IEnumerable<T> interface.

The following code example shows how to create an extension method called Median to compute a median for a sequence of numbers of type double.

public static class EnumerableExtension
{
    public static double Median(this IEnumerable<double>? source)
    {
        if (source is null || !source.Any())
        {
            throw new InvalidOperationException("Cannot compute median for a null or empty set.");
        }

        var sortedList =
            source.OrderBy(number => number).ToList();

        int itemIndex = sortedList.Count / 2;

        if (sortedList.Count % 2 == 0)
        {
            // Even number of items.
            return (sortedList[itemIndex] + sortedList[itemIndex - 1]) / 2;
        }
        else
        {
            // Odd number of items.
            return sortedList[itemIndex];
        }
    }
}

You call this extension method for any enumerable collection in the same way you call other aggregate methods from the IEnumerable<T> interface.

The following code example shows how to use the Median method for an array of type double.

double[] numbers = [1.9, 2, 8, 4, 5.7, 6, 7.2, 0];
var query = numbers.Median();

Console.WriteLine($"double: Median = {query}");
// This code produces the following output:
//     double: Median = 4.85

You can overload your aggregate method so that it accepts sequences of various types. The standard approach is to create an overload for each type. Another approach is to create an overload that takes a generic type and convert it to a specific type by using a delegate. You can also combine both approaches.

You can create a specific overload for each type that you want to support. The following code example shows an overload of the Median method for the int type.

// int overload
public static double Median(this IEnumerable<int> source) =>
    (from number in source select (double)number).Median();

You can now call the Median overloads for both integer and double types, as shown in the following code:

double[] numbers1 = [1.9, 2, 8, 4, 5.7, 6, 7.2, 0];
var query1 = numbers1.Median();

Console.WriteLine($"double: Median = {query1}");

int[] numbers2 = [1, 2, 3, 4, 5];
var query2 = numbers2.Median();

Console.WriteLine($"int: Median = {query2}");
// This code produces the following output:
//     double: Median = 4.85
//     int: Median = 3

You can also create an overload that accepts a generic sequence of objects. This overload takes a delegate as a parameter and uses it to convert a sequence of objects of a generic type to a specific type.

The following code shows an overload of the Median method that takes the Func<T,TResult> delegate as a parameter. This delegate takes an object of generic type T and returns an object of type double.

// generic overload
public static double Median<T>(
    this IEnumerable<T> numbers, Func<T, double> selector) =>
    (from num in numbers select selector(num)).Median();

You can now call the Median method for a sequence of objects of any type. If the type doesn't have its own method overload, you have to pass a delegate parameter. In C#, you can use a lambda expression for this purpose. Also, in Visual Basic only, if you use the Aggregate or Group By clause instead of the method call, you can pass any value or expression that is in the scope this clause.

The following example code shows how to call the Median method for an array of integers and an array of strings. For strings, the median for the lengths of strings in the array is calculated. The example shows how to pass the Func<T,TResult> delegate parameter to the Median method for each case.

int[] numbers3 = [1, 2, 3, 4, 5];

/*
    You can use the num => num lambda expression as a parameter for the Median method
    so that the compiler will implicitly convert its value to double.
    If there is no implicit conversion, the compiler will display an error message.
*/
var query3 = numbers3.Median(num => num);

Console.WriteLine($"int: Median = {query3}");

string[] numbers4 = ["one", "two", "three", "four", "five"];

// With the generic overload, you can also use numeric properties of objects.
var query4 = numbers4.Median(str => str.Length);

Console.WriteLine($"string: Median = {query4}");
// This code produces the following output:
//     int: Median = 3
//     string: Median = 4

You can extend the IEnumerable<T> interface with a custom query method that returns a sequence of values. In this case, the method must return a collection of type IEnumerable<T>. Such methods can be used to apply filters or data transforms to a sequence of values.

The following example shows how to create an extension method named AlternateElements that returns every other element in a collection, starting from the first element.

// Extension method for the IEnumerable<T> interface.
// The method returns every other element of a sequence.
public static IEnumerable<T> AlternateElements<T>(this IEnumerable<T> source)
{
    int index = 0;
    foreach (T element in source)
    {
        if (index % 2 == 0)
        {
            yield return element;
        }

        index++;
    }
}

You can call this extension method for any enumerable collection just as you would call other methods from the IEnumerable<T> interface, as shown in the following code:

string[] strings = ["a", "b", "c", "d", "e"];

var query5 = strings.AlternateElements();

foreach (var element in query5)
{
    Console.WriteLine(element);
}
// This code produces the following output:
//     a
//     c
//     e

Group results by contiguous keys

The following example shows how to group elements into chunks that represent subsequences of contiguous keys. For example, assume that you're given the following sequence of key-value pairs:

Key Value
A We
A think
A that
B Linq
C is
A really
B cool
B !

The following groups are created in this order:

  1. We, think, that
  2. Linq
  3. is
  4. really
  5. cool, !

The solution is implemented as a thread-safe extension method that returns its results in a streaming manner. It produces its groups as it moves through the source sequence. Unlike the group or orderby operators, it can begin returning groups to the caller before reading the entire sequence. The following example shows both the extension method and the client code that uses it:

public static class ChunkExtensions
{
    public static IEnumerable<IGrouping<TKey, TSource>> ChunkBy<TSource, TKey>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector) =>
                source.ChunkBy(keySelector, EqualityComparer<TKey>.Default);

    public static IEnumerable<IGrouping<TKey, TSource>> ChunkBy<TSource, TKey>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector,
            IEqualityComparer<TKey> comparer)
    {
        // Flag to signal end of source sequence.
        const bool noMoreSourceElements = true;

        // Auto-generated iterator for the source array.
        IEnumerator<TSource>? enumerator = source.GetEnumerator();

        // Move to the first element in the source sequence.
        if (!enumerator.MoveNext())
        {
            yield break;        // source collection is empty
        }

        while (true)
        {
            var key = keySelector(enumerator.Current);

            Chunk<TKey, TSource> current = new(key, enumerator, value => comparer.Equals(key, keySelector(value)));

            yield return current;

            if (current.CopyAllChunkElements() == noMoreSourceElements)
            {
                yield break;
            }
        }
    }
}
public static class GroupByContiguousKeys
{
    // The source sequence.
    static readonly KeyValuePair<string, string>[] list = [
        new("A", "We"),
        new("A", "think"),
        new("A", "that"),
        new("B", "LINQ"),
        new("C", "is"),
        new("A", "really"),
        new("B", "cool"),
        new("B", "!")
    ];

    // Query variable declared as class member to be available
    // on different threads.
    static readonly IEnumerable<IGrouping<string, KeyValuePair<string, string>>> query =
        list.ChunkBy(p => p.Key);

    public static void GroupByContiguousKeys1()
    {
        // ChunkBy returns IGrouping objects, therefore a nested
        // foreach loop is required to access the elements in each "chunk".
        foreach (var item in query)
        {
            Console.WriteLine($"Group key = {item.Key}");
            foreach (var inner in item)
            {
                Console.WriteLine($"\t{inner.Value}");
            }
        }
    }
}

ChunkExtensions class

In the presented code of the ChunkExtensions class implementation, the while(true) loop in the ChunkBy method iterates through source sequence and creates a copy of each Chunk. On each pass, the iterator advances to the first element of the next "Chunk", represented by a Chunk object, in the source sequence. This loop corresponds to the outer foreach loop that executes the query. In that loop, the code does the following actions:

  1. Get the key for the current Chunk and assign it to key variable. The source iterator consumes the source sequence until it finds an element with a key that doesn't match.
  2. Make a new Chunk (group) object, and store it in current variable. It has one GroupItem, a copy of the current source element.
  3. Return that Chunk. A Chunk is an IGrouping<TKey,TSource>, which is the return value of the ChunkBy method. The Chunk only has the first element in its source sequence. The remaining elements are returned only when the client code foreach's over this chunk. See Chunk.GetEnumerator for more info.
  4. Check to see if:
    • The chunk has a copy of all its source elements, or
    • The iterator reached the end of the source sequence.
  5. When the caller has enumerated all the chunk items, the Chunk.GetEnumerator method has copied all chunk items. If the Chunk.GetEnumerator loop didn't enumerate all elements in the chunk, do it now to avoid corrupting the iterator for clients that might be calling it on a separate thread.

Chunk class

The Chunk class is a contiguous group of one or more source elements that have the same key. A Chunk has a key and a list of ChunkItem objects, which are copies of the elements in the source sequence:

class Chunk<TKey, TSource> : IGrouping<TKey, TSource>
{
    // INVARIANT: DoneCopyingChunk == true ||
    //   (predicate != null && predicate(enumerator.Current) && current.Value == enumerator.Current)

    // A Chunk has a linked list of ChunkItems, which represent the elements in the current chunk. Each ChunkItem
    // has a reference to the next ChunkItem in the list.
    class ChunkItem
    {
        public ChunkItem(TSource value) => Value = value;
        public readonly TSource Value;
        public ChunkItem? Next;
    }

    public TKey Key { get; }

    // Stores a reference to the enumerator for the source sequence
    private IEnumerator<TSource> enumerator;

    // A reference to the predicate that is used to compare keys.
    private Func<TSource, bool> predicate;

    // Stores the contents of the first source element that
    // belongs with this chunk.
    private readonly ChunkItem head;

    // End of the list. It is repositioned each time a new
    // ChunkItem is added.
    private ChunkItem? tail;

    // Flag to indicate the source iterator has reached the end of the source sequence.
    internal bool isLastSourceElement;

    // Private object for thread synchronization
    private readonly object m_Lock;

    // REQUIRES: enumerator != null && predicate != null
    public Chunk(TKey key, [DisallowNull] IEnumerator<TSource> enumerator, [DisallowNull] Func<TSource, bool> predicate)
    {
        Key = key;
        this.enumerator = enumerator;
        this.predicate = predicate;

        // A Chunk always contains at least one element.
        head = new ChunkItem(enumerator.Current);

        // The end and beginning are the same until the list contains > 1 elements.
        tail = head;

        m_Lock = new object();
    }

    // Indicates that all chunk elements have been copied to the list of ChunkItems.
    private bool DoneCopyingChunk => tail == null;

    // Adds one ChunkItem to the current group
    // REQUIRES: !DoneCopyingChunk && lock(this)
    private void CopyNextChunkElement()
    {
        // Try to advance the iterator on the source sequence.
        isLastSourceElement = !enumerator.MoveNext();

        // If we are (a) at the end of the source, or (b) at the end of the current chunk
        // then null out the enumerator and predicate for reuse with the next chunk.
        if (isLastSourceElement || !predicate(enumerator.Current))
        {
            enumerator = default!;
            predicate = default!;
        }
        else
        {
            tail!.Next = new ChunkItem(enumerator.Current);
        }

        // tail will be null if we are at the end of the chunk elements
        // This check is made in DoneCopyingChunk.
        tail = tail!.Next;
    }

    // Called after the end of the last chunk was reached.
    internal bool CopyAllChunkElements()
    {
        while (true)
        {
            lock (m_Lock)
            {
                if (DoneCopyingChunk)
                {
                    return isLastSourceElement;
                }
                else
                {
                    CopyNextChunkElement();
                }
            }
        }
    }

    // Stays just one step ahead of the client requests.
    public IEnumerator<TSource> GetEnumerator()
    {
        // Specify the initial element to enumerate.
        ChunkItem? current = head;

        // There should always be at least one ChunkItem in a Chunk.
        while (current != null)
        {
            // Yield the current item in the list.
            yield return current.Value;

            // Copy the next item from the source sequence,
            // if we are at the end of our local list.
            lock (m_Lock)
            {
                if (current == tail)
                {
                    CopyNextChunkElement();
                }
            }

            // Move to the next ChunkItem in the list.
            current = current.Next;
        }
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() => GetEnumerator();
}

Each ChunkItem (represented by ChunkItem class) has a reference to the next ChunkItem in the list. The list consists of its head - which stores the contents of the first source element that belongs with this chunk, and its tail - which is an end of the list. The tail is repositioned each time a new ChunkItem is added. The tail of the linked list is set to null in the CopyNextChunkElement method if the key of the next element doesn't match the current chunk's key, or there are no more elements in the source.

The CopyNextChunkElement method of the Chunk class adds one ChunkItem to the current group of items. It tries to advance the iterator on the source sequence. If the MoveNext() method returns false the iteration is at the end, and isLastSourceElement is set to true.

The CopyAllChunkElements method is called after the end of the last chunk was reached. It checks whether there are more elements in the source sequence. If there are, it returns true if the enumerator for this chunk was exhausted. In this method, when the private DoneCopyingChunk field is checked for true, if isLastSourceElement is false, it signals to the outer iterator to continue iterating.

The inner foreach loop invokes the GetEnumerator method of the Chunk class. This method stays one element ahead of the client requests. It adds the next element of the chunk only after the client requests the previous last element in the list.