Last Updated: 2024-08-15
PyStarburst is a library that brings Python DataFrames to Starburst. If you're a data analyst or data scientist, this means that you can choose to use SQL or Python to analyze your data, all while taking advantage of the Trino distributed query engine.
This tutorial will guide you through setting up a schema and tables in Starburst Galaxy. It will then show you how to use PyStarburst to answer some questions about the data. The equivalent SQL commands will be provided for reference.
Once you've completed this tutorial, you will be able to:
In one of the prerequisites to this tutorial, you created a catalog called tmp_cat
which you connected to an S3 bucket owned by Starburst. The S3 bucket that you connected to also contains some flight data, which you will use in this tutorial. Before you can analyze the data, you'll need to complete a few tasks to prepare your environment.
First, you'll have to add a location privilege to the accountadmin role in Starburst Galaxy to ensure that you can write to the folder within the S3 bucket that contains the flight data. After that, you'll use the Query editor to execute some SQL to create the necessary schema and tables.
Your current role is listed in the top right-hand corner of the screen.
You must add a location privilege for the accountadmin role to be able to write to the S3 bucket location. If you don't, you'll get an error when you try to create your schema.
s3://starburst-tutorials/*
Now that you've added the location privilege, you're ready to create a schema for the aviation data.
CREATE SCHEMA IF NOT EXISTS "tmp_cat"."aviation" WITH (location = 's3://starburst-tutorials/projects/aviation/')
raw_flight
tableThe data you'll be working with comprises four csv files: flights.csv
, airports.csv
, carriers.csv
, and plane-data.csv
. You're going to make one table for each file, beginning with flights.csv
. Review the following ERD to understand the tables and their logical relationships.
raw_flight
table. CREATE TABLE tmp_cat.aviation.raw_flight (
month smallInt,
day_of_month smallInt,
day_of_week smallInt,
dep_time smallInt,
arr_time smallInt,
unique_carrier varchar(15),
flight_number smallInt,
tail_number varchar(15),
elapsed_time smallInt,
air_time smallInt,
arr_delay smallInt,
dep_delay smallInt,
origination varchar(15),
destination varchar(15),
distance smallInt,
taxi_in smallInt,
taxi_out smallInt,
cancelled varchar(15),
cancellation_code varchar(15),
diverted varchar(15)
) WITH (
external_location = 's3://starburst-tutorials/aviation/flights,
type = 'HIVE',
format = 'TEXTFILE',
textfile_field_separator = ','
);
raw_carrier
tableNext up is the raw_carrier
table. The SQL for this one is much shorter.
raw_carrier
table. CREATE TABLE tmp_cat.aviation.raw_carrier (
code varchar(15),
description varchar(150)
) WITH (
external_location = 's3://starburst-tutorials/aviation/carriers',
format = 'TEXTFILE',
type = 'HIVE',
textfile_field_separator = ','
);
raw_airport
tableraw_airport
table. CREATE TABLE tmp_cat.aviation.raw_airport (
code varchar(15),
description varchar(150),
city varchar(150),
state varchar(150),
country varchar(150),
lat decimal(10,8),
lng decimal(11,8)
) WITH (
external_location = 's3://starburst-tutorials/aviation/airports',
type = 'HIVE',
format = 'TEXTFILE',
textfile_field_separator = ','
);
raw_plane
tableJust one more to go!
raw_plane
table. CREATE TABLE tmp_cat.aviation.raw_plane (
tail_number varchar(15),
usage varchar(150),
manufacturer varchar(150),
issue_date varchar(150),
model varchar(150),
status varchar(150),
aircraft_type varchar(150),
engine_type varchar(150),
year_built smallint
) WITH (
external_location = 's3://starburst-tutorials/aviation/plane-data',
type = 'HIVE',
format = 'TEXTFILE',
textfile_field_separator = ','
);
For this section of the tutorial, we'll be writing all of our code in a file named aviation.py
. We'll then use a terminal window to execute the code with the command python3 aviation.py
.
We will be exploring the aviation data and asking seven analytical questions. For each PyStarburst solution, we will also provide the equivalent SQL.
aviation.py
and add imports Your aviation.py
file needs to include code to import the required libraries and also connect to your Starburst Galaxy cluster. We've put that together for you – you simply have to add your Starburst Galaxy cluster details.
aviation.py
.import trino
from pystarburst import Session
from pystarburst import functions as f
from pystarburst.functions import col, lag, round, row_number
from pystarburst.window import Window
db_parameters = {
"host": "<your host>",
"port": 443,
"http_scheme": "https",
"auth": trino.auth.BasicAuthentication("<your galaxy username>", "<your password>")
}
session = Session.builder.configs(db_parameters).create()
flight
table?It's time to answer the first question. This is a simple solution whether you're using SQL or PyStarburst. In the case of PyStarburst, you just retrieve the raw_flight
table as a DataFrame and then call the count()
function.
aviation.py
file. Save the changes when complete.allFs = session.table("tmp_cat.aviation.raw_flight")
print(allFs.count())
python3 aviation.py
SELECT count(*)
FROM tmp_cat.aviation.raw_flight;
This one is also relatively straightforward. After selecting the raw_airport
table, we'll use a group_by()
function and perform a count()
function on the aggregated rows. Finally, we'll order the results by the number of rows for each country and show a single result.
aviation.py
file. Save the changes when complete.# get the whole table, aggregate & sort
mostAs = session \
.table("tmp_cat.aviation.raw_airport") \
.group_by("country").count() \
.sort("count", ascending=False)
mostAs.show(1)
python3 aviation.py
SELECT country, count() AS num_airports
FROM tmp_cat.aviation.raw_airport
GROUP BY country
ORDER BY num_airports DESC;
This one is very similar to the previous question.
aviation.py
file. Save the changes when complete.# get the whole table, aggregate & sort
mostFs = session \
.table("tmp_cat.aviation.raw_flight") \
.group_by("unique_carrier").count() \
.rename("unique_carrier", "carr") \
.sort("count", ascending=False)
mostFs.show(5)
python3 aviation.py
SELECT unique_carrier, count() as num_flights
FROM tmp_cat.aviation.raw_flight
GROUP BY unique_carrier
ORDER BY num_flights DESC
LIMIT 5;
This is the same as the previous question, except we're looking for the name of the airline rather than its code. We'll create a DataFrame for the raw_carrier
table to join on. Then, we'll take the code from Question 3 and chain a few more methods on it, namely the join()
.
aviation.py
file. Save the changes when complete.# get all of the carriers
allCs = session.table("tmp_cat.aviation.raw_carrier")
# repurpose mostFs from above (or chain on it)
# to join the 2 DFs and sort the results that
# have already been grouped
top5CarrNm = mostFs \
.join(allCs, mostFs.carr == allCs.code) \
.drop("code") \
.sort("count", ascending=False)
top5CarrNm.show(5, 30)
python3 aviation.py
SELECT c.description, count() as num_flights
FROM tmp_cat.aviation.raw_flight f
JOIN tmp_cat.aviation.raw_carrier c
ON (f.unique_carrier = c.code)
GROUP BY c.description
ORDER BY num_flights DESC
LIMIT 5;
We'll be utilizing another join here.
aviation.py
file. Save the changes when complete.# trimFs are flights projected & filtered
trimFs = session.table("tmp_cat.aviation.raw_flight") \
.rename("tail_number", "tNbr") \
.select("tNbr", "distance") \
.filter(col("distance") > 1500)
# trimPs are planes table projected & filtered
trimPs = session.table("tmp_cat.aviation.raw_plane") \
.select("tail_number", "model") \
.filter("model is not null")
# join, group & sort
q5Answer = trimFs \
.join(trimPs, trimFs.tNbr == trimPs.tail_number) \
.drop("tail_number") \
.group_by("model").count() \
.sort("count", ascending=False)
q5Answer.show()
python3 aviation.py
SELECT p.model, count() as num_flights
FROM tmp_cat.aviation.raw_flight f
JOIN tmp_cat.aviation.raw_plane p
ON (f.tail_number = p.tail_number)
WHERE f.distance > 1500
AND p.model IS NOT NULL
GROUP BY p.model
ORDER BY num_flights desc
LIMIT 10;
Let's begin with the SQL solution for this question. This solution leverages Common Table Expressions (CTE) which can be conceptualized as temporary tables. We'll follow this general approach in the Python solution where the code will be explained a bit more.
SQL solution:
WITH agg_flights AS (
SELECT origination, month,
COUNT(*) AS num_flights
FROM tmp_cat.aviation.raw_flight
GROUP BY 1,2
),
change_flights AS (
SELECT origination, month, num_flights,
LAG(num_flights, 1)
OVER(PARTITION BY origination
ORDER BY month ASC)
AS num_flights_before
FROM agg_flights
)
SELECT origination, month, num_flights, num_flights_before,
ROUND((1.0 * (num_flights - num_flights_before)) /
(1.0 * (num_flights_before)), 2)
AS perc_change
FROM change_flights;
PyStarburst solution:
agg_flights
CTE above:# temp DF holds counts for each originating airport
# by month
aggFlights = session.table("tmp_cat.aviation.raw_flight") \
.select("origination", "month") \
.rename("origination", "orig") \
.group_by("orig", "month").count() \
.rename("count", "num_fs")
# define a window specification
w1 = Window.partition_by("orig").order_by("month")
# add col to grab the prior row's nbr flights
changeFlights = aggFlights \
.withColumn("num_fs_b4", \
lag("num_fs",1).over(w1))
# add col for the percentage change
q6Answer = changeFlights \
.withColumn("perc_chg", \
round((1.0 * (col("num_fs") - col("num_fs_b4")) / \
(1.0 * col("num_fs_b4"))), 1))
q6Answer.show()
This is another CTE solution where we will begin with the SQL solution and use that as a guide for the Python solution.
SQL solution:
WITH popular_routes AS (
SELECT origination, destination,
COUNT(*) AS num_flights
FROM raw_flight
GROUP BY 1, 2
),
ranked_routes AS (
SELECT origination, destination,
ROW_NUMBER()
OVER(PARTITION BY origination
ORDER BY num_flights DESC)
AS rank
FROM popular_routes
)
SELECT origination, destination, rank
FROM ranked_routes
WHERE rank <= 3
ORDER BY origination, rank;
PyStarburst solution:
popular_routes
CTE above:# determine counts from orig>dest pairs
popularRoutes = session \
.table("tmp_cat.aviation.raw_flight") \
.rename("origination", "orig") \
.rename("destination", "dest") \
.group_by("orig", "dest").count() \
.rename("count", "num_fs")
# define a window specification
w2 = Window.partition_by("orig") \
.order_by(col("num_fs").desc())
# add col to put the curr row's ranking in
rankedRoutes = popularRoutes \
.withColumn("rank", \
row_number().over(w2))
# just show up to 3 for each orig airport
q7Answer = rankedRoutes \
.filter(col("rank") <= 3) \
.sort("orig", "rank")
q7Answer.show(17);
Congratulations! You have reached the end of this tutorial, and the end of this stage of your journey.
Now that you've learned a bit about how to use PyStarburst for data analysis, we encourage you to explore the API documentation and write your own code.
At Starburst, we believe in continuous learning. This tutorial provides the foundation for further training available on this platform, and you can return to it as many times as you like. Future tutorials will make use of the concepts used here.
Starburst has lots of other tutorials to help you get up and running quickly. Each one breaks down an individual problem and guides you to a solution using a step-by-step approach to learning.
Visit the Tutorials section to view the full list of tutorials and keep moving forward on your journey!