PyStarburst: the DataFrame API

  • Lester Martin

    Lester Martin

    Developer Adocate

    Starburst

Share

I’m so excited about the Starburst blog Introducing Python DataFrames in Starburst Galaxy and I wanted to show a few examples of this exciting new feature set.

NOTE: This blog post is NOT attempting to teach you everything you need to know about the DataFrame API, but it will provide some insight into this rich subject matter.

The real goal is to see it in action!!

Setup your environment

As the Py in PyStarburst suggests, you clearly need Python installed. For my Mac, I set this up with brew a long time ago. For your environment, you may do something different.

$ brew install python
    ... many lines rm'd ...
$ python3 --version
Python 3.10.9

I then needed to get pip set up. Here’s what I did.

$ python3 -m ensurepip
    ... many lines rm'd ...
$ python3 -m pip install --upgrade pip
    ... many lines rm'd ...
$ pip --version
pip 23.2.1 from ... (python 3.10)

At this point you can get some more help from Starburst Galaxy by visiting Partner connect >> Drivers & Clients >> PyStarburst which surfaces a pop-up like the following. Use the Select cluster pulldown to align with the cluster you want to run some PyStarburst code against.

Click on the Download connection file button to get something like the following (file is named main.py) which has everything filled in, except the password. I masked out the values from my orange strike-outs above, too.

import trino
from pystarburst import Session

db_parameters = {
    "host": "tXXXXXXXXXXe.trino.galaxy.starburst.io",
    "port": 443,
    "http_scheme": "https",
    # Setup authentication through login or password or any other supported authentication methods
    # See docs: https://github.com/trinodb/trino-python-client#authentication-mechanisms
    "auth": trino.auth.BasicAuthentication("lXXXXXX/XXXXXXn", "<password>")
}
session = Session.builder.configs(db_parameters).create()
session.sql("SELECT * FROM system.runtime.nodes").collect()

Just to clean that up and make things go a bit smoother, delete lines 8 & 9 and then add the following two lines after line 2.

from pystarburst import functions as f
from pystarburst.functions import col

Lastly, replace the last line with the following (assuming you are using the TPCH catalog on the cluster you selected earlier).

session.table("tpch.tiny.region").show()

Back in the pop-up from earlier, there is a link to the PyStarburst docs site. From there, run the pip install command listed in the Install the library section. There is also some boilerplate code that you already have manipulated above.

Test the boilerplate code

The docs site above also points to an example Jupyter notebook and that suggests you should be using Jupyter, or another web-based notebook tool. That’s a great path to go down, but I’m going to keep it a bit more simple and just run my code from the CLI.

$ python3 main.py
----------------------------------------------------------------------------------
|"regionkey"  |"name"       |"comment"                                           |
----------------------------------------------------------------------------------
|0            |AFRICA       |lar deposits. blithely final packages cajole. r...  |
|1            |AMERICA      |hs use ironic, even requests. s                     |
|2            |ASIA         |ges. thinly even pinto beans ca                     |
|3            |EUROPE       |ly final courts cajole furiously final excuse       |
|4            |MIDDLE EAST  |uickly special accounts cajole carefully blithe...  |
----------------------------------------------------------------------------------

Awesome! We used the API to basically run a SELECT statement, which verified we can create a DataFrame with code that ran in Starburst Galaxy. In fact, you can see in Query history that it was run.

Explore the API

The docs page from above has a link to the detailed PyStarburst DataFrame API documentation site. As mentioned at the start of this post, I am NOT going to try to teach you Spark’s DataFrame API here. If this is totally new to you, one place you might start is this programming guide on the Apache Spark website.

I’ll be building some training around PyStarburst and it will surely start from the basics of what a DataFrame is and build from there. Ping me if you’re interested in such a class. Of course, I’ll let you know what the code below is doing — at least at a high-level.

Select a full table

Add these next lines to the end of your Python source file which use the table() function to grab hold of the customer table and then display the first 10 rows (the show() function, without an integer as an argument, defaults to 10) and then run it with python3 main.py as shown above.

