Post

[DEV] 14주차. Hadoop과 Spark (3)

[DEV] 14주차. Hadoop과 Spark (3)

1. Spark SQL

  • 구조화된 데이터 처리를 위한 Spark 모듈
  • 데이터프레임 작업을 SQL로 처리 가능
    • 데이터프레임에 테이블 이름 지정 후 sql 함수 사용 가능
      • Pandas에도 pandasql 모듈의 spldf 함수를 이용하는 동일한 패턴 존재
    • HQL (Hive QL)과 호환 제공
      • Hive 테이블들을 읽고 쓸 수 있음 (Hive Metastore)
      • 보통 Hive와 Spark 시스템을 동시에 사용하는 것이 일반적 (YARN 위에서)

Spark SQL vs. DataFrame

  • SQL로 가능한 작업이면 DataFrame을 사용할 이유가 없음
    • 하지만 두 개를 동시에 사용할 수 있다는 점 기억할 것!
  • Familiarity / Readability
    • SQL이 더 가독성이 좋고 많은 사람들이 사용 가능
  • Optimization
    • Spark SQL 엔진이 최적화하기 더 좋음 (SQL은 Declarative)
    • Catalyst Optimizer와 Project Tungsten
  • Interoperability / Data Management
    • SQL이 포팅도 쉽고 접근권한 체크도 쉬움

Spark SQL 사용 방법

  • 데이터프레임을 기반으로 테이블 뷰 생성 : 테이블이 만들어짐
    • createOrReplaceTempView : Spark Session이 살아있는 동안 존재
    • createOrReplaceGlobalTempView : Spark 드라이버가 살아있는 동안 존재
  • Spark Session의 sql 함수로 SQL 결과를 데이터프레임으로 받음


1
2
3
4
5
namegender_df.createOrReplaceTempView('namegender')
namegender_group_df = spark.sql("""
    SELECT gender, count(1) FROM namegender GROUP BY 1
""")
print(namegender_group_df.collect())

SparkSession 외부 데이터베이스 연결

  • Spark Session의 read 함수 호출
    • 로그인 관련 정보와 읽어오고자 하는 테이블 혹은 SQL 지정
  • 결과가 데이터프레임으로 반환됨


1
2
3
4
5
6
df_user_session_channel = spark.read\
    .format('jdbc')\
    .option('driver', 'com.amazon.redshift.jdbc42.Driver')\
    .option('url', 'jdbc:redshift://HOST:PORT/DB?user=ID&password=PASSWORD')\
    .option('dbtable', 'raw_data.user_session_channel')   # SELECT문도 가능  
    .load()

Aggregation

  • DataFrame이 아닌 SQL로 작성하는 것 추천

  • GroupBy
  • Window
  • Rank

JOIN

  • 두 개 혹은 그 이상의 테이블들을 공통 필드를 가지고 머지
  • 스타 스키마로 구성된 테이블들로 분산되어 있던 정보를 통합하는데 사용
  • 왼쪽 테이블을 LEFT, 오른쪽 테이블을 RIGHT라고 하면
    • 결과는 방식에 따라 양쪽 필드를 모두 가진 새로운 테이블 생성
    • 방식에 따라 두 가지가 달라짐
      • 어떤 레코드들이 선택되는지
      • 어떤 필드들이 채워지는지

스크린샷 2024-01-17 오전 9 21 44

최적화 관점에서 본 조인의 종류들

  • Shuffle Join
    • 일반 조인 방식
    • Bucket Join: 조인 키를 바탕으로 새로운 파티션을 만들고 조인하는 방식
  • Broadcast Join
    • 큰 데이터와 작은 데이터 간의 조인
    • 데이터프레임 하나가 충분이 작으면 작은 데이터프레임을 다른 데이터프레임이 있는 서버들로 뿌리는 것 (broadcasting)
      • spark.sql.autoBroadcastJoinThreshold 파라미터로 충분히 작은지 여부 결정

2. UDF (User Defined Function)

  • 데이터프레임의 경우 .withColumn 함수와 같이 사용하는 것이 일반적
    • SparkSQL에서도 사용 가능
  • Aggregation용 UDAF (User Defined Aggregation Function)도 존재
    • GROUP BY 에서 사용되는 SUM, AVG와 같은 함수를 만드는 것
    • PySpark에서 지원되지 않음. Scalar/Java를 사용해야 함

