CLR Trigger for Automatic Registration of UDXs on CREATE ASSEMBLY

This code is provided "AS IS" with no warranties, and confers no rights.

I've made some small modifications to comply with standard naming conventions on 11/30/2005.

Save the following into AutoRegisterTrigger.cs

--------------------------------------------------------

 using System;
using System;
using Microsoft.SqlServer.Server;
using System.Data;
using System.Data.SqlTypes;
using System.Data.SqlClient;
using System.Xml;
using System.Reflection;
using System.Text.RegularExpressions;

namespace Microsoft.SqlServer.Sample

{

    public class AutoRegister
    {
        [SqlTrigger(Name = "AutoRegister", Event = "FOR CREATE_ASSEMBLY", Target = "DATABASE")]
        public static void AutoRegisterTrigger()
        {
            // Checks that we are executing this trigger after a create assembly
            SqlTriggerContext triggContext = SqlContext.TriggerContext;
            if (triggContext.TriggerAction != TriggerAction.CreateAssembly)
            {
                throw new Exception("This trigger is only meaningful for CREATE ASSEMBLY.");
            }

            // Retrieves the name of the assembly being registered
            String asmName = null;
            XmlReader eventreader = triggContext.EventData.CreateReader();
            while (eventreader.Read())
            {
                if (eventreader.Name.Equals("ObjectName"))
                {
                    asmName = eventreader.ReadElementContentAsString();
                    break;
                }
            }
            if (asmName == null)
            {
                throw new Exception("Could not find the assembly name.");
            }

            // Retrieve the clr name of the assembly
            String ClrName = null;
            using (SqlConnection myConnection = new SqlConnection("Context connection = true"))
            {
                myConnection.Open();
                using (SqlCommand myCommand = new SqlCommand())
                {
                    myCommand.Connection = myConnection;
                    myCommand.CommandText = "SELECT clr_name FROM sys.assemblies WHERE name = @AsmName";
                    myCommand.Parameters.Add("@AsmName", SqlDbType.NVarChar);
                    myCommand.Parameters[0].Value = asmName;
                    ClrName = (String)myCommand.ExecuteScalar();
                }
            }

            // Load the assembly from SQL database to memory
            Assembly a = Assembly.Load(ClrName);

            // Iterating over types in assembly
            Type[] types = a.GetTypes();
            foreach (Type type in types)
            {

                try
                {
                    // Iterating over custom attributes of type
                    Attribute[] Type_attributes = Attribute.GetCustomAttributes(type);
                    foreach (Attribute att in Type_attributes)
                    {
                        if (att.ToString().Equals("Microsoft.SqlServer.Server.SqlUserDefinedTypeAttribute"))
                        {
                            RegisterUDT(asmName, type, (SqlUserDefinedTypeAttribute)att);
                        }
                        else if (att.ToString().Equals("Microsoft.SqlServer.Server.SqlUserDefinedAggregateAttribute"))
                        {
                            RegisterUDA(asmName, type, (SqlUserDefinedAggregateAttribute)att);
                        }
                    }
                }
                catch (AutoRegisterException e)
                {
                    SqlContext.Pipe.Send("[" + asmName + "].[" + type.Name + "] : " + e.Message);
                }

                // Iterate over public static methods
                MethodInfo[] methodInfos = type.GetMethods();
                foreach (MethodInfo methodInfo in methodInfos)
                {
                    if (methodInfo.IsPublic && methodInfo.IsStatic)
                    {
                        try
                        {
                            Object[] Method_attributes = methodInfo.GetCustomAttributes(false);
                            foreach (Attribute att in Method_attributes)
                            {
                                if (att.ToString().Equals("Microsoft.SqlServer.Server.SqlProcedureAttribute"))
                                    RegisterProcedure(asmName, type, (SqlProcedureAttribute)att, methodInfo);
                                if (att.ToString().Equals("Microsoft.SqlServer.Server.SqlTriggerAttribute"))
                                    RegisterTrigger(asmName, type, (SqlTriggerAttribute)att, methodInfo);
                                if (att.ToString().Equals("Microsoft.SqlServer.Server.SqlFunctionAttribute"))
                                    RegisterUDF(asmName, type, (SqlFunctionAttribute)att, methodInfo);
                            }

                        }
                        catch (AutoRegisterException e)
                        {
                            SqlContext.Pipe.Send("[" + asmName + "].[" + type.Name + "].[" + methodInfo.Name + "] : " + e.Message);
                        }
                    }
                }
            }
        }

