Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@
package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.train.{TrainClassifier, TrainRegressor}
import com.microsoft.azure.synapse.ml.core.schema.{DatasetExtensions, SchemaConstants}
import com.microsoft.azure.synapse.ml.core.utils.StopWatch
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.stages.DropColumns
import org.apache.spark.annotation.Experimental
import com.microsoft.azure.synapse.ml.train.{TrainClassifier, TrainRegressor}
import org.apache.commons.math3.stat.descriptive.rank.Percentile
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model, Pipeline}
import org.apache.commons.math3.stat.inference.TestUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.classification.ProbabilisticClassifier
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, Regressor}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap}
import org.apache.spark.ml.param.shared.{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol}
import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap}
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, Regressor}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, IntegerType, LongType, StructType}
import org.apache.commons.math3.stat.inference.TestUtils
import org.apache.spark.sql.functions.{col, lit, when}

import scala.concurrent.Future

Expand Down Expand Up @@ -114,8 +113,8 @@ class DoubleMLEstimator(override val uid: String)
val oneAte = totalTime.measure {
trainInternal(redrewDF)
}
log.info(s"Completed ATE calculation on iteration $index and got ATE value: $oneAte, " +
s"time elapsed: ${totalTime.elapsed() / 6e10} minutes")
log.info(s"Completed ATE calculation on iteration $index " +
s"and got ATE value: $oneAte, time elapsed: ${totalTime.elapsed() / 6e10} minutes")
Some(oneAte)
} catch {
case ex: Throwable =>
Expand All @@ -128,6 +127,7 @@ class DoubleMLEstimator(override val uid: String)
}

val ates = awaitFutures(ateFutures).flatten

if (ates.isEmpty) {
throw new Exception("ATE calculation failed on all iterations. Please check the log for details.")
}
Expand Down Expand Up @@ -193,35 +193,47 @@ class DoubleMLEstimator(override val uid: String)

def calculateResiduals(train: Dataset[_], test: Dataset[_]): DataFrame = {
val treatmentModel = treatmentEstimator.setInputCols(
train.columns.filterNot(Array(getTreatmentCol, getOutcomeCol).contains)
).fit(train)
train.columns.filterNot(Array(getTreatmentCol, getOutcomeCol
).contains)
)

val outcomeModel = outcomeEstimator.setInputCols(
train.columns.filterNot(Array(getOutcomeCol, getTreatmentCol).contains)
).fit(train)
train.columns.filterNot(Array(getOutcomeCol, getTreatmentCol
).contains)
)

val treatmentResidual =
new ResidualTransformer()
.setObservedCol(getTreatmentCol)
.setPredictedCol(treatmentResidualPredictionColName)
.setOutputCol(treatmentResidualCol)
val dropTreatmentPredictedColumns = new DropColumns().setCols(treatmentPredictionColsToDrop.toArray)
val dropTreatmentPredictedColumns = new DropColumns().setCols(treatmentPredictionColsToDrop)
val outcomeResidual =
new ResidualTransformer()
.setObservedCol(getOutcomeCol)
.setPredictedCol(outcomeResidualPredictionColName)
.setOutputCol(outcomeResidualCol)
val dropOutcomePredictedColumns = new DropColumns().setCols(outcomePredictionColsToDrop.toArray)
val dropOutcomePredictedColumns = new DropColumns().setCols(outcomePredictionColsToDrop)

// TODO: use org.apache.spark.ml.functions.array_to_vector function maybe slightly more efficient.
val treatmentResidualVA =
new VectorAssembler()
.setInputCols(Array(treatmentResidualCol))
.setOutputCol(treatmentResidualVecCol)
.setHandleInvalid("skip")
val pipeline = new Pipeline().setStages(Array(
treatmentModel, treatmentResidual, dropTreatmentPredictedColumns,
outcomeModel, outcomeResidual, dropOutcomePredictedColumns,
treatmentResidualVA))

pipeline.fit(test).transform(test)
val treatmentPipeline = new Pipeline()
.setStages(Array(treatmentModel, treatmentResidual, dropTreatmentPredictedColumns))
.fit(train)

val outcomePipeline = new Pipeline()
.setStages(Array(outcomeModel, outcomeResidual, dropOutcomePredictedColumns))
.fit(train)