DataFrame에 사용

1
2
3
4
5
import pyspark.sql.functions as F
from pyspark.sql.types import *

upperUDF = F.udf(lambda z:z.upper())
df.withColumn('Curated Name', upperUDF('Name'))

SparkSQL에 사용

1
2
3
4
5
6
7
8
9
10
def upper(s):
    return s.upper()

# 먼저 테스트
upperUDF = spark.udf.register('upper', upper)
spark.sql("SELECT upper('aBcD')").show()

# DataFrame 기반 SQL에 적용
df.createOrReplaceTempView('test')
spark.sql("""SELECT name, upper(name) "Curated Name" FROM test""").show()

Pandas UDF Scalar 함수

1
2
3
4
5
6
7
8
9
10
11
12
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())   # 각 컬럼의 타입
def upper_udf2(s:pd.Series) -> pd.Series:    
    return s.str.upper()
    # bulk로 처리 -> 더 퍼포먼스가 좋음

upperUDF = spark.udf.register('upper_udf', upper_udf2)

df.select('Name', upperUDF('Name')).show()
spark.sql("""SELECT name, upper_udf(name) 'Curated Name' FROM test""").show()

UDF - DataFrame/SQL에 Aggregation 사용

1
2
3
4
5
6
7
8
9
10
11
from pyspark.sql.functions import pandas_df
import pandas as pd

@pandas_udf(FloatType())
def average(v:pd.Series) -> float:
    return v.mean()

averageUDF = spark.udf.register('average', average)

spark.sql('SELECT average(b) FROM test').show()
df.agg(averageUDF('b').alias('count')).show()

3. Spark SQL 실습

실습 테이블

  • 사용자 ID
    • 보통 웹서비스는 등록된 사용자마다 유일한 ID 부여 -> 사용자 ID
  • 세션 ID
    • 사용자가 외부 링크를 타고 오거나 직접 방문해서 올 경우 세션을 생성
    • 즉 하나의 사용자 ID는 여러 개의 세션 ID를 가질 수 있음
    • 보통 세션의 경우 세션을 만들어낸 소스를 채널이라는 이름으로 기록해둠
      • 마케팅 관련 기여도 분석을 위함
    • 또한, 세션이 생긴 시간도 기록
  • 이 정보를 기반으로 다양한 데이터 분석과 지표 설정이 가능
    • 마케팅 관련
    • 사용자 트래픽 관련


스크린샷 2024-01-17 오전 10 40 47

JOIN

  • 두 개의 테이블 VitalID 기준 JOIN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL #1") \
    .getOrCreate()

vital = [
     { 'UserID': 100, 'VitalID': 1, 'Date': '2020-01-01', 'Weight': 75 },
     { 'UserID': 100, 'VitalID': 2, 'Date': '2020-01-02', 'Weight': 78 },
     { 'UserID': 101, 'VitalID': 3, 'Date': '2020-01-01', 'Weight': 90 },
     { 'UserID': 101, 'VitalID': 4, 'Date': '2020-01-02', 'Weight': 95 },
]

alert = [
    { 'AlertID': 1, 'VitalID': 4, 'AlertType': 'WeightIncrease', 'Date': '2020-01-01', 'UserID': 101},
    { 'AlertID': 2, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-04', 'UserID': 100},
    { 'AlertID': 3, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-05', 'UserID': 101}
]

rdd_vital = spark.sparkContext.parallelize(vital)
rdd_alert = spark.sparkContext.parallelize(alert)

df_vital = rdd_vital.toDF()
df_alert = rdd_alert.toDF()

스크린샷 2024-01-17 오전 10 43 26

DataFrame JOIN

  • Inner Join
1
2
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "inner").show()
DateUserIDVitalIDWeightAlertIDAlertTypeDateUserIDVitalID
2020-01-021014951WeightIncrease2020-01-011014


  • Left Join
1
2
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "left").show()
DateUserIDVitalIDWeightAlertIDAlertTypeDateUserIDVitalID
2020-01-01100175nullnullnullnullnull
2020-01-02100278nullnullnullnullnull
2020-01-01101390nullnullnullnullnull
2020-01-021014951WeightIncrease2020-01-011014


  • Right Join
1
2
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "right").show()
DateUserIDVitalIDWeightAlertIDAlertTypeDateUserIDVitalID
2020-01-021014951WeightIncrease2020-01-011014
nullnullnullnull2MissingVital2020-01-04100null
nullnullnullnull3MissingVital2020-01-05101null


  • Full Outer Join