        private static void RegisterUDT(String asmName, Type type, SqlUserDefinedTypeAttribute sqlUserDefinedTypeAttribute)
        {
            String Name = sqlUserDefinedTypeAttribute.Name;
            if (Name == null) Name = type.Name;

            // Check that the name is valid
            if (!IsValid(Name))
                throw new InvalidNameException(Name);

            // Check if the object is already registered
            if (IsRegistered(Name))
                throw new ObjectRegisteredException(Name);

            String CommandString = "CREATE TYPE [" + Name + "] EXTERNAL NAME [" + asmName + "].[";
            if (type.Namespace != null) CommandString += type.Namespace + ".";
            CommandString += type.Name + "]";

            Register(CommandString);
        }

        private static void RegisterUDA(String asmName, Type type, SqlUserDefinedAggregateAttribute sqlUserDefinedAggregateAttribute)
        {
            String Name = sqlUserDefinedAggregateAttribute.Name;
            if (Name == null) Name = type.Name;

            // Check that the name is valid
            if (!IsValid(Name))
                throw new InvalidNameException(Name);

            // Check if the object is already registered
            if (IsRegistered(Name))
                throw new ObjectRegisteredException(Name);

            String CommandString = "CREATE AGGREGATE [" + Name + "]";
            String CommandString_return = "";
            bool hasAccumulateMethod = false;

            // Iterate over public methods to find parameter and return type
            MethodInfo[] methodInfos = type.GetMethods();
            foreach (MethodInfo methodInfo in methodInfos)
            {
                if (methodInfo.Name.Equals("Accumulate"))
                {
                    // Add parameter
                    ParameterInfo[] pis = methodInfo.GetParameters();
                    if (pis.Length != 1)
                        throw new InvalidNumberOfArgumentsException();

                    ParameterInfo pi = pis[0];
                    char[] refchar = { '&' }; // double check that this is a unicode array
                    String param_type;

                    // if parameter is passed by reference as in (@i INT OUTPUT) display an error
                    if (pi.ParameterType.ToString().LastIndexOfAny(refchar) != -1)
                        throw new InvalidArgumentByReferenceException(pi.ParameterType.ToString());

                    param_type = NETtoSQLmap(pi.ParameterType.ToString());
                    if (param_type == null)
                        throw new InvalidTypeException(pi.ParameterType.ToString());

                    // Check the validity of the parameter strings (we construct param_type no need to validate it)
                    if (!IsValid(pi.Name))
                        throw new InvalidNameException(pi.Name);

                    CommandString += " (@" + pi.Name + " " + param_type + ")";
                    hasAccumulateMethod = true;
                }
                else if (methodInfo.Name.Equals("Terminate"))
                {
                    String return_type = NETtoSQLmap(methodInfo.ReturnType.ToString());
                    if (return_type == null)
                        throw new InvalidReturnTypeException(methodInfo.ReturnType.ToString());
                    CommandString_return = " RETURNS " + return_type + " "; // we construct return_type no need to validate it
                }
            }

            if (CommandString_return.Equals(""))
                throw new MissingTerminateMethodException();

            if (!hasAccumulateMethod)
                throw new MissingAccumulateMethodException();

            CommandString += CommandString_return;
            CommandString += "EXTERNAL NAME [" + asmName + "].[";
            if (type.Namespace != null) CommandString += type.Namespace + ".";
            CommandString += type.Name + "]";

            Register(CommandString);
        }

        private static void RegisterProcedure(String asmName, Type type, SqlProcedureAttribute sqlProcedureAttribute, MethodInfo methodInfo)
        {
            String Name = sqlProcedureAttribute.Name;
            if (Name == null) Name = methodInfo.Name;

            // Check that the name is valid
            if (!IsValid(Name))
                throw new InvalidNameException(Name);

            // Check if the object is already registered
            if (IsRegistered(Name))
                throw new ObjectRegisteredException(Name);

            String CommandString = "CREATE PROCEDURE [" + Name + "]";
            CommandString += GetParameterString(methodInfo);
            CommandString += " AS EXTERNAL NAME [" + asmName + "].[";
            if (type.Namespace != null) CommandString += type.Namespace + ".";
            CommandString += type.Name + "].[" + methodInfo.Name + "]";

            Register(CommandString);
        }

