- Create DataFrame.
>>> from pyspark.sql.types import *
>>> schema = StructType([ StructField("longitude", FloatType(), nullable=True), StructField("latitude", FloatType(), nullable=True), StructField("medage", FloatType(), nullable=True), StructField("totrooms", FloatType(), nullable=True), StructField("totbdrms", FloatType(), nullable=True), StructField("pop", FloatType(), nullable=True), StructField("houshlds", FloatType(), nullable=True), StructField("medinc", FloatType(), nullable=True), StructField("medhv", FloatType(), nullable=True)] )
>>> housing_df = spark.read.csv(path=HOUSING_DATA, schema=schema).cache()
>>> housing_df.show()
+---------+--------+------+--------+--------+------+--------+------+--------+ |longitude|latitude|medage|totrooms|totbdrms| pop|houshlds|medinc| medhv| +---------+--------+------+--------+--------+------+--------+------+--------+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0|8.3252|452600.0| | -122.22| 37.86| 21.0| 7099.0| 1106.0|2401.0| 1138.0|8.3014|358500.0| | -122.24| 37.85| 52.0| 1467.0| 190.0| 496.0| 177.0|7.2574|352100.0| | -122.25| 37.85| 52.0| 1274.0| 235.0| 558.0| 219.0|5.6431|341300.0| | -122.25| 37.85| 52.0| 1627.0| 280.0| 565.0| 259.0|3.8462|342200.0| | -122.25| 37.85| 52.0| 919.0| 213.0| 413.0| 193.0|4.0368|269700.0| | -122.25| 37.84| 52.0| 2535.0| 489.0|1094.0| 514.0|3.6591|299200.0| | -122.25| 37.84| 52.0| 3104.0| 687.0|1157.0| 647.0| 3.12|241400.0| | -122.26| 37.84| 42.0| 2555.0| 665.0|1206.0| 595.0|2.0804|226700.0| | -122.25| 37.84| 52.0| 3549.0| 707.0|1551.0| 714.0|3.6912|261100.0| | -122.26| 37.85| 52.0| 2202.0| 434.0| 910.0| 402.0|3.2031|281500.0| | -122.26| 37.85| 52.0| 3503.0| 752.0|1504.0| 734.0|3.2705|241800.0| | -122.26| 37.85| 52.0| 2491.0| 474.0|1098.0| 468.0| 3.075|213500.0| | -122.26| 37.84| 52.0| 696.0| 191.0| 345.0| 174.0|2.6736|191300.0| | -122.26| 37.85| 52.0| 2643.0| 626.0|1212.0| 620.0|1.9167|159200.0| | -122.26| 37.85| 50.0| 1120.0| 283.0| 697.0| 264.0| 2.125|140000.0| | -122.27| 37.85| 52.0| 1966.0| 347.0| 793.0| 331.0| 2.775|152500.0| | -122.27| 37.85| 52.0| 1228.0| 293.0| 648.0| 303.0|2.1202|155500.0| | -122.26| 37.84| 50.0| 2239.0| 455.0| 990.0| 419.0|1.9911|158700.0| | -122.27| 37.84| 52.0| 1503.0| 298.0| 690.0| 275.0|2.6033|162900.0| +---------+--------+------+--------+--------+------+--------+------+--------+ only showing top 20 rows
- Scale down the column 'medhv'.
>>> housing_df = housing_df.withColumn("medhv", housing_df.medhv/1000)
>>> housing_df.show()
+---------+--------+------+--------+--------+------+--------+------+-----+ |longitude|latitude|medage|totrooms|totbdrms| pop|houshlds|medinc|medhv| +---------+--------+------+--------+--------+------+--------+------+-----+ | -122.23| 37.88| 41.0| 880.0| 129.0| 322.0| 126.0|8.3252|452.6| | -122.22| 37.86| 21.0| 7099.0| 1106.0|2401.0| 1138.0|8.3014|358.5| | -122.24| 37.85| 52.0| 1467.0| 190.0| 496.0| 177.0|7.2574|352.1| | -122.25| 37.85| 52.0| 1274.0| 235.0| 558.0| 219.0|5.6431|341.3| | -122.25| 37.85| 52.0| 1627.0| 280.0| 565.0| 259.0|3.8462|342.2| | -122.25| 37.85| 52.0| 919.0| 213.0| 413.0| 193.0|4.0368|269.7| | -122.25| 37.84| 52.0| 2535.0| 489.0|1094.0| 514.0|3.6591|299.2| | -122.25| 37.84| 52.0| 3104.0| 687.0|1157.0| 647.0| 3.12|241.4| | -122.26| 37.84| 42.0| 2555.0| 665.0|1206.0| 595.0|2.0804|226.7| | -122.25| 37.84| 52.0| 3549.0| 707.0|1551.0| 714.0|3.6912|261.1| | -122.26| 37.85| 52.0| 2202.0| 434.0| 910.0| 402.0|3.2031|281.5| | -122.26| 37.85| 52.0| 3503.0| 752.0|1504.0| 734.0|3.2705|241.8| | -122.26| 37.85| 52.0| 2491.0| 474.0|1098.0| 468.0| 3.075|213.5| | -122.26| 37.84| 52.0| 696.0| 191.0| 345.0| 174.0|2.6736|191.3| | -122.26| 37.85| 52.0| 2643.0| 626.0|1212.0| 620.0|1.9167|159.2| | -122.26| 37.85| 50.0| 1120.0| 283.0| 697.0| 264.0| 2.125|140.0| | -122.27| 37.85| 52.0| 1966.0| 347.0| 793.0| 331.0| 2.775|152.5| | -122.27| 37.85| 52.0| 1228.0| 293.0| 648.0| 303.0|2.1202|155.5| | -122.26| 37.84| 50.0| 2239.0| 455.0| 990.0| 419.0|1.9911|158.7| | -122.27| 37.84| 52.0| 1503.0| 298.0| 690.0| 275.0|2.6033|162.9| +---------+--------+------+--------+--------+------+--------+------+-----+ only showing top 20 rows
- Prepare the stages of Pipeline.PySpark machine learning functions accept vectors.
- Convert Feature columns to Vectors using VectorAssembler.
>>> from pyspark.ml.regression import LinearRegression
>>> from pyspark.ml.feature import VectorAssembler, StandardScaler
- Define a StandardScaler to scale the features .
>>> assembler = VectorAssembler(inputCols=housing_df.columns[::-1], outputCol="features", handleInvalid='keep')
>>> standardScaler = StandardScaler(inputCol="features", outputCol="features_scaled")
- Declare LinearRegression.
>>> lr = (LinearRegression(featuresCol='features_scaled', labelCol="medhv", predictionCol='predmedhv', maxIter=10, regParam=0.3, elasticNetParam=0.8, standardization=False))
- Convert Feature columns to Vectors using VectorAssembler.
- Initiate Pipeline with the stages in Step 3.
>>> from pyspark.ml import Pipeline
>>> ppl = Pipeline(stages= [assembler, standardScaler, lr])
- Prepare test and train data.
>>> train_data, test_data = housing_df.randomSplit([.8,.2])
- Fit the Pipeline model and predict the values.
>>> model = ppl.fit(train_data)
>>> model.transform(test_data).show()
+---------+--------+------+--------+--------+------+--------+------+-----+--------------------+--------------------+------------------+ |longitude|latitude|medage|totrooms|totbdrms| pop|houshlds|medinc|medhv| features| features_scaled| predmedhv| +---------+--------+------+--------+--------+------+--------+------+-----+--------------------+--------------------+------------------+ | -124.25| 40.28| 32.0| 1430.0| 419.0| 434.0| 187.0|1.9417| 76.1|[76.1,1.941699981 |[0.65626456580292 | 79.25507388391912| | -124.21| 40.75| 32.0| 1218.0| 331.0| 620.0| 268.0|1.6528| 58.1|[58.1,1.652799963 |[0.50103773026478 | 59.55025106798064| | -124.18| 40.78| 33.0| 1076.0| 222.0| 656.0| 236.0|2.5096| 72.2|[72.2,2.509599924 |[0.62263208476965 | 72.81704038544564| | -124.17| 40.76| 26.0| 1776.0| 361.0| 992.0| 380.0|2.8056| 82.8|[82.8,2.805599927 |[0.71404344347545 | 83.39513495596003| | -124.17| 40.77| 30.0| 1895.0| 366.0| 990.0| 359.0|2.2227| 81.3|[81.3,2.222700119 |[0.70110787384727 | 81.67913053178052| | -124.17| 41.76| 20.0| 2673.0| 538.0|1282.0| 514.0|2.4605|105.9|[105.9,2.46050000 |[0.91325121574940 |104.20940155104489| | -124.16| 40.6| 39.0| 1322.0| 283.0| 642.0| 292.0|2.4519| 85.1|[85.1,2.451900005 |[0.73387798357199 | 86.041163302114| | -124.16| 40.78| 50.0| 2285.0| 403.0| 837.0| 353.0|2.5417| 85.4|[85.4,2.541699886 |[0.73646509749762 | 85.46745330865821| | -124.16| 40.8| 52.0| 2167.0| 480.0| 908.0| 451.0|1.6111| 74.7|[74.7,1.611099958 |[0.64419136748328 | 74.89300829090928| | -124.16| 40.8| 52.0| 2416.0| 618.0|1150.0| 571.0|1.7308| 80.5|[80.5,1.730800032 |[0.69420890337891 | 80.56692532800037| | -124.15| 40.59| 39.0| 1186.0| 238.0| 539.0| 212.0|2.0938| 79.6|[79.6,2.093800067 |[0.68644756160200 | 80.59643757652566| | -124.15| 40.78| 36.0| 2112.0| 374.0| 829.0| 368.0|3.3984| 90.0|[90.0,3.398400068 |[0.77613417769070 | 90.5945896867645| | -124.15| 41.81| 17.0| 3276.0| 628.0|3546.0| 585.0|2.2868|103.1|[103.1,2.28679990 |[0.88910481911013 | 94.78147144605492| | -124.14| 40.57| 29.0| 2864.0| 600.0|1314.0| 562.0|2.1354| 75.1|[75.1,2.135400056 |[0.64764085271746 | 76.22355760734331| | -124.14| 40.58| 25.0| 1899.0| 357.0| 891.0| 355.0|2.6987| 92.5|[92.5,2.698699951 |[0.79769346040434 | 93.63768146514934| | -124.14| 41.06| 32.0| 1020.0| 215.0| 421.0| 198.0|3.0208|143.4|[143.4,3.02080011 |[1.23664045645386 |143.15494451322667| | -124.11| 40.95| 19.0| 1734.0| 365.0| 866.0| 342.0| 2.96| 81.7|[81.7,2.960000038 |[0.70455735908145 | 82.38724990731038| | -124.1| 40.9| 18.0| 4032.0| 798.0|1948.0| 775.0|2.7321| 92.6|[92.6,2.732100009 |[0.79855583171288 | 92.3936172993387| | -124.09| 40.55| 24.0| 2978.0| 553.0|1370.0| 480.0|2.7644| 97.3|[97.3,2.764400005 |[0.83908728321451 | 97.88765777716378| | -124.09| 40.86| 25.0| 1322.0| 387.0| 794.0| 379.0|1.1742| 75.0|[75.0,1.174200057 |[0.64677848140892 | 75.86383177479371| +---------+--------+------+--------+--------+------+--------+------+-----+--------------------+--------------------+------------------+ only showing top 20 rows