1
2
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "full").show()
DateUserIDVitalIDWeightAlertIDAlertTypeDateUserIDVitalID
nullnullnullnull2MissingVital2020-01-04100null
nullnullnullnull3MissingVital2020-01-05101null
2020-01-01100175nullnullnullnullnull
2020-01-02100278nullnullnullnullnull
2020-01-01101390nullnullnullnullnull
2020-01-021014951WeightIncrease2020-01-011014


  • Cross Join
1
df_vital.join(df_alert, None, "cross").show()
DateUserIDVitalIDWeightAlertIDAlertTypeDateUserIDVitalID
2020-01-011001751WeightIncrease2020-01-011014
2020-01-021002781WeightIncrease2020-01-011014
2020-01-011001752MissingVital2020-01-04100null
2020-01-011001753MissingVital2020-01-05101null
2020-01-021002782MissingVital2020-01-04100null
2020-01-021002783MissingVital2020-01-05101null
2020-01-011013901WeightIncrease2020-01-011014
2020-01-021014951WeightIncrease2020-01-011014
2020-01-011013902MissingVital2020-01-04100null
2020-01-011013903MissingVital2020-01-05101null
2020-01-021014952MissingVital2020-01-04100null
2020-01-021014953MissingVital2020-01-05101null


  • Self Join
1
2
join_expr = df_vital.VitalID == df_vital.VitalID
df_vital.join(df_vital, join_expr, "left").show()
DateUserIDVitalIDWeightDateUserIDVitalIDWeight
2020-01-011001752020-01-01100175
2020-01-021002782020-01-02100278
2020-01-011013902020-01-01101390
2020-01-021014952020-01-02101495

SQL JOIN

1
2
df_vital.createOrReplaceTempView("Vital")
df_alert.createOrReplaceTempView("Alert")


  • Inner Join
1
2
3
df_inner_join = spark.sql("""SELECT * FROM Vital v
                JOIN Alert a ON v.vitalID = a.vitalID;""")
df_inner_join.show()


  • Left Join