        private static void RegisterTrigger(String asmName, Type type, SqlTriggerAttribute sqlTriggerAttribute, MethodInfo methodInfo)
        {
            String Name = sqlTriggerAttribute.Name;
            if (Name == null) Name = methodInfo.Name;

            // Check that the name is valid
            if (!IsValid(Name))
                throw new InvalidNameException(Name);

            // Check if the object is already registered
            if (IsRegistered(Name))
                throw new ObjectRegisteredException(Name);

            // Check that the target is valid
            String Target = sqlTriggerAttribute.Target;
            if (Target == null)
                throw new NoTargetException();
            if (!IsValid(sqlTriggerAttribute.Target))
                throw new InvalidNameException(sqlTriggerAttribute.Target);

            // Check that the event is valid
            String Tr_event = sqlTriggerAttribute.Event;
            if (Tr_event == null)
                throw new NoEventException();
            if (!IsValid(sqlTriggerAttribute.Event))
                throw new InvalidNameException(sqlTriggerAttribute.Event);

            String CommandString = "CREATE TRIGGER [" + Name + "] ";
            CommandString += "ON " + Target + " WITH EXECUTE AS CALLER ";
            CommandString += Tr_event + " AS EXTERNAL NAME [" + asmName + "].[";
            if (type.Namespace != null) CommandString += type.Namespace + ".";
            CommandString += type.Name + "].[" + methodInfo.Name + "]";

            Register(CommandString);
        }

        private static void RegisterUDF(String asmName, Type type, SqlFunctionAttribute sqlFunctionAttribute, MethodInfo methodInfo)
        {
            String Name = sqlFunctionAttribute.Name;
            if (Name == null) Name = methodInfo.Name;

            // Check that the name is valid
            if (!IsValid(Name))
                throw new InvalidNameException(Name);

            // Check if the object is already registered
            if (IsRegistered(Name))
                throw new ObjectRegisteredException(Name);

            String CommandString = "CREATE FUNCTION [" + Name + "]";
            CommandString += GetParameterString(methodInfo);
            String return_type = NETtoSQLmap(methodInfo.ReturnType.ToString());
            if (return_type == null)
                throw new InvalidReturnTypeException(methodInfo.ReturnType.ToString());

            CommandString += "RETURNS " + return_type + " ";
            CommandString += "AS EXTERNAL NAME [" + asmName + "].[";
            if (type.Namespace != null) CommandString += type.Namespace + ".";
            CommandString += type.Name + "].[" + methodInfo.Name + "]";

            Register(CommandString);
        }

        // could we use a solution similar to Xiaowei's question to make this more robust
        private static String NETtoSQLmap(String typeName)
        {
            switch (typeName)
            {
                case "System.Data.SqlTypes.SqlBoolean": return "bit";
                case "System.Data.SqlTypes.SqlBinary": return "varbinary"; // other choices: binary, image, timestamp,
                case "System.Data.SqlTypes.SqlByte": return "tinyint";
                case "System.Data.SqlTypes.SqlChars": return "nvarchar"; //other choices: char, nchar, text, ntext, varchar
                case "System.Data.SqlTypes.SqlDateTime": return "smalldatetime"; // other choice: datetime
                case "System.Data.SqlTypes.SqlDecimal": return "decimal";
                case "System.Data.SqlTypes.SqlDouble": return "float";
                case "System.Data.SqlTypes.SqlGuid": return "uniqueidentifier";
                case "System.Data.SqlTypes.SqlInt16": return "smallint";
                case "System.Data.SqlTypes.SqlInt32": return "int";
                case "System.Data.SqlTypes.SqlInt64": return "bigint";
                case "System.Data.SqlTypes.SqlMoney": return "money"; // other choice smallmoney
                case "System.Data.SqlTypes.SqlSingle": return "real";
                case "System.Data.SqlTypes.SqlString": return "nvarchar"; //other choices: char, nchar, text, ntext, varchar
                case "System.Data.SqlTypes.SqlXml": return "xml";
                default:
                    {
                        switch (typeName.ToLower())
                        {
                            case "short": return "smallint";
                            case "int": return "int";
                            case "long": return "bigint";
                            case "bool": return "bit";
                            case "boolean": return "bit";
                            case "float": return "float";
                            case "single": return "single";
                            default: return null;
                        }
                    }
            }
        }

        // Check if the object is already registered, return true if it is, false otherwise
        private static bool IsRegistered(String name)
        {
            bool b;

            using (SqlConnection myConnection = new SqlConnection("context connection = true"))
            {
                myConnection.Open();
                using (SqlCommand myCommand = new SqlCommand())
                {
                    myCommand.Connection = myConnection;
                    myCommand.CommandText = "SELECT * FROM sys.objects WHERE name = @Name";
                    myCommand.Parameters.Add("@Name", SqlDbType.NVarChar);
                    myCommand.Parameters[0].Value = name;
                    if (myCommand.ExecuteScalar() != null) b = true;
                    else b = false;
                }
                return b;
            }

        }

        private static void Register(String commandString)
        {
            using (SqlConnection myConnection = new SqlConnection("context connection = true"))
            {
                myConnection.Open();
                using (SqlCommand comm = new SqlCommand(commandString, myConnection))
                {
                    try
                    {
                        comm.ExecuteNonQuery();
                        SqlContext.Pipe.Send("AutoRegister: " + commandString);
                    }
                    catch (SqlException e)
                    {
                        SqlContext.Pipe.Send(e.Message);
                    }
                }
            }
        }

