如何:实现表达式目录树访问器

更新:2007 年 11 月

本主题中的代码示例是表达式目录树访问器的实现。此类旨在被继承,以创建更多具有需要遍历、检查或复制表达式目录树的功能的专门类。

下面的主题使用此类:

这两个主题均包含用于创建本主题中介绍的表达式目录树访问器基类的专门访问器子类的代码。

示例

在表达式目录树访问器实现中,应首先调用的 Visit 方法,该方法基于表达式的类型调度传递到类中更专门的访问器方法之一的表达式。专门的访问器方法访问传递它们的表达式的子树。例如,如果派生类的重写方法访问某个子表达式后,该子表达式发生了更改,则专门的访问器方法会创建一个包含子树中的更改的新表达式。否则,这些方法返回传递它们的表达式。此递归行为可以生成新的表达式目录树,新的表达式目录树或者与传递到 Visit 中的原始表达式版本相同,或者为其修改后的版本。

Public MustInherit Class ExpressionVisitor

    Protected Sub New()
    End Sub

    Protected Overridable Function Visit(ByVal exp As Expression) As Expression
        If exp Is Nothing Then
            Return exp
        End If

        Select Case exp.NodeType
            Case ExpressionType.Negate, _
                 ExpressionType.NegateChecked, _
                 ExpressionType.Not, _
                 ExpressionType.Convert, _
                 ExpressionType.ConvertChecked, _
                 ExpressionType.ArrayLength, _
                 ExpressionType.Quote, _
                 ExpressionType.TypeAs
                Return Me.VisitUnary(CType(exp, UnaryExpression))
            Case ExpressionType.Add, _
                 ExpressionType.AddChecked, _
                 ExpressionType.Subtract, _
                 ExpressionType.SubtractChecked, _
                 ExpressionType.Multiply, _
                 ExpressionType.MultiplyChecked, _
                 ExpressionType.Divide, _
                 ExpressionType.Modulo, _
                 ExpressionType.And, _
                 ExpressionType.AndAlso, _
                 ExpressionType.Or, _
                 ExpressionType.OrElse, _
                 ExpressionType.LessThan, _
                 ExpressionType.LessThanOrEqual, _
                 ExpressionType.GreaterThan, _
                 ExpressionType.GreaterThanOrEqual, _
                 ExpressionType.Equal, _
                 ExpressionType.NotEqual, _
                 ExpressionType.Coalesce, _
                 ExpressionType.ArrayIndex, _
                 ExpressionType.RightShift, _
                 ExpressionType.LeftShift, _
                 ExpressionType.ExclusiveOr
                Return Me.VisitBinary(CType(exp, BinaryExpression))
            Case ExpressionType.TypeIs
                Return Me.VisitTypeIs(CType(exp, TypeBinaryExpression))
            Case ExpressionType.Conditional
                Return Me.VisitConditional(CType(exp, ConditionalExpression))
            Case ExpressionType.Constant
                Return Me.VisitConstant(CType(exp, ConstantExpression))
            Case ExpressionType.Parameter
                Return Me.VisitParameter(CType(exp, ParameterExpression))
            Case ExpressionType.MemberAccess
                Return Me.VisitMemberAccess(CType(exp, MemberExpression))
            Case ExpressionType.Call
                Return Me.VisitMethodCall(CType(exp, MethodCallExpression))
            Case ExpressionType.Lambda
                Return Me.VisitLambda(CType(exp, LambdaExpression))
            Case ExpressionType.New
                Return Me.VisitNew(CType(exp, NewExpression))
            Case ExpressionType.NewArrayInit, _
                 ExpressionType.NewArrayBounds
                Return Me.VisitNewArray(CType(exp, NewArrayExpression))
            Case ExpressionType.Invoke
                Return Me.VisitInvocation(CType(exp, InvocationExpression))
            Case ExpressionType.MemberInit
                Return Me.VisitMemberInit(CType(exp, MemberInitExpression))
            Case ExpressionType.ListInit
                Return Me.VisitListInit(CType(exp, ListInitExpression))
            Case Else
                Throw New Exception("Unhandled expression type: '" & exp.NodeType & "'")
        End Select
    End Function

    Protected Overridable Function VisitBinding(ByVal binding As MemberBinding) As MemberBinding
        Select Case binding.BindingType
            Case MemberBindingType.Assignment
                Return Me.VisitMemberAssignment(CType(binding, MemberAssignment))
            Case MemberBindingType.MemberBinding
                Return Me.VisitMemberMemberBinding(CType(binding, MemberMemberBinding))
            Case MemberBindingType.ListBinding
                Return Me.VisitMemberListBinding(CType(binding, MemberListBinding))
            Case Else
                Throw New Exception("Unhandled binding type '" & binding.BindingType & "'")
        End Select
    End Function

    Protected Overridable Function VisitElementInitializer(ByVal initializer As ElementInit) _
        As ElementInit

        Dim arguments = Me.VisitExpressionList(initializer.Arguments)

        If arguments IsNot initializer.Arguments Then
            Return Expression.ElementInit(initializer.AddMethod, arguments)
        End If

        Return initializer
    End Function

    Protected Overridable Function VisitUnary(ByVal u As UnaryExpression) As Expression
        Dim operand = Me.Visit(u.Operand)

        If operand IsNot u.Operand Then
            Return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method)
        End If

        Return u
    End Function

    Protected Overridable Function VisitBinary(ByVal b As BinaryExpression) As Expression
        Dim left = Me.Visit(b.Left)
        Dim right = Me.Visit(b.Right)
        Dim conversion = Me.Visit(b.Conversion)

        If left IsNot b.Left Or right IsNot b.Right Or conversion IsNot b.Conversion Then

            If b.NodeType = ExpressionType.Coalesce And b.Conversion IsNot Nothing Then
                Return Expression.Coalesce(left, right, _
                                           TryCast(conversion, LambdaExpression))
            Else
                Return Expression.MakeBinary(b.NodeType, left, right, _
                                             b.IsLiftedToNull, b.Method)
            End If
        End If

        Return b
    End Function

    Protected Overridable Function VisitTypeIs(ByVal b As TypeBinaryExpression) As Expression
        Dim expr = Me.Visit(b.Expression)

        If expr IsNot b.Expression Then
            Return Expression.TypeIs(expr, b.TypeOperand)
        End If

        Return b
    End Function

    Protected Overridable Function VisitConstant(ByVal c As ConstantExpression) As Expression
        Return c
    End Function

    Protected Overridable Function VisitConditional(ByVal c As ConditionalExpression) As Expression
        Dim test = Me.Visit(c.Test)
        Dim ifTrue = Me.Visit(c.IfTrue)
        Dim ifFalse = Me.Visit(c.IfFalse)

        If test IsNot c.Test Or ifTrue IsNot c.IfTrue Or ifFalse IsNot c.IfFalse Then
            Return Expression.Condition(test, ifTrue, ifFalse)
        End If

        Return c
    End Function

    Protected Overridable Function VisitParameter(ByVal p As ParameterExpression) As Expression
        Return p
    End Function

    Protected Overridable Function VisitMemberAccess(ByVal m As MemberExpression) As Expression
        Dim exp = Me.Visit(m.Expression)

        If exp IsNot m.Expression Then
            Return Expression.MakeMemberAccess(exp, m.Member)
        End If

        Return m
    End Function

    Protected Overridable Function VisitMethodCall(ByVal m As MethodCallExpression) As Expression
        Dim obj = Me.Visit(m.Object)
        Dim args = Me.VisitExpressionList(m.Arguments)

        If obj IsNot m.Object Or args IsNot m.Arguments Then
            Return Expression.Call(obj, m.Method, args)
        End If

        Return m
    End Function

    Protected Overridable Function VisitExpressionList( _
        ByVal original As ReadOnlyCollection(Of Expression)) As ReadOnlyCollection(Of Expression)

        Dim list As List(Of Expression) = Nothing
        Dim n = original.Count

        For i = 0 To n - 1
            Dim p = Me.Visit(original(i))

            If list IsNot Nothing Then
                list.Add(p)
            ElseIf p IsNot original(i) Then
                list = New List(Of Expression)(n)

                For j = 0 To i - 1
                    list.Add(original(j))
                Next j
                list.Add(p)
            End If
        Next i

        If list IsNot Nothing Then
            Return list.AsReadOnly()
        End If

        Return original
    End Function

    Protected Overridable Function VisitMemberAssignment(ByVal assignment As MemberAssignment) _
        As MemberAssignment

        Dim e = Me.Visit(assignment.Expression)

        If e IsNot assignment.Expression Then
            Return Expression.Bind(assignment.Member, e)
        End If

        Return assignment
    End Function

    Protected Overridable Function VisitMemberMemberBinding(ByVal binding As MemberMemberBinding) _
        As MemberMemberBinding

        Dim bindings = Me.VisitBindingList(binding.Bindings)

        If bindings IsNot binding.Bindings Then
            Return Expression.MemberBind(binding.Member, bindings)
        End If

        Return binding
    End Function

    Protected Overridable Function VisitMemberListBinding(ByVal binding As MemberListBinding) _
        As MemberListBinding

        Dim initializers = Me.VisitElementInitializerList(binding.Initializers)

        If initializers IsNot binding.Initializers Then
            Return Expression.ListBind(binding.Member, initializers)
        End If

        Return binding
    End Function

    Protected Overridable Function VisitBindingList( _
        ByVal original As ReadOnlyCollection(Of MemberBinding)) As IEnumerable(Of MemberBinding)

        Dim list As List(Of MemberBinding) = Nothing
        Dim n = original.Count

        For i = 0 To n - 1
            Dim b = Me.VisitBinding(original(i))

            If list IsNot Nothing Then
                list.Add(b)
            ElseIf b IsNot original(i) Then
                list = New List(Of MemberBinding)(n)
                For j = 0 To i - 1
                    list.Add(original(j))
                Next j
                list.Add(b)
            End If

        Next i

        If list IsNot Nothing Then
            Return list
        End If

        Return original
    End Function

    Protected Overridable Function VisitElementInitializerList( _
        ByVal original As ReadOnlyCollection(Of ElementInit)) As IEnumerable(Of ElementInit)

        Dim list As List(Of ElementInit) = Nothing
        Dim n = original.Count

        For i = 0 To n - 1
            Dim init = Me.VisitElementInitializer(original(i))
            If list IsNot Nothing Then
                list.Add(init)
            ElseIf init IsNot original(i) Then
                list = New List(Of ElementInit)(n)
                For j = 0 To i - 1
                    list.Add(original(j))
                Next j
                list.Add(init)
            End If
        Next i

        If list IsNot Nothing Then
            Return list
        End If

        Return original
    End Function

    Protected Overridable Function VisitLambda(ByVal lambda As LambdaExpression) As Expression
        Dim body = Me.Visit(lambda.Body)

        If body IsNot lambda.Body Then
            Return Expression.Lambda(lambda.Type, body, lambda.Parameters)
        End If
        Return lambda
    End Function

    Protected Overridable Function VisitNew(ByVal nex As NewExpression) As NewExpression
        Dim args = Me.VisitExpressionList(nex.Arguments)

        If args IsNot nex.Arguments Then
            If nex.Members IsNot Nothing Then
                Return Expression.[New](nex.Constructor, args, nex.Members)
            Else
                Return Expression.[New](nex.Constructor, args)
            End If
        End If

        Return nex
    End Function

    Protected Overridable Function VisitMemberInit(ByVal init As MemberInitExpression) As Expression
        Dim n = Me.VisitNew(init.NewExpression)
        Dim bindings = Me.VisitBindingList(init.Bindings)

        If n IsNot init.NewExpression Or bindings IsNot init.Bindings Then
            Return Expression.MemberInit(n, bindings)
        End If

        Return init
    End Function

    Protected Overridable Function VisitListInit(ByVal init As ListInitExpression) As Expression
        Dim n = Me.VisitNew(init.NewExpression)
        Dim initializers = Me.VisitElementInitializerList(init.Initializers)

        If n IsNot init.NewExpression Or initializers IsNot init.Initializers Then
            Return Expression.ListInit(n, initializers)
        End If

        Return init
    End Function

    Protected Overridable Function VisitNewArray(ByVal na As NewArrayExpression) As Expression
        Dim exprs = Me.VisitExpressionList(na.Expressions)
        If exprs IsNot na.Expressions Then
            If na.NodeType = ExpressionType.NewArrayInit Then
                Return Expression.NewArrayInit(na.Type.GetElementType(), exprs)
            Else
                Return Expression.NewArrayBounds(na.Type.GetElementType(), exprs)
            End If
        End If

        Return na
    End Function

    Protected Overridable Function VisitInvocation(ByVal iv As InvocationExpression) As Expression
        Dim args = Me.VisitExpressionList(iv.Arguments)
        Dim expr = Me.Visit(iv.Expression)

        If args IsNot iv.Arguments Or expr IsNot iv.Expression Then
            Return Expression.Invoke(expr, args)
        End If

        Return iv
    End Function