1
2
3
df_left_join = spark.sql("""SELECT * FROM Vital v
                LEFT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_left_join.show()


  • Right Join
1
2
3
df_right_join = spark.sql("""SELECT * FROM Vital v
                RIGHT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_right_join.show()


  • Outer Join
1
2
3
df_outer_join = spark.sql("""SELECT * FROM Vital v
                FULL JOIN Alert a ON v.vitalID = a.vitalID;""")
df_outer_join.show()


  • Cross Join
1
2
3
df_cross_join = spark.sql("""SELECT * FROM Vital v
                CROSS JOIN Alert a""")
df_cross_join.show()


  • Self Join
1
2
3
df_self_join = spark.sql("""SELECT * FROM Vital v1
JOIN Vital v2""")
df_self_join.show()

Ranking

  • refund 여부를 고려하지 않고, 총 매출이 가장 많은 사용자 10명 찾기
필드설명
사용자ID총 매출


  • 3개의 테이블을 각각 데이터프레임으로 로딩
  • 데이터프레임 별로 테이블 이름 지정
  • Spark SQL로 처리
    • 조인 방식 결정
      • 조인키
      • 조인 방식
    • 간단한 지표부터 계산


1
2
3
4
5
6
7
8
9
10
# 데이터는 Redshift에서 가져옴

!pip install pyspark==3.3.1 py4j==0.10.9.5 

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL #1") \
    .getOrCreate()


월별 채널별 매출과 방문자 정보 계산

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Redshift와 연결해서 DataFrame으로 로딩
url = "jdbc:redshift://***.***.ap-northeast-2.redshift.amazonaws.com:5439/dev?user=***&password=***"

df_user_session_channel = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.user_session_channel") \
    .load()

df_session_timestamp = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.session_timestamp") \
    .load()

df_session_transaction = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.session_transaction") \
    .load()

df_user_session_channel.createOrReplaceTempView("user_session_channel")
df_session_timestamp.createOrReplaceTempView("session_timestamp")
df_session_transaction.createOrReplaceTempView("session_transaction")


1
df_user_session_channel.show(5)
useridsessionidchannel
16510004289ee1c7b8b08…Organic
119700053f5e11d1fe4e4…Facebook
140100056c20eb5a02958…Facebook
139900063cb5da1826feb…Facebook
1667000958fdaefe0dd06…Instagram


1
df_session_timestamp.show(5)
sessionidts
00029153d12ae1c9a…2019-10-18 14:14:…
0004289ee1c7b8b08…2019-11-16 21:20:…
0006246bee639c7a7…2019-08-10 16:33:…
0006dd05ea1e999dd…2019-07-06 19:54:…
000958fdaefe0dd06…2019-11-02 14:52:…


1
df_session_transaction.show(5)
sessionidrefundedamount
00029153d12ae1c9a…false85
008909bd27b680698…false13
0107acb41ef20db22…false16
018544a2c48077d2c…false39
020c38173caff0203…false61

총 매출이 가장 많은 사용자 10명 찾기

  • Inner Join / Left(Right) Join 모두 가능


  • revenue(매출액)으로 order
1
2
3
4
5
6
7
8
9
10
11
top_rev_user_df = spark.sql("""
    SELECT userid,
        SUM(str.amount) revenue,
        SUM(CASE WHEN str.refunded = False THEN str.amount END) net_revenue
    FROM user_session_channel usc
    JOIN session_transaction str ON usc.sessionid = str.sessionid
    GROUP BY 1
    ORDER BY 2 DESC
    LIMIT 10""")

top_rev_user_df.show()
useridrevenuenet_revenue
989743743
772556556
1615506506
654488488
1651463463
973438438
262422422
1099421343
2682414414
891412412


  • rank 이용
1
2
3
4
5
6
7
8
9
10
11
12
top_rev_user_df2 = spark.sql("""
SELECT
    userid,
    SUM(amount) total_amount, 
    RANK() OVER (ORDER BY SUM(amount) DESC) rank
FROM session_transaction st
JOIN user_session_channel usc ON st.sessionid = usc.sessionid
GROUP BY userid
ORDER BY rank
LIMIT 10""")

top_rev_user_df2.show()
useridtotal_amountrank
9897431
7725562
16155063
6544884
16514635
9734386
2624227
10994218
26824149
89141210

월별 채널별 매출과 방문자 정보 계산하기

  • 연도-월, 채널, 총 매출액, 순매출액, 총 방문자, 매출발생 방문자, 매출 변환률 (매출발생 방문자 / 총 방문자)


  • 중요) 데이터를 항상 의심하기!
    • join key가 정말 Unique한지!
    • 아래 sql을 실행했을 때 count 값이 1이 아니면 unique하지 않은 것!
1
2
3
4
5
spark.sql("""SELECT sessionid, COUNT(1) count
FROM user_session_channel
GROUP BY 1
ORDER BY 2 DESC
LIMIT 1""").show() 


  • 월별 채널별 총 매출액, 총 방문자, 매출발생 방문자, 변환률
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 mon_channel_rev_df = spark.sql("""
  SELECT LEFT(ts, 7) month,
       usc.channel,
       COUNT(DISTINCT userid) uniqueUsers,
       COUNT(DISTINCT (CASE WHEN amount >= 0 THEN userid END)) paidUsers,
       SUM(amount) grossRevenue,
       SUM(CASE WHEN refunded is not True THEN amount END) netRevenue,
       ROUND(COUNT(DISTINCT CASE WHEN amount >= 0 THEN userid END)*100
          / COUNT(DISTINCT userid), 2) conversionRate
   FROM user_session_channel usc
   LEFT JOIN session_timestamp t ON t.sessionid = usc.sessionid
   LEFT JOIN session_transaction st ON st.sessionid = usc.sessionid
   GROUP BY 1, 2
   ORDER BY 1, 2;
""")

사용자별 처음 채널과 마지막 채널 알아내기

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
first_last_channel_df = spark.sql("""
WITH RECORD AS (
  SELECT /*사용자의 유입에 따른, 채널 순서 매기는 쿼리*/
      userid,
      channel, 
      ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts ASC) AS seq_first,
      ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts DESC) AS seq_last
  FROM user_session_channel u
  LEFT JOIN session_timestamp t
    ON u.sessionid = t.sessionid
)
SELECT /*유저의 첫번째 유입채널, 마지막 유입 채널 구하기*/
      f.userid,
      f.channel first_channel,
      l.channel last_channel
FROM RECORD f
INNER JOIN RECORD l ON f.userid = l.userid
WHERE f.seq_first = 1 and l.seq_last = 1
ORDER BY userid
""")

