[DEV] 14주차. Hadoop과 Spark (3)
1. Spark SQL
- 구조화된 데이터 처리를 위한 Spark 모듈
- 데이터프레임 작업을 SQL로 처리 가능
- 데이터프레임에 테이블 이름 지정 후 sql 함수 사용 가능
- Pandas에도 pandasql 모듈의 spldf 함수를 이용하는 동일한 패턴 존재
- HQL (Hive QL)과 호환 제공
- Hive 테이블들을 읽고 쓸 수 있음 (Hive Metastore)
- 보통 Hive와 Spark 시스템을 동시에 사용하는 것이 일반적 (YARN 위에서)
- 데이터프레임에 테이블 이름 지정 후 sql 함수 사용 가능
Spark SQL vs. DataFrame
- SQL로 가능한 작업이면 DataFrame을 사용할 이유가 없음
- 하지만 두 개를 동시에 사용할 수 있다는 점 기억할 것!
- Familiarity / Readability
- SQL이 더 가독성이 좋고 많은 사람들이 사용 가능
- Optimization
- Spark SQL 엔진이 최적화하기 더 좋음 (SQL은 Declarative)
- Catalyst Optimizer와 Project Tungsten
- Interoperability / Data Management
- SQL이 포팅도 쉽고 접근권한 체크도 쉬움
Spark SQL 사용 방법
- 데이터프레임을 기반으로 테이블 뷰 생성 : 테이블이 만들어짐
createOrReplaceTempView
: Spark Session이 살아있는 동안 존재createOrReplaceGlobalTempView
: Spark 드라이버가 살아있는 동안 존재
- Spark Session의 sql 함수로 SQL 결과를 데이터프레임으로 받음
namegender_df.createOrReplaceTempView('namegender')
namegender_group_df = spark.sql("""
SELECT gender, count(1) FROM namegender GROUP BY 1
""")
print(namegender_group_df.collect())
SparkSession 외부 데이터베이스 연결
- Spark Session의 read 함수 호출
- 로그인 관련 정보와 읽어오고자 하는 테이블 혹은 SQL 지정
- 결과가 데이터프레임으로 반환됨
df_user_session_channel = spark.read\
.format('jdbc')\
.option('driver', 'com.amazon.redshift.jdbc42.Driver')\
.option('url', 'jdbc:redshift://HOST:PORT/DB?user=ID&password=PASSWORD')\
.option('dbtable', 'raw_data.user_session_channel') # SELECT문도 가능
.load()
Aggregation
-
DataFrame이 아닌 SQL로 작성하는 것 추천
- GroupBy
- Window
- Rank
JOIN
- 두 개 혹은 그 이상의 테이블들을 공통 필드를 가지고 머지
- 스타 스키마로 구성된 테이블들로 분산되어 있던 정보를 통합하는데 사용
- 왼쪽 테이블을 LEFT, 오른쪽 테이블을 RIGHT라고 하면
- 결과는 방식에 따라 양쪽 필드를 모두 가진 새로운 테이블 생성
- 방식에 따라 두 가지가 달라짐
- 어떤 레코드들이 선택되는지
- 어떤 필드들이 채워지는지
최적화 관점에서 본 조인의 종류들
- Shuffle Join
- 일반 조인 방식
- Bucket Join: 조인 키를 바탕으로 새로운 파티션을 만들고 조인하는 방식
- Broadcast Join
- 큰 데이터와 작은 데이터 간의 조인
- 데이터프레임 하나가 충분이 작으면 작은 데이터프레임을 다른 데이터프레임이 있는 서버들로 뿌리는 것 (broadcasting)
spark.sql.autoBroadcastJoinThreshold
파라미터로 충분히 작은지 여부 결정
2. UDF (User Defined Function)
- 데이터프레임의 경우
.withColumn
함수와 같이 사용하는 것이 일반적- SparkSQL에서도 사용 가능
- Aggregation용 UDAF (User Defined Aggregation Function)도 존재
- GROUP BY 에서 사용되는 SUM, AVG와 같은 함수를 만드는 것
- PySpark에서 지원되지 않음. Scalar/Java를 사용해야 함
DataFrame에 사용
import pyspark.sql.functions as F
from pyspark.sql.types import *
upperUDF = F.udf(lambda z:z.upper())
df.withColumn('Curated Name', upperUDF('Name'))
SparkSQL에 사용
def upper(s):
return s.upper()
# 먼저 테스트
upperUDF = spark.udf.register('upper', upper)
spark.sql("SELECT upper('aBcD')").show()
# DataFrame 기반 SQL에 적용
df.createOrReplaceTempView('test')
spark.sql("""SELECT name, upper(name) "Curated Name" FROM test""").show()
Pandas UDF Scalar 함수
from pyspark.sql.functions import pandas_udf
import pandas as pd
@pandas_udf(StringType()) # 각 컬럼의 타입
def upper_udf2(s:pd.Series) -> pd.Series:
return s.str.upper()
# bulk로 처리 -> 더 퍼포먼스가 좋음
upperUDF = spark.udf.register('upper_udf', upper_udf2)
df.select('Name', upperUDF('Name')).show()
spark.sql("""SELECT name, upper_udf(name) 'Curated Name' FROM test""").show()
UDF - DataFrame/SQL에 Aggregation 사용
from pyspark.sql.functions import pandas_df
import pandas as pd
@pandas_udf(FloatType())
def average(v:pd.Series) -> float:
return v.mean()
averageUDF = spark.udf.register('average', average)
spark.sql('SELECT average(b) FROM test').show()
df.agg(averageUDF('b').alias('count')).show()
3. Spark SQL 실습
실습 테이블
- 사용자 ID
- 보통 웹서비스는 등록된 사용자마다 유일한 ID 부여 -> 사용자 ID
- 세션 ID
- 사용자가 외부 링크를 타고 오거나 직접 방문해서 올 경우 세션을 생성
- 즉 하나의 사용자 ID는 여러 개의 세션 ID를 가질 수 있음
- 보통 세션의 경우 세션을 만들어낸 소스를 채널이라는 이름으로 기록해둠
- 마케팅 관련 기여도 분석을 위함
- 또한, 세션이 생긴 시간도 기록
- 이 정보를 기반으로 다양한 데이터 분석과 지표 설정이 가능
- 마케팅 관련
- 사용자 트래픽 관련
JOIN
- 두 개의 테이블
VitalID
기준 JOIN
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.appName("Python Spark SQL #1") \
.getOrCreate()
vital = [
{ 'UserID': 100, 'VitalID': 1, 'Date': '2020-01-01', 'Weight': 75 },
{ 'UserID': 100, 'VitalID': 2, 'Date': '2020-01-02', 'Weight': 78 },
{ 'UserID': 101, 'VitalID': 3, 'Date': '2020-01-01', 'Weight': 90 },
{ 'UserID': 101, 'VitalID': 4, 'Date': '2020-01-02', 'Weight': 95 },
]
alert = [
{ 'AlertID': 1, 'VitalID': 4, 'AlertType': 'WeightIncrease', 'Date': '2020-01-01', 'UserID': 101},
{ 'AlertID': 2, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-04', 'UserID': 100},
{ 'AlertID': 3, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-05', 'UserID': 101}
]
rdd_vital = spark.sparkContext.parallelize(vital)
rdd_alert = spark.sparkContext.parallelize(alert)
df_vital = rdd_vital.toDF()
df_alert = rdd_alert.toDF()
DataFrame JOIN
- Inner Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "inner").show()
Date | UserID | VitalID | Weight | AlertID | AlertType | Date | UserID | VitalID |
---|---|---|---|---|---|---|---|---|
2020-01-02 | 101 | 4 | 95 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
- Left Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "left").show()
Date | UserID | VitalID | Weight | AlertID | AlertType | Date | UserID | VitalID |
---|---|---|---|---|---|---|---|---|
2020-01-01 | 100 | 1 | 75 | null | null | null | null | null |
2020-01-02 | 100 | 2 | 78 | null | null | null | null | null |
2020-01-01 | 101 | 3 | 90 | null | null | null | null | null |
2020-01-02 | 101 | 4 | 95 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
- Right Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "right").show()
Date | UserID | VitalID | Weight | AlertID | AlertType | Date | UserID | VitalID |
---|---|---|---|---|---|---|---|---|
2020-01-02 | 101 | 4 | 95 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
null | null | null | null | 2 | MissingVital | 2020-01-04 | 100 | null |
null | null | null | null | 3 | MissingVital | 2020-01-05 | 101 | null |
- Full Outer Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "full").show()
Date | UserID | VitalID | Weight | AlertID | AlertType | Date | UserID | VitalID |
---|---|---|---|---|---|---|---|---|
null | null | null | null | 2 | MissingVital | 2020-01-04 | 100 | null |
null | null | null | null | 3 | MissingVital | 2020-01-05 | 101 | null |
2020-01-01 | 100 | 1 | 75 | null | null | null | null | null |
2020-01-02 | 100 | 2 | 78 | null | null | null | null | null |
2020-01-01 | 101 | 3 | 90 | null | null | null | null | null |
2020-01-02 | 101 | 4 | 95 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
- Cross Join
df_vital.join(df_alert, None, "cross").show()
Date | UserID | VitalID | Weight | AlertID | AlertType | Date | UserID | VitalID |
---|---|---|---|---|---|---|---|---|
2020-01-01 | 100 | 1 | 75 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
2020-01-02 | 100 | 2 | 78 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
2020-01-01 | 100 | 1 | 75 | 2 | MissingVital | 2020-01-04 | 100 | null |
2020-01-01 | 100 | 1 | 75 | 3 | MissingVital | 2020-01-05 | 101 | null |
2020-01-02 | 100 | 2 | 78 | 2 | MissingVital | 2020-01-04 | 100 | null |
2020-01-02 | 100 | 2 | 78 | 3 | MissingVital | 2020-01-05 | 101 | null |
2020-01-01 | 101 | 3 | 90 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
2020-01-02 | 101 | 4 | 95 | 1 | WeightIncrease | 2020-01-01 | 101 | 4 |
2020-01-01 | 101 | 3 | 90 | 2 | MissingVital | 2020-01-04 | 100 | null |
2020-01-01 | 101 | 3 | 90 | 3 | MissingVital | 2020-01-05 | 101 | null |
2020-01-02 | 101 | 4 | 95 | 2 | MissingVital | 2020-01-04 | 100 | null |
2020-01-02 | 101 | 4 | 95 | 3 | MissingVital | 2020-01-05 | 101 | null |
- Self Join
join_expr = df_vital.VitalID == df_vital.VitalID
df_vital.join(df_vital, join_expr, "left").show()
Date | UserID | VitalID | Weight | Date | UserID | VitalID | Weight |
---|---|---|---|---|---|---|---|
2020-01-01 | 100 | 1 | 75 | 2020-01-01 | 100 | 1 | 75 |
2020-01-02 | 100 | 2 | 78 | 2020-01-02 | 100 | 2 | 78 |
2020-01-01 | 101 | 3 | 90 | 2020-01-01 | 101 | 3 | 90 |
2020-01-02 | 101 | 4 | 95 | 2020-01-02 | 101 | 4 | 95 |
SQL JOIN
df_vital.createOrReplaceTempView("Vital")
df_alert.createOrReplaceTempView("Alert")
- Inner Join
df_inner_join = spark.sql("""SELECT * FROM Vital v
JOIN Alert a ON v.vitalID = a.vitalID;""")
df_inner_join.show()
- Left Join
df_left_join = spark.sql("""SELECT * FROM Vital v
LEFT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_left_join.show()
- Right Join
df_right_join = spark.sql("""SELECT * FROM Vital v
RIGHT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_right_join.show()
- Outer Join
df_outer_join = spark.sql("""SELECT * FROM Vital v
FULL JOIN Alert a ON v.vitalID = a.vitalID;""")
df_outer_join.show()
- Cross Join
df_cross_join = spark.sql("""SELECT * FROM Vital v
CROSS JOIN Alert a""")
df_cross_join.show()
- Self Join
df_self_join = spark.sql("""SELECT * FROM Vital v1
JOIN Vital v2""")
df_self_join.show()
Ranking
- refund 여부를 고려하지 않고, 총 매출이 가장 많은 사용자 10명 찾기
필드 | 설명 |
---|---|
사용자ID | 총 매출 |
- 3개의 테이블을 각각 데이터프레임으로 로딩
- 데이터프레임 별로 테이블 이름 지정
- Spark SQL로 처리
- 조인 방식 결정
- 조인키
- 조인 방식
- 간단한 지표부터 계산
- 조인 방식 결정
# 데이터는 Redshift에서 가져옴
!pip install pyspark==3.3.1 py4j==0.10.9.5
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.appName("Python Spark SQL #1") \
.getOrCreate()
월별 채널별 매출과 방문자 정보 계산
# Redshift와 연결해서 DataFrame으로 로딩
url = "jdbc:redshift://***.***.ap-northeast-2.redshift.amazonaws.com:5439/dev?user=***&password=***"
df_user_session_channel = spark.read \
.format("jdbc") \
.option("driver", "com.amazon.redshift.jdbc42.Driver") \
.option("url", url) \
.option("dbtable", "raw_data.user_session_channel") \
.load()
df_session_timestamp = spark.read \
.format("jdbc") \
.option("driver", "com.amazon.redshift.jdbc42.Driver") \
.option("url", url) \
.option("dbtable", "raw_data.session_timestamp") \
.load()
df_session_transaction = spark.read \
.format("jdbc") \
.option("driver", "com.amazon.redshift.jdbc42.Driver") \
.option("url", url) \
.option("dbtable", "raw_data.session_transaction") \
.load()
df_user_session_channel.createOrReplaceTempView("user_session_channel")
df_session_timestamp.createOrReplaceTempView("session_timestamp")
df_session_transaction.createOrReplaceTempView("session_transaction")
df_user_session_channel.show(5)
userid | sessionid | channel |
---|---|---|
1651 | 0004289ee1c7b8b08… | Organic |
1197 | 00053f5e11d1fe4e4… | |
1401 | 00056c20eb5a02958… | |
1399 | 00063cb5da1826feb… | |
1667 | 000958fdaefe0dd06… |
df_session_timestamp.show(5)
sessionid | ts |
---|---|
00029153d12ae1c9a… | 2019-10-18 14:14:… |
0004289ee1c7b8b08… | 2019-11-16 21:20:… |
0006246bee639c7a7… | 2019-08-10 16:33:… |
0006dd05ea1e999dd… | 2019-07-06 19:54:… |
000958fdaefe0dd06… | 2019-11-02 14:52:… |
df_session_transaction.show(5)
sessionid | refunded | amount |
---|---|---|
00029153d12ae1c9a… | false | 85 |
008909bd27b680698… | false | 13 |
0107acb41ef20db22… | false | 16 |
018544a2c48077d2c… | false | 39 |
020c38173caff0203… | false | 61 |
총 매출이 가장 많은 사용자 10명 찾기
- Inner Join / Left(Right) Join 모두 가능
- revenue(매출액)으로 order
top_rev_user_df = spark.sql("""
SELECT userid,
SUM(str.amount) revenue,
SUM(CASE WHEN str.refunded = False THEN str.amount END) net_revenue
FROM user_session_channel usc
JOIN session_transaction str ON usc.sessionid = str.sessionid
GROUP BY 1
ORDER BY 2 DESC
LIMIT 10""")
top_rev_user_df.show()
userid | revenue | net_revenue |
---|---|---|
989 | 743 | 743 |
772 | 556 | 556 |
1615 | 506 | 506 |
654 | 488 | 488 |
1651 | 463 | 463 |
973 | 438 | 438 |
262 | 422 | 422 |
1099 | 421 | 343 |
2682 | 414 | 414 |
891 | 412 | 412 |
- rank 이용
top_rev_user_df2 = spark.sql("""
SELECT
userid,
SUM(amount) total_amount,
RANK() OVER (ORDER BY SUM(amount) DESC) rank
FROM session_transaction st
JOIN user_session_channel usc ON st.sessionid = usc.sessionid
GROUP BY userid
ORDER BY rank
LIMIT 10""")
top_rev_user_df2.show()
userid | total_amount | rank |
---|---|---|
989 | 743 | 1 |
772 | 556 | 2 |
1615 | 506 | 3 |
654 | 488 | 4 |
1651 | 463 | 5 |
973 | 438 | 6 |
262 | 422 | 7 |
1099 | 421 | 8 |
2682 | 414 | 9 |
891 | 412 | 10 |
월별 채널별 매출과 방문자 정보 계산하기
- 연도-월, 채널, 총 매출액, 순매출액, 총 방문자, 매출발생 방문자, 매출 변환률 (매출발생 방문자 / 총 방문자)
- 중요) 데이터를 항상 의심하기!
- join key가 정말 Unique한지!
- 아래 sql을 실행했을 때 count 값이 1이 아니면 unique하지 않은 것!
spark.sql("""SELECT sessionid, COUNT(1) count
FROM user_session_channel
GROUP BY 1
ORDER BY 2 DESC
LIMIT 1""").show()
- 월별 채널별 총 매출액, 총 방문자, 매출발생 방문자, 변환률
mon_channel_rev_df = spark.sql("""
SELECT LEFT(ts, 7) month,
usc.channel,
COUNT(DISTINCT userid) uniqueUsers,
COUNT(DISTINCT (CASE WHEN amount >= 0 THEN userid END)) paidUsers,
SUM(amount) grossRevenue,
SUM(CASE WHEN refunded is not True THEN amount END) netRevenue,
ROUND(COUNT(DISTINCT CASE WHEN amount >= 0 THEN userid END)*100
/ COUNT(DISTINCT userid), 2) conversionRate
FROM user_session_channel usc
LEFT JOIN session_timestamp t ON t.sessionid = usc.sessionid
LEFT JOIN session_transaction st ON st.sessionid = usc.sessionid
GROUP BY 1, 2
ORDER BY 1, 2;
""")
사용자별 처음 채널과 마지막 채널 알아내기
first_last_channel_df = spark.sql("""
WITH RECORD AS (
SELECT /*사용자의 유입에 따른, 채널 순서 매기는 쿼리*/
userid,
channel,
ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts ASC) AS seq_first,
ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts DESC) AS seq_last
FROM user_session_channel u
LEFT JOIN session_timestamp t
ON u.sessionid = t.sessionid
)
SELECT /*유저의 첫번째 유입채널, 마지막 유입 채널 구하기*/
f.userid,
f.channel first_channel,
l.channel last_channel
FROM RECORD f
INNER JOIN RECORD l ON f.userid = l.userid
WHERE f.seq_first = 1 and l.seq_last = 1
ORDER BY userid
""")
또는
first_last_channel_df2 = spark.sql("""
SELECT DISTINCT A.userid,
FIRST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS First_Channel,
LAST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS Last_Channel
FROM user_session_channel A
LEFT JOIN session_timestamp B
ON A.sessionid = B.sessionid""")
userid | first_channel | last_channel |
---|---|---|
27 | Youtube | |
29 | Naver | Naver |
33 | Youtube | |
34 | Youtube | Naver |
36 | Naver | Youtube |
40 | Youtube | |
41 | Youtube | |
44 | Naver | |
45 | Youtube | |
59 |
Window 함수 - ROWS BETWEEN AND
- window 함수는 기본적으로 레코드 수를 바꾸는 것이 아니라, 새로운 컬럼을 만드는 것
SELECT
SUM(value) OVER(
order by value
rows between unbounded preceding and 2 following
-- unbounded: 개수 제한을 두지 않음
) AS rolling_sum
FROM rows_test;
value | rolling_sum |
---|---|
1 | 6 |
2 | 10 |
3 | 15 |
4 | 15 |
5 | 15 |
4. Hive 메타스토어
Spark 데이터베이스와 테이블
- 카탈로그: 테이블과 뷰에 관한 메타 데이터 관리
- 기본으로 메모리 기반 카탈로그 제공 - 세션이 끝나면 사라짐
- Hive와 호환되는 카탈로그 제공 - Persistent
- 테이블 관리 방식
- 테이블들은 데이터베이스라 부르는 폴더와 같은 구조로 관리 (2단계)
- 메모리 기반 테이블/뷰
- 임시 테이블로 사용
- 스토리지 기반 테이블
- 기본적으로 HDFS와 Parquet 포맷 사용
- Hive와 호환되는 메타스토어 사용
- 두 종류의 테이블이 존재 (Hive와 동일한 개념)
- Managed Table: Spark가 실제 데이터와 메타 데이터 모두 관리
- Unmanaged (External) Table: Spark가 메타 데이터만 관리
Spark SQL - 스토리지 기반 카탈로그
- Hive와 호환되는 메타스토어 사용
- SparkSession 생성 시
enableHiveSupport()
호출- 기본으로
default
라는 이름의 데이터베이스 생성
- 기본으로
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.appName("Python Spark Hive") \
.enableHiveSupport() \
.getOrCreate()
Spark SQL - Managed Table
- 생성 방법
dataframe.saveAsTable('table')
- SQL 문법 사용 (
CREATE TABLE
, CTAS)
spark.sql.warehouse.dir
이 가리키는 위치에 데이터가 저장됨- Parquet이 기본 데이터 포맷
- 선호하는 테이블 타입
- Spark 테이블로 처리하는 것의 장점
- JDBC/ODBC 등으로 Spark을 연결하여 접근 가능 (태블로, Power BI)
Spark SQL - External Table
- 이미 HDFS에 존재하는 데이터에 스키마를 정의하여 사용
LOCATION
이라는 프로퍼티 사용
- 메타데이터만 카탈로그에 기록됨
- 데이터는 이미 존재
- External Table은 삭제되어도 데이터는 그대로!
CREATE TABLE table_name (
column1 type1,
column2 type2,
column3 type3,
...
)
USING PARQUET
LOCATION 'hdfs_path';
실습
- DataFrame을 Managed Table로 저장
- 새로운 데이터베이스 사용
- Spark SQL로 Managed Table 사용 (CTAS)
- 데이터베이스 생성
spark.sql("CREATE DATABASE IF NOT EXISTS TEST_DB")
spark.sql("USE TEST_DB")
- metastor_db : Spark 메타스토어 - Hive 메타스토어와 호환
- spark-warehouse : HDFS 폴더에 해당
- spark에서 managed table을 만들면 여기에 저장됨
- 데이터베이스에 테이블 생성
- 기본 parquet 형식
df.write.saveAsTable("TEST_DB.orders", mode="overwrite")
- 테이블 값 읽기
spark.sql("SELECT * FROM TEST_DB.orders").show()
sparkt.table("TEST_DB.orders").show()
- spark catalog
- catalog가 인메모리가 아닌 HDFS에 영구적으로 저장되는 메타스토어
spark.catalog.listTables()
isTemporary=False
: 임시테이블이 아님tableType='MANAGED"
: managed table
- CTAS로 테이블 생성
spark.sql("""
CREATE TABLE IF NOT EXISTS TEST_DB.orders_count AS
SELECT order_id, COUNT(1) as count
FROM TEST_DB.orders
GROUP BY 1""")
5. 유닛테스트
- 코드 상의 특정 기능 (보통 메소드 형태)을 테스트하기 위해 작성된 코드
- 보통 정해진 입력을 주고 예상된 출력이 나오는지 테스트
- CI/CD를 사용하려면 전체 코드의 테스트 coverage가 매우 중요해짐 (7-80% 이상)
- 각 언어별로 정해진 테스트 프레임워크를 사용하는 것이 일반적
- JAVA : JUnit
- .NET : NUnit
- Python : unittest
- 실제 환경에서
- 내 코드를 어떻게 짜면 테스트하기 쉬울까 고민!
- 함수 - input, output
- 작은 이슈, 큰 이슈가 생길 때마다 어떻게 테스트를 했으면 이슈를 미연에 방지할 수 있었을까
- TDD (Test Driven Development)
- 코드 작성 전 테스트 코드를 먼저 만들어보고 그것에 맞춰 함수, 기능을 채워나가는 논리
실습
- 일반적으로 colab에서 테스트를 돌리지는 않음
- test 코드를 따로 만든 다음 해당 코드로 테스트 할 함수를 import해서 사용하는 것이 일반적
from unittest import TestCase
"""
일반적으로는 아래 함수가 정의된 모듈을 임포트하고 그걸 테스트
- upper_udf_f
- load_gender
- get_gender_count
이외에도 2가지 방법 더 존재
- from pyspark.sql.tests import SparkTestingBase
- pytest-spark (pytest testing framework plugin)
"""
class UtilsTestCase(TestCase):
spark = None
@classmethod
def setUpClass(cls) -> None:
cls.spark = SparkSession.builder \
.appName("Spark Unit Test") \
.getOrCreate()
def test_datafile_loading(self):
sample_df = load_gender(self.spark, "name_gender.csv")
result_count = sample_df.count()
self.assertEqual(result_count, 100, "Record count should be 100")
def test_gender_count(self):
sample_df = load_gender(self.spark, "name_gender.csv")
count_list = get_gender_count(self.spark, sample_df, "gender").collect()
count_dict = dict()
for row in count_list:
count_dict[row["gender"]] = row["count"]
self.assertEqual(count_dict["F"], 65, "Count for F should be 65")
self.assertEqual(count_dict["M"], 28, "Count for M should be 28")
self.assertEqual(count_dict["Unisex"], 7, "Count for Unisex should be 7")
def test_upper_udf(self):
test_data = [
{ "name": "John Kim" },
{ "name": "Johnny Kim"},
{ "name": "1234" }
]
expected_results = [ "JOHN KIM", "JOHNNY KIM", "1234" ]
upperUDF = self.spark.udf.register("upper_udf", upper_udf_f)
test_df = self.spark.createDataFrame(test_data)
names = test_df.select("name", upperUDF("name").alias("NAME")).collect()
results = []
for name in names:
results.append(name["NAME"])
self.assertCountEqual(results, expected_results)
@classmethod
def tearDownClass(cls) -> None:
cls.spark.stop()
import unittest
unittest.main(argv=[''], verbosity=2, exit=False)