I'm trying to create a hybrid pipeline for ml purposes. Firstly, I run the data preparation step on my local machine, then I want to upload the new dataset to azure and run the train on azure.
If I run the train step manually in Azure there is no error. But if I run the train step in a pipeline(still on azure) an error occurs:
UserErrorException: Pipeline input expected azure.ai.ml.Input or primitive types (str, bool, int or float) but received type <class 'main.Train'>.
The error occurs in cnc_miling_pipelinev2 function
class Train(BaseEstimator, TransformerMixin):
def __init__(self,train_src_dir,train_data_csv_path,test_data_csv_path):
self.train_src_dir = train_src_dir
self.train_data_csv_path = train_data_csv_path
self.test_data_csv_path = test_data_csv_path
self.azure_authenticate()
def azure_authenticate(self):
try:
credential = DefaultAzureCredential()
# Check if given credential can get token successfully.
credential.get_token("https://management.azure.com/.default")
except Exception as ex:
# Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
credential = InteractiveBrowserCredential()
self.ml_client = MLClient.from_config(credential=credential)
def fit(self,X=None, y=None):
return self
def data_load(self,):
self.train_data_csv = Data(
name="train_data_csv",
path=self.train_data_csv_path,
type=AssetTypes.URI_FILE,
description="train_data_csv",
tags={"source_type": "local", "source": "AzureML CNC blob"},
version="1.2.9")
self.train_data_csv = self.ml_client.data.create_or_update(self.train_data_csv)
print(
f"Dataset with name {self.train_data_csv.name} was registered to workspace, the dataset version is {self.train_data_csv.version}"
)
self.test_data_csv = Data(
name="test_data_csv",
path=self.test_data_csv_path,
type=AssetTypes.URI_FILE,
description="test_data_csv",
tags={"source_type": "local", "source": "AzureML CNC blob"},
version="1.2.9")
self.test_data_csv = self.ml_client.data.create_or_update(self.test_data_csv)
print(
f"Dataset with name {self.test_data_csv.name} was registered to workspace, the dataset version is {self.test_data_csv.version}"
)
@dsl.pipeline(compute="x",description="CNC prep train pipeline v2",)
def cnc_miling_pipelinev2(self,
pipeline_train_data,
pipeline_test_data,
pipeline_job_learning_rate,
pipeline_job_registered_model_name):
print("HELLO")
# using data_prep_function like a python call with its own inputs
# using train_func like a python call with its own inputs
train_job = self.train_component(
train_data=pipeline_train_data, # note: using outputs from previous step
test_data=pipeline_test_data, # note: using outputs from previous step
learning_rate=pipeline_job_learning_rate, # note: using a pipeline input as parameter
registered_model_name=pipeline_job_registered_model_name
)
# a pipeline returns a dictionary of outputs
# keys will code for the pipeline output identifier
def transform(self,X=None, y=None):
# Loading the component from the yml file
self.train_component = load_component(source=os.path.join(self.train_src_dir, "train.yml"))
# Now we register the component to the workspace
self.train_component = self.ml_client.create_or_update(self.train_component)
# Create (register) the component in your workspace
print(
f"Component {self.train_component.name} with Version {self.train_component.version} is registered"
)
self.data_load()
registered_model_name = "cnc_miling_modelV2"
# Let's instantiate the pipeline with the parameters of our choice
# Let's instantiate the pipeline with the parameters of our choice
print("self.train_data_csv.path" , self.train_data_csv.path, "self.test_data_csv.path: ", self.test_data_csv.path)
print("\n\n")
pipeline_train = self.cnc_miling_pipelinev2(
pipeline_train_data=Input(type="uri_file", path=self.train_data_csv.path),
pipeline_test_data=Input(type="uri_file", path=self.test_data_csv.path),
pipeline_job_learning_rate=0.05,
pipeline_job_registered_model_name=registered_model_name
)
print("self.cnc_miling_pipelinev2 OK")
self.ml_client.jobs.create_or_update(pipeline_train,
# Project's name
experiment_name="cnc_miling_exp_v2")