custDF = session.table("tpch.tiny.customer")
custDF.show()

Here is the output.

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"custkey"  |"name"              |"address"                              |"nationkey"  |"phone"          |"acctbal"  |"mktsegment"  |"comment"                                           |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|1          |Customer#000000001  |IVhzIApeRb ot,c,E                      |15           |25-989-741-2988  |711.56     |BUILDING      |to the even, regular platelets. regular, ironic...  |
|2          |Customer#000000002  |XSTf4,NCwDVaWNe6tEgvwfmRchLXak         |13           |23-768-687-3665  |121.65     |AUTOMOBILE    |l accounts. blithely ironic theodolites integra...  |
|3          |Customer#000000003  |MG9kdTD2WBHm                           |1            |11-719-748-3364  |7498.12    |AUTOMOBILE    | deposits eat slyly ironic, even instructions. ...  |
|4          |Customer#000000004  |XxVSJsLAGtn                            |4            |14-128-190-5944  |2866.83    |MACHINERY     | requests. final, regular ideas sleep final accou   |
|5          |Customer#000000005  |KvpyuHCplrB84WgAiGV6sYpZq7Tj           |3            |13-750-942-6364  |794.47     |HOUSEHOLD     |n accounts will have to unwind. foxes cajole accor  |
|6          |Customer#000000006  |sKZz0CsnMD7mp4Xd0YrBvx,LREYKUWAh yVn   |20           |30-114-968-4951  |7638.57    |AUTOMOBILE    |tions. even deposits boost according to the sly...  |
|7          |Customer#000000007  |TcGe5gaZNgVePxU5kRrvXBfkasDTea         |18           |28-190-982-9759  |9561.95    |AUTOMOBILE    |ainst the ironic, express theodolites. express,...  |
|8          |Customer#000000008  |I0B10bB0AymmC, 0PrRYBCP1yGJ8xcBPmWhl5  |17           |27-147-574-9335  |6819.74    |BUILDING      |among the slyly regular theodolites kindle blit...  |
|9          |Customer#000000009  |xKiAFTjUsCuxfeleNqefumTrjS             |8            |18-338-906-3675  |8324.07    |FURNITURE     |r theodolites according to the requests wake th...  |
|10         |Customer#000000010  |6LrEaV6KR6PLVcgl2ArL Q3rqzLzcT1 v2     |5            |15-741-346-9870  |2753.54    |HOUSEHOLD     |es regular deposits haggle. fur                     |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

That is quite busy in the CLI, but probably looks good in a notebook since it won’t wrap the text.

Use projection

We really only need a couple of columns, so we can use the select() method on the existing DataFrame to identify those that we really want. There is a compensatory drop() function that would be better if we wanted to keep most of the columns and only remove a few.

projectedDF = custDF.select(custDF.name, custDF.acctbal, custDF.nationkey)
projectedDF.show()

Here is the output after adding those lines of code above and running your Python program again. It looks a bit more manageable.

------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |
------------------------------------------------
|Customer#000000751  |2130.98    |0            |
|Customer#000000752  |8363.66    |8            |
|Customer#000000753  |8114.44    |17           |
|Customer#000000754  |-566.86    |0            |
|Customer#000000755  |7631.94    |16           |
|Customer#000000756  |8116.99    |14           |
|Customer#000000757  |9334.82    |3            |
|Customer#000000758  |6352.14    |17           |
|Customer#000000759  |3477.59    |1            |
|Customer#000000760  |2883.24    |2            |
------------------------------------------------

Again, the show() command without an argument is displaying only 10 rows.

Filter the rows

Well-named, the filter() function does exactly what we need it to do. In this example, we are trying to limit to the customer records with the highest account balance values. Add these next lines to the end of your Python source file and run it again.

filteredDF = projectedDF.filter(projectedDF.acctbal > 9900.0)
filteredDF.show(100)

Here is the output.

------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |
------------------------------------------------
|Customer#000001106  |9977.62    |21           |
|Customer#000000043  |9904.28    |19           |
|Customer#000000045  |9983.38    |9            |
|Customer#000000140  |9963.15    |4            |
|Customer#000000200  |9967.6     |16           |
|Customer#000000213  |9987.71    |24           |
|Customer#000000381  |9931.71    |5            |
------------------------------------------------

Notice that even though 100 records were requested to be displayed, there are only 7 records that meet this criteria.

Select a second table

Later, we are going to join our customer records to the nation table to get the name of the country, not just a key value for it. In the example below, we are chaining methods together instead of assigning each output to a distinct variable as we have done up until now.

nationDF = session.table("tpch.tiny.nation") \
                  .drop("regionkey", "comment") \
                  .rename("name", "nation_name") \
                  .rename("nationkey", "n_nationkey")
nationDF.show()

We have already presented table() and drop(). The rename() function simply changes a column’s name to something else as you can see in the output.

---------------------------------
|"n_nationkey"  |"nation_name"  |
---------------------------------
|0              |ALGERIA        |
|1              |ARGENTINA      |
|2              |BRAZIL         |
|3              |CANADA         |
|4              |EGYPT          |
|5              |ETHIOPIA       |
|6              |FRANCE         |
|7              |GERMANY        |
|8              |INDIA          |
|9              |INDONESIA      |
---------------------------------

Join the tables

Now we can join() the two DataFrames using their nationkey values.

joinedDF = filteredDF.join(nationDF, filteredDF.nationkey == nationDF.n_nationkey)
joinedDF.show()

Here is the output.

--------------------------------------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |"n_nationkey"  |"nation_name"  |
--------------------------------------------------------------------------------
|Customer#000000140  |9963.15    |4            |4              |EGYPT          |
|Customer#000000381  |9931.71    |5            |5              |ETHIOPIA       |
|Customer#000000045  |9983.38    |9            |9              |INDONESIA      |
|Customer#000000200  |9967.6     |16           |16             |MOZAMBIQUE     |
|Customer#000000043  |9904.28    |19           |19             |ROMANIA        |
|Customer#000001106  |9977.62    |21           |21             |VIETNAM        |
|Customer#000000213  |9987.71    |24           |24             |UNITED STATES  |
--------------------------------------------------------------------------------

As you can see, the join() function did not let us get rid of unwanted columns; we have all from both DataFrames.

Project the joined result

How do we clean up those unwanted columns? Exactly right, we talked about this before!

projectedJoinDF = joinedDF.drop("nationkey").drop("n_nationkey")
projectedJoinDF.show()
--------------------------------------------------
|"name"              |"acctbal"  |"nation_name"  |
--------------------------------------------------
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000043  |9904.28    |ROMANIA        |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000213  |9987.71    |UNITED STATES  |
--------------------------------------------------

Apply a sort

I love it when the methods do what they say; sort() is no different.

orderedDF = projectedJoinDF.sort(col("acctbal"), ascending=False)
orderedDF.show()
--------------------------------------------------
|"name"              |"acctbal"  |"nation_name"  |
--------------------------------------------------
|Customer#000000213  |9987.71    |UNITED STATES  |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000043  |9904.28    |ROMANIA        |
--------------------------------------------------

Put it all together

While the creation of multiple DataFrame objects was used above, in practice (as discussed when fetching the nation table) most DataFrame API programmers chain many methods together to look at bit more like this.

nationDF = session.table("tpch.tiny.nation") \
            .drop("regionkey", "comment") \
            .rename("name", "nation_name") \
            .rename("nationkey", "n_nationkey")
apiSQL = session.table("tpch.tiny.customer") \
            .select("name", "acctbal", "nationkey") \
            .filter(col("acctbal") > 9900.0) \
            .join(nationDF, col("nationkey") == nationDF.n_nationkey) \
            .drop("nationkey").drop("n_nationkey") \
            .sort(col("acctbal"), ascending=False)
apiSQL.show()

