I was also getting this issue and raised a ticket with MS support. They acknowledged the issue and gave me the following workaround for the meantime:
At the top of your code add the following lines:
from py4j.java_gateway import java_import
java_import(spark._sc._jvm, "org.apache.spark.sql.api.python.*")
For example: