Use Pipeline in PySpark | teradatamlspk Examples - Use Pipeline in PySpark - Teradata Package for Python

Teradata® pyspark2teradataml User Guide

Teradata Package for Python
Release Number
December 2024
Product Category
Teradata Vantage
  1. Create DataFrame.
    >>> from pyspark.sql.types import *
    >>> schema = StructType([
            StructField("longitude", FloatType(), nullable=True),
            StructField("latitude", FloatType(), nullable=True),
            StructField("medage", FloatType(), nullable=True),
            StructField("totrooms", FloatType(), nullable=True),
            StructField("totbdrms", FloatType(), nullable=True),
            StructField("pop", FloatType(), nullable=True),
            StructField("houshlds", FloatType(), nullable=True),
            StructField("medinc", FloatType(), nullable=True),
            StructField("medhv", FloatType(), nullable=True)]
    >>> housing_df =, schema=schema).cache()
    |longitude|latitude|medage|totrooms|totbdrms|   pop|houshlds|medinc|   medhv|
    |  -122.23|   37.88|  41.0|   880.0|   129.0| 322.0|   126.0|8.3252|452600.0|
    |  -122.22|   37.86|  21.0|  7099.0|  1106.0|2401.0|  1138.0|8.3014|358500.0|
    |  -122.24|   37.85|  52.0|  1467.0|   190.0| 496.0|   177.0|7.2574|352100.0|
    |  -122.25|   37.85|  52.0|  1274.0|   235.0| 558.0|   219.0|5.6431|341300.0|
    |  -122.25|   37.85|  52.0|  1627.0|   280.0| 565.0|   259.0|3.8462|342200.0|
    |  -122.25|   37.85|  52.0|   919.0|   213.0| 413.0|   193.0|4.0368|269700.0|
    |  -122.25|   37.84|  52.0|  2535.0|   489.0|1094.0|   514.0|3.6591|299200.0|
    |  -122.25|   37.84|  52.0|  3104.0|   687.0|1157.0|   647.0|  3.12|241400.0|
    |  -122.26|   37.84|  42.0|  2555.0|   665.0|1206.0|   595.0|2.0804|226700.0|
    |  -122.25|   37.84|  52.0|  3549.0|   707.0|1551.0|   714.0|3.6912|261100.0|
    |  -122.26|   37.85|  52.0|  2202.0|   434.0| 910.0|   402.0|3.2031|281500.0|
    |  -122.26|   37.85|  52.0|  3503.0|   752.0|1504.0|   734.0|3.2705|241800.0|
    |  -122.26|   37.85|  52.0|  2491.0|   474.0|1098.0|   468.0| 3.075|213500.0|
    |  -122.26|   37.84|  52.0|   696.0|   191.0| 345.0|   174.0|2.6736|191300.0|
    |  -122.26|   37.85|  52.0|  2643.0|   626.0|1212.0|   620.0|1.9167|159200.0|
    |  -122.26|   37.85|  50.0|  1120.0|   283.0| 697.0|   264.0| 2.125|140000.0|
    |  -122.27|   37.85|  52.0|  1966.0|   347.0| 793.0|   331.0| 2.775|152500.0|
    |  -122.27|   37.85|  52.0|  1228.0|   293.0| 648.0|   303.0|2.1202|155500.0|
    |  -122.26|   37.84|  50.0|  2239.0|   455.0| 990.0|   419.0|1.9911|158700.0|
    |  -122.27|   37.84|  52.0|  1503.0|   298.0| 690.0|   275.0|2.6033|162900.0|
    only showing top 20 rows
  2. Scale down the column 'medhv'.
    >>> housing_df = housing_df.withColumn("medhv", housing_df.medhv/1000)
    |longitude|latitude|medage|totrooms|totbdrms|   pop|houshlds|medinc|medhv|
    |  -122.23|   37.88|  41.0|   880.0|   129.0| 322.0|   126.0|8.3252|452.6|
    |  -122.22|   37.86|  21.0|  7099.0|  1106.0|2401.0|  1138.0|8.3014|358.5|
    |  -122.24|   37.85|  52.0|  1467.0|   190.0| 496.0|   177.0|7.2574|352.1|
    |  -122.25|   37.85|  52.0|  1274.0|   235.0| 558.0|   219.0|5.6431|341.3|
    |  -122.25|   37.85|  52.0|  1627.0|   280.0| 565.0|   259.0|3.8462|342.2|
    |  -122.25|   37.85|  52.0|   919.0|   213.0| 413.0|   193.0|4.0368|269.7|
    |  -122.25|   37.84|  52.0|  2535.0|   489.0|1094.0|   514.0|3.6591|299.2|
    |  -122.25|   37.84|  52.0|  3104.0|   687.0|1157.0|   647.0|  3.12|241.4|
    |  -122.26|   37.84|  42.0|  2555.0|   665.0|1206.0|   595.0|2.0804|226.7|
    |  -122.25|   37.84|  52.0|  3549.0|   707.0|1551.0|   714.0|3.6912|261.1|
    |  -122.26|   37.85|  52.0|  2202.0|   434.0| 910.0|   402.0|3.2031|281.5|
    |  -122.26|   37.85|  52.0|  3503.0|   752.0|1504.0|   734.0|3.2705|241.8|
    |  -122.26|   37.85|  52.0|  2491.0|   474.0|1098.0|   468.0| 3.075|213.5|
    |  -122.26|   37.84|  52.0|   696.0|   191.0| 345.0|   174.0|2.6736|191.3|
    |  -122.26|   37.85|  52.0|  2643.0|   626.0|1212.0|   620.0|1.9167|159.2|
    |  -122.26|   37.85|  50.0|  1120.0|   283.0| 697.0|   264.0| 2.125|140.0|
    |  -122.27|   37.85|  52.0|  1966.0|   347.0| 793.0|   331.0| 2.775|152.5|
    |  -122.27|   37.85|  52.0|  1228.0|   293.0| 648.0|   303.0|2.1202|155.5|
    |  -122.26|   37.84|  50.0|  2239.0|   455.0| 990.0|   419.0|1.9911|158.7|
    |  -122.27|   37.84|  52.0|  1503.0|   298.0| 690.0|   275.0|2.6033|162.9|
    only showing top 20 rows
  3. Prepare the stages of Pipeline.
    PySpark machine learning functions accept vectors.
    1. Convert Feature columns to Vectors using VectorAssembler.
      >>> from import LinearRegression
      >>> from import VectorAssembler, StandardScaler
    2. Define a StandardScaler to scale the features .
      >>> assembler = VectorAssembler(inputCols=housing_df.columns[::-1], outputCol="features", handleInvalid='keep')
      >>> standardScaler = StandardScaler(inputCol="features", outputCol="features_scaled")
    3. Declare LinearRegression.
      >>> lr = (LinearRegression(featuresCol='features_scaled', labelCol="medhv", predictionCol='predmedhv',
                                         maxIter=10, regParam=0.3, elasticNetParam=0.8, standardization=False))
  4. Initiate Pipeline with the stages in Step 3.
    >>> from import Pipeline
    >>> ppl = Pipeline(stages= [assembler, standardScaler, lr])
  5. Prepare test and train data.
    >>> train_data, test_data = housing_df.randomSplit([.8,.2])
  6. Fit the Pipeline model and predict the values.
    >>> model =
    >>> model.transform(test_data).show()
    |longitude|latitude|medage|totrooms|totbdrms|   pop|houshlds|medinc|medhv|            features|     features_scaled|         predmedhv|
    |  -124.25|   40.28|  32.0|  1430.0|   419.0| 434.0|   187.0|1.9417| 76.1|[76.1,1.941699981   |[0.65626456580292   | 79.25507388391912|
    |  -124.21|   40.75|  32.0|  1218.0|   331.0| 620.0|   268.0|1.6528| 58.1|[58.1,1.652799963   |[0.50103773026478   | 59.55025106798064|
    |  -124.18|   40.78|  33.0|  1076.0|   222.0| 656.0|   236.0|2.5096| 72.2|[72.2,2.509599924   |[0.62263208476965   | 72.81704038544564|
    |  -124.17|   40.76|  26.0|  1776.0|   361.0| 992.0|   380.0|2.8056| 82.8|[82.8,2.805599927   |[0.71404344347545   | 83.39513495596003|
    |  -124.17|   40.77|  30.0|  1895.0|   366.0| 990.0|   359.0|2.2227| 81.3|[81.3,2.222700119   |[0.70110787384727   | 81.67913053178052|
    |  -124.17|   41.76|  20.0|  2673.0|   538.0|1282.0|   514.0|2.4605|105.9|[105.9,2.46050000   |[0.91325121574940   |104.20940155104489|
    |  -124.16|    40.6|  39.0|  1322.0|   283.0| 642.0|   292.0|2.4519| 85.1|[85.1,2.451900005   |[0.73387798357199   |   86.041163302114|
    |  -124.16|   40.78|  50.0|  2285.0|   403.0| 837.0|   353.0|2.5417| 85.4|[85.4,2.541699886   |[0.73646509749762   | 85.46745330865821|
    |  -124.16|    40.8|  52.0|  2167.0|   480.0| 908.0|   451.0|1.6111| 74.7|[74.7,1.611099958   |[0.64419136748328   | 74.89300829090928|
    |  -124.16|    40.8|  52.0|  2416.0|   618.0|1150.0|   571.0|1.7308| 80.5|[80.5,1.730800032   |[0.69420890337891   | 80.56692532800037|
    |  -124.15|   40.59|  39.0|  1186.0|   238.0| 539.0|   212.0|2.0938| 79.6|[79.6,2.093800067   |[0.68644756160200   | 80.59643757652566|
    |  -124.15|   40.78|  36.0|  2112.0|   374.0| 829.0|   368.0|3.3984| 90.0|[90.0,3.398400068   |[0.77613417769070   |  90.5945896867645|
    |  -124.15|   41.81|  17.0|  3276.0|   628.0|3546.0|   585.0|2.2868|103.1|[103.1,2.28679990   |[0.88910481911013   | 94.78147144605492|
    |  -124.14|   40.57|  29.0|  2864.0|   600.0|1314.0|   562.0|2.1354| 75.1|[75.1,2.135400056   |[0.64764085271746   | 76.22355760734331|
    |  -124.14|   40.58|  25.0|  1899.0|   357.0| 891.0|   355.0|2.6987| 92.5|[92.5,2.698699951   |[0.79769346040434   | 93.63768146514934|
    |  -124.14|   41.06|  32.0|  1020.0|   215.0| 421.0|   198.0|3.0208|143.4|[143.4,3.02080011   |[1.23664045645386   |143.15494451322667|
    |  -124.11|   40.95|  19.0|  1734.0|   365.0| 866.0|   342.0|  2.96| 81.7|[81.7,2.960000038   |[0.70455735908145   | 82.38724990731038|
    |   -124.1|    40.9|  18.0|  4032.0|   798.0|1948.0|   775.0|2.7321| 92.6|[92.6,2.732100009   |[0.79855583171288   |  92.3936172993387|
    |  -124.09|   40.55|  24.0|  2978.0|   553.0|1370.0|   480.0|2.7644| 97.3|[97.3,2.764400005   |[0.83908728321451   | 97.88765777716378|
    |  -124.09|   40.86|  25.0|  1322.0|   387.0| 794.0|   379.0|1.1742| 75.0|[75.0,1.174200057   |[0.64677848140892   | 75.86383177479371|
    only showing top 20 rows