This produces the same result as before. There is a lot more going on with the PyStarburst implementation including the lazy execution model that the DataFrame API is known for. In a nutshell, this simply means that the program waits until it absolutely needs to run some code on the Trino engine that Starburst Galaxy is built on top of.

If only these 3 lines of code were run after the session object was created in the boilerplate source, then ultimately only a single SQL statement was sent to Starburst Galaxy — again, that you can find in the Query history page.

The generated SQL

SELECT "name" , "acctbal" , "nation_name" FROM ( SELECT "name" , "acctbal" , "n_nationkey" , "nation_name" FROM ( SELECT * FROM (( SELECT "name" "name" , "acctbal" "acctbal" , "nationkey" "nationkey" FROM ( SELECT * FROM ( SELECT "name" , "acctbal" , "nationkey" FROM ( SELECT * FROM tpch.tiny.customer ) ) WHERE ("acctbal" > DOUBLE '9900.0') ) ) INNER JOIN ( SELECT "n_nationkey" "n_nationkey" , "nation_name" "nation_name" FROM ( SELECT "nationkey" "n_nationkey" , "nation_name" FROM ( SELECT "nationkey" , "name" "nation_name" FROM ( SELECT "nationkey" , "name" FROM ( SELECT * FROM tpch.tiny.nation ) ) ) ) ) ON ("nationkey" = "n_nationkey")) ) ) ORDER BY "acctbal" DESC NULLS LAST OFFSET 0 ROWS LIMIT 10

The generated SQL above is clearly something a program would have created and in fairness it is walking the PyStarburst function calls and building some pretty ugly SQL. The good news is the cost-based optimizer (CBO) inside Trino deciphered it all and broke it down into a very efficient 3 stage job that utilized a broadcast join as seen in this eye exam of a visualization from the directed acyclic graph (DAG).

If all that CBO and DAG talk was mumbo-jumbo, and you want to learn more, check out these free training modules from Starburst Academy.

Or… just run some SQL

I’ll be honest, I actually LIKE that code above chaining methods together all while looking back and forth into the API doc, but I’m a programmer. If you were following the code along the way, you realized we were just building the equivalent to a rather simple SQL statement doing filtering & projection, joining two tables, and sorting the results.

Are you wondering instead of using the Session object’s table() function to start our efforts if there would be a way to just run some SQL instead?

Well, yes, there is. It is called the sql() method and here is an example of its use with the hand-crafted, rather simple, SQL statement that is doing the same thing as the rest of this post.

dfSQL = session.sql("SELECT c.name, c.acctbal, n.name "\
                    "  FROM tpch.tiny.customer c "\
                    "  JOIN tpch.tiny.nation n "\
                    "    ON c.nationkey = n.nationkey "\
                    " WHERE c.acctbal > 9900.0 "\
                    " ORDER BY c.acctbal DESC "
dfSQL.show()

Probably a bit more obvious than before is the generated code that you can find in the Query history page on Starburst Galaxy

The generated SQL

SELECT c.name , c.acctbal , n.name FROM (tpch.tiny.customer c INNER JOIN tpch.tiny.nation n ON (c.nationkey = n.nationkey)) WHERE (c.acctbal > DECIMAL '9900.0') ORDER BY c.acctbal DESC OFFSET 0 ROWS LIMIT 10

If you look closely, you’ll see that the SQL was modified a bit such as adding the INNER keyword for the join type and a LIMIT 10 clause due to the show() function’s default behavior. It is not simply “passing through” the query.

More interesting is that the same 3 stage job with a broadcast join was run with the same text and visual query plan being created from the DAG.

Wrap up

You’ve had a quick tour of the DataFrame API implementation with Python that runs the code ultimately as SQL on Starburst Galaxy.

We’ve see just a tiny bit of the rich API that is available to data engineers who prefer to write programs over SQL. We also saw that often, we can just replace the “neat” function calls with just hard-coding SQL and in all fairness, it is a great idea for code maintainability to use the sql() function to generate DataFrames when we can.

I hope you are as excited as I am to experiment more with PyStarburst! 

Blog post originally published, here.