End Class
public abstract class ExpressionVisitor
{
    protected ExpressionVisitor()
    {
    }

    protected virtual Expression Visit(Expression exp)
    {
        if (exp == null)
            return exp;
        switch (exp.NodeType)
        {
            case ExpressionType.Negate:
            case ExpressionType.NegateChecked:
            case ExpressionType.Not:
            case ExpressionType.Convert:
            case ExpressionType.ConvertChecked:
            case ExpressionType.ArrayLength:
            case ExpressionType.Quote:
            case ExpressionType.TypeAs:
                return this.VisitUnary((UnaryExpression)exp);
            case ExpressionType.Add:
            case ExpressionType.AddChecked:
            case ExpressionType.Subtract:
            case ExpressionType.SubtractChecked:
            case ExpressionType.Multiply:
            case ExpressionType.MultiplyChecked:
            case ExpressionType.Divide:
            case ExpressionType.Modulo:
            case ExpressionType.And:
            case ExpressionType.AndAlso:
            case ExpressionType.Or:
            case ExpressionType.OrElse:
            case ExpressionType.LessThan:
            case ExpressionType.LessThanOrEqual:
            case ExpressionType.GreaterThan:
            case ExpressionType.GreaterThanOrEqual:
            case ExpressionType.Equal:
            case ExpressionType.NotEqual:
            case ExpressionType.Coalesce:
            case ExpressionType.ArrayIndex:
            case ExpressionType.RightShift:
            case ExpressionType.LeftShift:
            case ExpressionType.ExclusiveOr:
                return this.VisitBinary((BinaryExpression)exp);
            case ExpressionType.TypeIs:
                return this.VisitTypeIs((TypeBinaryExpression)exp);
            case ExpressionType.Conditional:
                return this.VisitConditional((ConditionalExpression)exp);
            case ExpressionType.Constant:
                return this.VisitConstant((ConstantExpression)exp);
            case ExpressionType.Parameter:
                return this.VisitParameter((ParameterExpression)exp);
            case ExpressionType.MemberAccess:
                return this.VisitMemberAccess((MemberExpression)exp);
            case ExpressionType.Call:
                return this.VisitMethodCall((MethodCallExpression)exp);
            case ExpressionType.Lambda:
                return this.VisitLambda((LambdaExpression)exp);
            case ExpressionType.New:
                return this.VisitNew((NewExpression)exp);
            case ExpressionType.NewArrayInit:
            case ExpressionType.NewArrayBounds:
                return this.VisitNewArray((NewArrayExpression)exp);
            case ExpressionType.Invoke:
                return this.VisitInvocation((InvocationExpression)exp);
            case ExpressionType.MemberInit:
                return this.VisitMemberInit((MemberInitExpression)exp);
            case ExpressionType.ListInit:
                return this.VisitListInit((ListInitExpression)exp);
            default:
                throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
        }
    }