또는

1
2
3
4
5
6
7
8
9
first_last_channel_df2 = spark.sql("""
SELECT DISTINCT A.userid,
    FIRST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS First_Channel,
    LAST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS Last_Channel
FROM user_session_channel A
LEFT JOIN session_timestamp B
ON A.sessionid = B.sessionid""")


useridfirst_channellast_channel
27YoutubeInstagram
29NaverNaver
33GoogleYoutube
34YoutubeNaver
36NaverYoutube
40YoutubeGoogle
41FacebookYoutube
44NaverInstagram
45YoutubeInstagram
59InstagramInstagram

Window 함수 - ROWS BETWEEN AND

  • window 함수는 기본적으로 레코드 수를 바꾸는 것이 아니라, 새로운 컬럼을 만드는 것
1
2
3
4
5
6
7
SELECT
    SUM(value) OVER(
        order by value
        rows between unbounded preceding and 2 following  
        -- unbounded: 개수 제한을 두지 않음
    ) AS rolling_sum
FROM rows_test;
valuerolling_sum
16
210
315
415
515

4. Hive 메타스토어

Spark 데이터베이스와 테이블

  • 카탈로그: 테이블과 뷰에 관한 메타 데이터 관리
    • 기본으로 메모리 기반 카탈로그 제공 - 세션이 끝나면 사라짐
    • Hive와 호환되는 카탈로그 제공 - Persistent
  • 테이블 관리 방식
    • 테이블들은 데이터베이스라 부르는 폴더와 같은 구조로 관리 (2단계)

스크린샷 2024-01-19 오후 1 11 54


  • 메모리 기반 테이블/뷰
    • 임시 테이블로 사용
  • 스토리지 기반 테이블
    • 기본적으로 HDFS와 Parquet 포맷 사용
    • Hive와 호환되는 메타스토어 사용
    • 두 종류의 테이블이 존재 (Hive와 동일한 개념)
      • Managed Table: Spark가 실제 데이터와 메타 데이터 모두 관리
      • Unmanaged (External) Table: Spark가 메타 데이터만 관리

Spark SQL - 스토리지 기반 카탈로그

  • Hive와 호환되는 메타스토어 사용
  • SparkSession 생성 시 enableHiveSupport() 호출
    • 기본으로 default라는 이름의 데이터베이스 생성
1
2
3
4
5
6
7
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark Hive") \
    .enableHiveSupport() \
    .getOrCreate()

Spark SQL - Managed Table

  • 생성 방법
    • dataframe.saveAsTable('table')
    • SQL 문법 사용 (CREATE TABLE, CTAS)
  • spark.sql.warehouse.dir이 가리키는 위치에 데이터가 저장됨
    • Parquet이 기본 데이터 포맷
  • 선호하는 테이블 타입
  • Spark 테이블로 처리하는 것의 장점
    • JDBC/ODBC 등으로 Spark을 연결하여 접근 가능 (태블로, Power BI)

Spark SQL - External Table

  • 이미 HDFS에 존재하는 데이터에 스키마를 정의하여 사용
    • LOCATION이라는 프로퍼티 사용
  • 메타데이터만 카탈로그에 기록됨
    • 데이터는 이미 존재
    • External Table은 삭제되어도 데이터는 그대로!
1
2
3
4
5
6
7
8
CREATE TABLE table_name (
    column1 type1,
    column2 type2,
    column3 type3,
    ...
)
USING PARQUET
LOCATION 'hdfs_path';

실습

  • DataFrame을 Managed Table로 저장
  • 새로운 데이터베이스 사용
  • Spark SQL로 Managed Table 사용 (CTAS)


  • 데이터베이스 생성
1
2
spark.sql("CREATE DATABASE IF NOT EXISTS TEST_DB")
spark.sql("USE TEST_DB")

스크린샷 2024-01-27 오후 7 59 19


스크린샷 2024-01-27 오후 7 57 20

  • metastor_db : Spark 메타스토어 - Hive 메타스토어와 호환
  • spark-warehouse : HDFS 폴더에 해당
    • spark에서 managed table을 만들면 여기에 저장됨


  • 데이터베이스에 테이블 생성
  • 기본 parquet 형식
1
df.write.saveAsTable("TEST_DB.orders", mode="overwrite")

