Step configuration for Databricks

Azure ML SDK v1 includes the DatabricksStep class that allows us to execute Databricks workloads as a part of a bigger Azure Machine Learning pipeline. Here are some tips about how to solve common challenges configuring DatabricksStep.

Identifying the Databricks script directory

Creating a Databricks step object requires both a directory name and the name of the file to be executed. The issue is that uploaded files and folders are placed in locations with randomized locations. As a result, Python import-s doesn't work.

The workaround is to get AZUREML_SCRIPT_DIRECTORY_NAME from arguments in the script (the argument is provided automatically, but you need to parse it) and append it to sys.path before doing imports:

# args is an ArgumentParser object with all the arguments to the script
# parser.add_argument("--AZURE_SCRIPT_DIRECTORY_NAME") is required when you are parsing the arguments
root = os.path.join("/dbfs", args.AZURE_SCRIPT_DIRECTORY_NAME)

sys.path.append(root)

Using environment variables

Environment variables need to be initialized prior to using the Run class.

Run.get_context() cannot be used due to lack of required environment variables. These variables should be initialized explicitly in code. For example, you can use the following code:

def get_current_run():
    """
    Retrieves current Azure ML run for a Databricks step.
    """
    parser = argparse.ArgumentParser()

    # Remaining arguments are filled in for all databricks jobs and can be used to build the run context
    parser.add_argument("--AZUREML_RUN_TOKEN")
    parser.add_argument("--AZUREML_RUN_TOKEN_EXPIRY")
    parser.add_argument("--AZUREML_RUN_ID")
    parser.add_argument("--AZUREML_ARM_SUBSCRIPTION")
    parser.add_argument("--AZUREML_ARM_RESOURCEGROUP")
    parser.add_argument("--AZUREML_ARM_WORKSPACE_NAME")
    parser.add_argument("--AZUREML_ARM_PROJECT_NAME")
    parser.add_argument("--AZUREML_SERVICE_ENDPOINT")
    parser.add_argument("--AZUREML_WORKSPACE_ID")
    parser.add_argument("--AZUREML_EXPERIMENT_ID")

    args, _ = parser.parse_known_args()

    os.environ["AZUREML_RUN_TOKEN"] = args.AZUREML_RUN_TOKEN
    os.environ["AZUREML_RUN_TOKEN_EXPIRY"] = args.AZUREML_RUN_TOKEN_EXPIRY
    os.environ["AZUREML_RUN_ID"] = args.AZUREML_RUN_ID
    os.environ["AZUREML_ARM_SUBSCRIPTION"] = args.AZUREML_ARM_SUBSCRIPTION
    os.environ["AZUREML_ARM_RESOURCEGROUP"] = args.AZUREML_ARM_RESOURCEGROUP
    os.environ["AZUREML_ARM_WORKSPACE_NAME"] = args.AZUREML_ARM_WORKSPACE_NAME
    os.environ["AZUREML_ARM_PROJECT_NAME"] = args.AZUREML_ARM_PROJECT_NAME
    os.environ["AZUREML_SERVICE_ENDPOINT"] = args.AZUREML_SERVICE_ENDPOINT
    os.environ["AZUREML_WORKSPACE_ID"] = args.AZUREML_WORKSPACE_ID
    os.environ["AZUREML_EXPERIMENT_ID"] = args.AZUREML_EXPERIMENT_ID

    run = Run.get_context(allow_offline=False)

    return run