val df1 = treatmentPipeline.transform(test).cache
val df2 = outcomePipeline.transform(df1).cache
val transformed = treatmentResidualVA.transform(df2)
transformed
}

// Note, we perform these steps to get ATE
Expand All @@ -234,8 +246,8 @@ class DoubleMLEstimator(override val uid: String)
*/
val splits = dataset.randomSplit(getSampleSplitRatio)
val (train, test) = (splits(0).cache, splits(1).cache)
val residualsDF1 = calculateResiduals(train, test)
val residualsDF2 = calculateResiduals(test, train)
val residualsDF1 = calculateResiduals(train, test).select(outcomeResidualCol, treatmentResidualVecCol)
val residualsDF2 = calculateResiduals(test, train).select(outcomeResidualCol, treatmentResidualVecCol)

// Average slopes from the two residual models.
val regressor = new GeneralizedLinearRegression()
Expand All @@ -247,7 +259,6 @@ class DoubleMLEstimator(override val uid: String)

val coefficients = Array(residualsDF1, residualsDF2).map(regressor.fit).map(_.coefficients(0))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the logs, DML notebook run is being blocked here.
Any idea?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val ate = coefficients.sum / coefficients.length

Seq(train, test).foreach(_.unpersist)
ate
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class ResidualTransformer(override val uid: String) extends Transformer
} else dataset

val predictedColDataType = convertedDataset.schema(getPredictedCol).dataType

predictedColDataType match {
case SQLDataTypes.VectorType =>
// For probability vector, compute the residual as "observed - probability($index)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,16 @@ class TrainClassifier(override val uid: String) extends AutoTrainer[TrainedClass
.setOneHotEncodeCategoricals(oneHotEncodeCategoricals)
.setNumFeatures(featuresToHashTo)
val featurizedModel = featurizer.fit(convertedLabelDataset)

val processedData = featurizedModel.transform(convertedLabelDataset)

processedData.cache()
if (!processedData.storageLevel.useMemory) {
processedData.cache()
}

// For neural network, need to modify input layer so it will automatically work during train
if (modifyInputLayer) {

val multilayerPerceptronClassifier = classifier.asInstanceOf[MultilayerPerceptronClassifier]
val row = processedData.take(1)(0)
val featuresVector = row.get(row.fieldIndex(getFeaturesCol))
Expand All @@ -185,7 +189,9 @@ class TrainClassifier(override val uid: String) extends AutoTrainer[TrainedClass
// Train the learner
val fitModel = classifier.fit(processedData)

processedData.unpersist()
if (processedData.storageLevel.useMemory) {
processedData.unpersist()
}

// Note: The fit shouldn't do anything here
val pipelineModel = new Pipeline().setStages(Array(featurizedModel, fitModel)).fit(convertedLabelDataset)
Expand Down Expand Up @@ -350,7 +356,6 @@ class TrainedClassifierModel(val uid: String)
else CategoricalUtilities.setLevels(scoredDataWithUpdatedScoredLabels,
SchemaConstants.SparkPredictionColumn,
getLevels)

// add metadata to the scored labels and true labels for the levels in label column
if (get(levels).isEmpty || !labelColumnExists) scoredDataWithUpdatedScoredLevels
else CategoricalUtilities.setLevels(scoredDataWithUpdatedScoredLevels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._

import java.util.UUID
Expand Down Expand Up @@ -106,6 +107,7 @@ class TrainRegressor(override val uid: String) extends AutoTrainer[TrainedRegres
.setNumFeatures(featuresToHashTo)

val featurizedModel = featurizer.fit(convertedLabelDataset)

val processedData = featurizedModel.transform(convertedLabelDataset)

processedData.cache()
Expand Down Expand Up @@ -168,7 +170,6 @@ class TrainedRegressorModel(val uid: String)
if (!labelColumnExists) cleanedScoredData
else SparkSchema.setLabelColumnName(
cleanedScoredData, moduleName, getLabelCol, SchemaConstants.RegressionKind)

SparkSchema.updateColumnMetadata(schematizedScoredDataWithLabel,
moduleName, SchemaConstants.SparkPredictionColumn, SchemaConstants.RegressionKind)
})
Expand Down