스크린샷 2024-01-27 오후 8 00 05


  • 테이블 값 읽기
1
2
3
spark.sql("SELECT * FROM TEST_DB.orders").show()

sparkt.table("TEST_DB.orders").show()

스크린샷 2024-01-27 오후 8 01 50


  • spark catalog
    • catalog가 인메모리가 아닌 HDFS에 영구적으로 저장되는 메타스토어
1
spark.catalog.listTables() 

스크린샷 2024-01-27 오후 8 04 50

  • isTemporary=False : 임시테이블이 아님
  • tableType='MANAGED" : managed table


  • CTAS로 테이블 생성
1
2
3
4
5
spark.sql("""
    CREATE TABLE IF NOT EXISTS TEST_DB.orders_count AS 
    SELECT order_id, COUNT(1) as count 
    FROM TEST_DB.orders
    GROUP BY 1""")

스크린샷 2024-01-27 오후 8 07 13

스크린샷 2024-01-27 오후 8 07 44

5. 유닛테스트

  • 코드 상의 특정 기능 (보통 메소드 형태)을 테스트하기 위해 작성된 코드
  • 보통 정해진 입력을 주고 예상된 출력이 나오는지 테스트
  • CI/CD를 사용하려면 전체 코드의 테스트 coverage가 매우 중요해짐 (7-80% 이상)
  • 각 언어별로 정해진 테스트 프레임워크를 사용하는 것이 일반적
    • JAVA : JUnit
    • .NET : NUnit
    • Python : unittest


  • 실제 환경에서
    • 내 코드를 어떻게 짜면 테스트하기 쉬울까 고민!
    • 함수 - input, output
    • 작은 이슈, 큰 이슈가 생길 때마다 어떻게 테스트를 했으면 이슈를 미연에 방지할 수 있었을까
    • TDD (Test Driven Development)
      • 코드 작성 전 테스트 코드를 먼저 만들어보고 그것에 맞춰 함수, 기능을 채워나가는 논리


실습

  • 일반적으로 colab에서 테스트를 돌리지는 않음
    • test 코드를 따로 만든 다음 해당 코드로 테스트 할 함수를 import해서 사용하는 것이 일반적


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from unittest import TestCase

"""
일반적으로는 아래 함수가 정의된 모듈을 임포트하고 그걸 테스트
 - upper_udf_f
 - load_gender
 - get_gender_count

이외에도 2가지 방법 더 존재
 - from pyspark.sql.tests import SparkTestingBase
 - pytest-spark (pytest testing framework plugin)
"""

class UtilsTestCase(TestCase):
    spark = None

    @classmethod
    def setUpClass(cls) -> None:
        cls.spark = SparkSession.builder \
            .appName("Spark Unit Test") \
            .getOrCreate()

    def test_datafile_loading(self):
        sample_df = load_gender(self.spark, "name_gender.csv")
        result_count = sample_df.count()
        self.assertEqual(result_count, 100, "Record count should be 100")

    def test_gender_count(self):
        sample_df = load_gender(self.spark, "name_gender.csv")
        count_list = get_gender_count(self.spark, sample_df, "gender").collect()
        count_dict = dict()
        for row in count_list:
            count_dict[row["gender"]] = row["count"]
        self.assertEqual(count_dict["F"], 65, "Count for F should be 65")
        self.assertEqual(count_dict["M"], 28, "Count for M should be 28")
        self.assertEqual(count_dict["Unisex"], 7, "Count for Unisex should be 7")

    def test_upper_udf(self):
        test_data = [
            { "name": "John Kim" },
            { "name": "Johnny Kim"},
            { "name": "1234" }
        ]
        expected_results = [ "JOHN KIM", "JOHNNY KIM", "1234" ]

        upperUDF = self.spark.udf.register("upper_udf", upper_udf_f)
        test_df = self.spark.createDataFrame(test_data)
        names = test_df.select("name", upperUDF("name").alias("NAME")).collect()
        results = []
        for name in names:
            results.append(name["NAME"])
        self.assertCountEqual(results, expected_results)

    @classmethod
    def tearDownClass(cls) -> None:
        cls.spark.stop()
1
2
3
import unittest

unittest.main(argv=[''], verbosity=2, exit=False)

스크린샷 2024-01-27 오후 8 29 17

This post is licensed under CC BY 4.0 by the author.