        /* We assume namespace, class and method names are valid strings for
         * TSQL statements (alphanumeric or underscore only) since the
         * compiler accepted them and so is the assembly name since
         * it was provided by the DBA */
        private static bool IsValid(String s)
        {
            return Regex.IsMatch(s, "[^_.[:alnum:]]");
        }

        private static String GetParameterString(MethodInfo methodInfo)
        {
            String s_params = "";

            // Add parameters
            ParameterInfo[] pis = methodInfo.GetParameters();
            bool empty = true;
            char[] refchar = { '&' }; // double check that this is a unicode array
            String param_type;
            foreach (ParameterInfo pi in pis)
            {
                // add OUTPUT if parameter is passed by reference as in (@i INT OUTPUT)
                int ix = pi.ParameterType.ToString().LastIndexOfAny(refchar);
                if (ix != -1)
                {
                    param_type = NETtoSQLmap(pi.ParameterType.ToString().Substring(0, ix));
                    if (param_type == null)
                        throw new InvalidTypeException(pi.ParameterType.ToString().Substring(0, ix));
                    param_type += " OUTPUT";
                }
                else
                {
                    param_type = NETtoSQLmap(pi.ParameterType.ToString());
                    if (param_type == null)
                        throw new InvalidTypeException(pi.ParameterType.ToString());
                }

                // Check the validity of the parameter strings (we construct param_type no need to validate it)
                if (!IsValid(pi.Name))
                    throw new InvalidNameException(pi.Name);

                if (empty)
                {
                    s_params += " (@" + pi.Name + " " + param_type;
                    empty = false;
                }
                else
                    s_params += ", @" + pi.Name + " " + param_type;
            }
            if (!empty)
                s_params += ") ";

            return s_params;
        }
    }

    internal class AutoRegisterException : ApplicationException
    {
        internal AutoRegisterException(String Message)
            : base(Message + " UDX not registered.")
        {
        }
    }

    internal class InvalidTypeException : AutoRegisterException
    {
        internal InvalidTypeException(String sType)
            : base("Argument type " + sType + " cannot be mapped to a SQL Type.")
        {
        }

    }

    internal class InvalidReturnTypeException : AutoRegisterException
    {

        internal InvalidReturnTypeException(String sType)
            : base("Return type " + sType + " cannot be mapped to a SQL Type.")
        {
        }
    }

    internal class ObjectRegisteredException : AutoRegisterException
    {
        internal ObjectRegisteredException(String Name)
            : base(Name + " is already registered.")
        {
        }
    }

    internal class InvalidNameException : AutoRegisterException
    {
        internal InvalidNameException(String Name)
            : base(Name + " is not a valid name.")
        {
        }
    }

    internal class NoEventException : AutoRegisterException
    {
        internal NoEventException()
            : base("Missing the 'Event' argument for the SqlTriggerAttribute constructor.")
        {
        }
    }

    internal class NoTargetException : AutoRegisterException
    {
        internal NoTargetException()
            : base("Missing the 'Target' argument for the SqlTriggerAttribute constructor.")
        {
        }
    }

    internal class InvalidNumberOfArgumentsException : AutoRegisterException
    {
        internal InvalidNumberOfArgumentsException()
            : base("Wrong number of arguments.")
        {
        }
    }

    internal class InvalidArgumentByReferenceException : AutoRegisterException
    {
        internal InvalidArgumentByReferenceException(String Argument)
            : base("Expected argument " + Argument + "passed by value, passed by reference instead.")
        {
        }
    }

    internal class MissingTerminateMethodException : AutoRegisterException
    {
        internal MissingTerminateMethodException()
            : base("The user defined aggregate is missing a 'Terminate' method.")
        {
        }
    }

    internal class MissingAccumulateMethodException : AutoRegisterException
    {
        internal MissingAccumulateMethodException()
            : base("The user defined aggregate is missing an 'Accumulate' method.")
        {
        }
    }
}

-------------------------------------

then run the following TSQL batch

-------------------------------------

 CREATE ASSEMBLY AutoRegisterAsm
FROM -- path to your dll goes here
WITH PERMISSION_SET = SAFE
GO

CREATE TRIGGER AutoRegisterTrigger
ON DATABASE
FOR CREATE_ASSEMBLY
AS EXTERNAL NAME AutoRegisterAsm.[Microsoft.SqlServer.Sample.AutoRegister].AutoRegisterTrigger
GO

ENABLE TRIGGER AutoRegisterTrigger ON DATABASE
GO

- Miles Trochesset, Microsoft SQL Server