    protected virtual MemberBinding VisitBinding(MemberBinding binding)
    {
        switch (binding.BindingType)
        {
            case MemberBindingType.Assignment:
                return this.VisitMemberAssignment((MemberAssignment)binding);
            case MemberBindingType.MemberBinding:
                return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
            case MemberBindingType.ListBinding:
                return this.VisitMemberListBinding((MemberListBinding)binding);
            default:
                throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
        }
    }

    protected virtual ElementInit VisitElementInitializer(ElementInit initializer)
    {
        ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
        if (arguments != initializer.Arguments)
        {
            return Expression.ElementInit(initializer.AddMethod, arguments);
        }
        return initializer;
    }

    protected virtual Expression VisitUnary(UnaryExpression u)
    {
        Expression operand = this.Visit(u.Operand);
        if (operand != u.Operand)
        {
            return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
        }
        return u;
    }

    protected virtual Expression VisitBinary(BinaryExpression b)
    {
        Expression left = this.Visit(b.Left);
        Expression right = this.Visit(b.Right);
        Expression conversion = this.Visit(b.Conversion);
        if (left != b.Left || right != b.Right || conversion != b.Conversion)
        {
            if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)
                return Expression.Coalesce(left, right, conversion as LambdaExpression);
            else
                return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
        }
        return b;
    }

    protected virtual Expression VisitTypeIs(TypeBinaryExpression b)
    {
        Expression expr = this.Visit(b.Expression);
        if (expr != b.Expression)
        {
            return Expression.TypeIs(expr, b.TypeOperand);
        }
        return b;
    }

    protected virtual Expression VisitConstant(ConstantExpression c)
    {
        return c;
    }

    protected virtual Expression VisitConditional(ConditionalExpression c)
    {
        Expression test = this.Visit(c.Test);
        Expression ifTrue = this.Visit(c.IfTrue);
        Expression ifFalse = this.Visit(c.IfFalse);
        if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse)
        {
            return Expression.Condition(test, ifTrue, ifFalse);
        }
        return c;
    }

    protected virtual Expression VisitParameter(ParameterExpression p)
    {
        return p;
    }

    protected virtual Expression VisitMemberAccess(MemberExpression m)
    {
        Expression exp = this.Visit(m.Expression);
        if (exp != m.Expression)
        {
            return Expression.MakeMemberAccess(exp, m.Member);
        }
        return m;
    }

    protected virtual Expression VisitMethodCall(MethodCallExpression m)
    {
        Expression obj = this.Visit(m.Object);
        IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
        if (obj != m.Object || args != m.Arguments)
        {
            return Expression.Call(obj, m.Method, args);
        }
        return m;
    }

    protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original)
    {
        List<Expression> list = null;
        for (int i = 0, n = original.Count; i < n; i++)
        {
            Expression p = this.Visit(original[i]);
            if (list != null)
            {
                list.Add(p);
            }
            else if (p != original[i])
            {
                list = new List<Expression>(n);
                for (int j = 0; j < i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(p);
            }
        }
        if (list != null)
        {
            return list.AsReadOnly();
        }
        return original;
    }

    protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment)
    {
        Expression e = this.Visit(assignment.Expression);
        if (e != assignment.Expression)
        {
            return Expression.Bind(assignment.Member, e);
        }
        return assignment;
    }

    protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding)
    {
        IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
        if (bindings != binding.Bindings)
        {
            return Expression.MemberBind(binding.Member, bindings);
        }
        return binding;
    }

    protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding)
    {
        IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
        if (initializers != binding.Initializers)
        {
            return Expression.ListBind(binding.Member, initializers);
        }
        return binding;
    }

    protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original)
    {
        List<MemberBinding> list = null;
        for (int i = 0, n = original.Count; i < n; i++)
        {
            MemberBinding b = this.VisitBinding(original[i]);
            if (list != null)
            {
                list.Add(b);
            }
            else if (b != original[i])
            {
                list = new List<MemberBinding>(n);
                for (int j = 0; j < i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(b);
            }
        }
        if (list != null)
            return list;
        return original;
    }

    protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original)
    {
        List<ElementInit> list = null;
        for (int i = 0, n = original.Count; i < n; i++)
        {
            ElementInit init = this.VisitElementInitializer(original[i]);
            if (list != null)
            {
                list.Add(init);
            }
            else if (init != original[i])
            {
                list = new List<ElementInit>(n);
                for (int j = 0; j < i; j++)
                {
                    list.Add(original[j]);
                }
                list.Add(init);
            }
        }
        if (list != null)
            return list;
        return original;
    }

    protected virtual Expression VisitLambda(LambdaExpression lambda)
    {
        Expression body = this.Visit(lambda.Body);
        if (body != lambda.Body)
        {
            return Expression.Lambda(lambda.Type, body, lambda.Parameters);
        }
        return lambda;
    }

    protected virtual NewExpression VisitNew(NewExpression nex)
    {
        IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
        if (args != nex.Arguments)
        {
            if (nex.Members != null)
                return Expression.New(nex.Constructor, args, nex.Members);
            else
                return Expression.New(nex.Constructor, args);
        }
        return nex;
    }

    protected virtual Expression VisitMemberInit(MemberInitExpression init)
    {
        NewExpression n = this.VisitNew(init.NewExpression);
        IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
        if (n != init.NewExpression || bindings != init.Bindings)
        {
            return Expression.MemberInit(n, bindings);
        }
        return init;
    }

    protected virtual Expression VisitListInit(ListInitExpression init)
    {
        NewExpression n = this.VisitNew(init.NewExpression);
        IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
        if (n != init.NewExpression || initializers != init.Initializers)
        {
            return Expression.ListInit(n, initializers);
        }
        return init;
    }

    protected virtual Expression VisitNewArray(NewArrayExpression na)
    {
        IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
        if (exprs != na.Expressions)
        {
            if (na.NodeType == ExpressionType.NewArrayInit)
            {
                return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
            }
            else
            {
                return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
            }
        }
        return na;
    }

    protected virtual Expression VisitInvocation(InvocationExpression iv)
    {
        IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
        Expression expr = this.Visit(iv.Expression);
        if (args != iv.Arguments || expr != iv.Expression)
        {
            return Expression.Invoke(expr, args);
        }
        return iv;
    }
}
Bb882521.alert_note(zh-cn,VS.90).gif说明:

在此实现中,作为访问表达式目录树起点的 Visit 方法具有 protected(在 Visual Basic 中为 Protected)访问修饰符。这意味着为了可以从类的外部或其派生类访问此方法,必须创建调用 Visit 的 public(在 Visual Basic 中为 Public)方法。通过让此方法成为访问器中的一个 public(在 Visual Basic 中为 Public)方法,入口点对于调用方来说更明显了。

编译代码

  • 添加对 System.Core.dll 的引用(如果在项目中尚未引用它的话)。

  • 对于 System.Collections.Generic、System.Collections.ObjectModel 和 System.Linq.Expressions 命名空间,添加 using 指令(在 Visual Basic 中为 Imports 语句)。

请参见

任务

如何:修改表达式目录树

演练:创建 IQueryable LINQ 提供程序

概念

表达式目录树