1. Spark SQL

  • 구조화된 데이터 처리를 위한 Spark 모듈
  • 데이터프레임 작업을 SQL로 처리 가능
    • 데이터프레임에 테이블 이름 지정 후 sql 함수 사용 가능
      • Pandas에도 pandasql 모듈의 spldf 함수를 이용하는 동일한 패턴 존재
    • HQL (Hive QL)과 호환 제공
      • Hive 테이블들을 읽고 쓸 수 있음 (Hive Metastore)
      • 보통 Hive와 Spark 시스템을 동시에 사용하는 것이 일반적 (YARN 위에서)

Spark SQL vs. DataFrame

  • SQL로 가능한 작업이면 DataFrame을 사용할 이유가 없음
    • 하지만 두 개를 동시에 사용할 수 있다는 점 기억할 것!
  • Familiarity / Readability
    • SQL이 더 가독성이 좋고 많은 사람들이 사용 가능
  • Optimization
    • Spark SQL 엔진이 최적화하기 더 좋음 (SQL은 Declarative)
    • Catalyst Optimizer와 Project Tungsten
  • Interoperability / Data Management
    • SQL이 포팅도 쉽고 접근권한 체크도 쉬움

Spark SQL 사용 방법

  • 데이터프레임을 기반으로 테이블 뷰 생성 : 테이블이 만들어짐
    • createOrReplaceTempView : Spark Session이 살아있는 동안 존재
    • createOrReplaceGlobalTempView : Spark 드라이버가 살아있는 동안 존재
  • Spark Session의 sql 함수로 SQL 결과를 데이터프레임으로 받음


namegender_df.createOrReplaceTempView('namegender')
namegender_group_df = spark.sql("""
    SELECT gender, count(1) FROM namegender GROUP BY 1
""")
print(namegender_group_df.collect())

SparkSession 외부 데이터베이스 연결

  • Spark Session의 read 함수 호출
    • 로그인 관련 정보와 읽어오고자 하는 테이블 혹은 SQL 지정
  • 결과가 데이터프레임으로 반환됨


df_user_session_channel = spark.read\
    .format('jdbc')\
    .option('driver', 'com.amazon.redshift.jdbc42.Driver')\
    .option('url', 'jdbc:redshift://HOST:PORT/DB?user=ID&password=PASSWORD')\
    .option('dbtable', 'raw_data.user_session_channel')   # SELECT문도 가능  
    .load()

Aggregation

  • DataFrame이 아닌 SQL로 작성하는 것 추천

  • GroupBy
  • Window
  • Rank

JOIN

  • 두 개 혹은 그 이상의 테이블들을 공통 필드를 가지고 머지
  • 스타 스키마로 구성된 테이블들로 분산되어 있던 정보를 통합하는데 사용
  • 왼쪽 테이블을 LEFT, 오른쪽 테이블을 RIGHT라고 하면
    • 결과는 방식에 따라 양쪽 필드를 모두 가진 새로운 테이블 생성
    • 방식에 따라 두 가지가 달라짐
      • 어떤 레코드들이 선택되는지
      • 어떤 필드들이 채워지는지

스크린샷 2024-01-17 오전 9 21 44

최적화 관점에서 본 조인의 종류들

  • Shuffle Join
    • 일반 조인 방식
    • Bucket Join: 조인 키를 바탕으로 새로운 파티션을 만들고 조인하는 방식
  • Broadcast Join
    • 큰 데이터와 작은 데이터 간의 조인
    • 데이터프레임 하나가 충분이 작으면 작은 데이터프레임을 다른 데이터프레임이 있는 서버들로 뿌리는 것 (broadcasting)
      • spark.sql.autoBroadcastJoinThreshold 파라미터로 충분히 작은지 여부 결정

2. UDF (User Defined Function)

  • 데이터프레임의 경우 .withColumn 함수와 같이 사용하는 것이 일반적
    • SparkSQL에서도 사용 가능
  • Aggregation용 UDAF (User Defined Aggregation Function)도 존재
    • GROUP BY 에서 사용되는 SUM, AVG와 같은 함수를 만드는 것
    • PySpark에서 지원되지 않음. Scalar/Java를 사용해야 함

DataFrame에 사용

import pyspark.sql.functions as F
from pyspark.sql.types import *

upperUDF = F.udf(lambda z:z.upper())
df.withColumn('Curated Name', upperUDF('Name'))

SparkSQL에 사용

def upper(s):
    return s.upper()

# 먼저 테스트
upperUDF = spark.udf.register('upper', upper)
spark.sql("SELECT upper('aBcD')").show()

# DataFrame 기반 SQL에 적용
df.createOrReplaceTempView('test')
spark.sql("""SELECT name, upper(name) "Curated Name" FROM test""").show()

Pandas UDF Scalar 함수

from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())   # 각 컬럼의 타입
def upper_udf2(s:pd.Series) -> pd.Series:    
    return s.str.upper()
    # bulk로 처리 -> 더 퍼포먼스가 좋음

upperUDF = spark.udf.register('upper_udf', upper_udf2)

df.select('Name', upperUDF('Name')).show()
spark.sql("""SELECT name, upper_udf(name) 'Curated Name' FROM test""").show()

UDF - DataFrame/SQL에 Aggregation 사용

from pyspark.sql.functions import pandas_df
import pandas as pd

@pandas_udf(FloatType())
def average(v:pd.Series) -> float:
    return v.mean()

averageUDF = spark.udf.register('average', average)

spark.sql('SELECT average(b) FROM test').show()
df.agg(averageUDF('b').alias('count')).show()

3. Spark SQL 실습

실습 테이블

  • 사용자 ID
    • 보통 웹서비스는 등록된 사용자마다 유일한 ID 부여 -> 사용자 ID
  • 세션 ID
    • 사용자가 외부 링크를 타고 오거나 직접 방문해서 올 경우 세션을 생성
    • 즉 하나의 사용자 ID는 여러 개의 세션 ID를 가질 수 있음
    • 보통 세션의 경우 세션을 만들어낸 소스를 채널이라는 이름으로 기록해둠
      • 마케팅 관련 기여도 분석을 위함
    • 또한, 세션이 생긴 시간도 기록
  • 이 정보를 기반으로 다양한 데이터 분석과 지표 설정이 가능
    • 마케팅 관련
    • 사용자 트래픽 관련


스크린샷 2024-01-17 오전 10 40 47

JOIN

  • 두 개의 테이블 VitalID 기준 JOIN
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL #1") \
    .getOrCreate()

vital = [
     { 'UserID': 100, 'VitalID': 1, 'Date': '2020-01-01', 'Weight': 75 },
     { 'UserID': 100, 'VitalID': 2, 'Date': '2020-01-02', 'Weight': 78 },
     { 'UserID': 101, 'VitalID': 3, 'Date': '2020-01-01', 'Weight': 90 },
     { 'UserID': 101, 'VitalID': 4, 'Date': '2020-01-02', 'Weight': 95 },
]

alert = [
    { 'AlertID': 1, 'VitalID': 4, 'AlertType': 'WeightIncrease', 'Date': '2020-01-01', 'UserID': 101},
    { 'AlertID': 2, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-04', 'UserID': 100},
    { 'AlertID': 3, 'VitalID': None, 'AlertType': 'MissingVital', 'Date': '2020-01-05', 'UserID': 101}
]

rdd_vital = spark.sparkContext.parallelize(vital)
rdd_alert = spark.sparkContext.parallelize(alert)

df_vital = rdd_vital.toDF()
df_alert = rdd_alert.toDF()

스크린샷 2024-01-17 오전 10 43 26

DataFrame JOIN

  • Inner Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "inner").show()
Date UserID VitalID Weight AlertID AlertType Date UserID VitalID
2020-01-02 101 4 95 1 WeightIncrease 2020-01-01 101 4


  • Left Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "left").show()
Date UserID VitalID Weight AlertID AlertType Date UserID VitalID
2020-01-01 100 1 75 null null null null null
2020-01-02 100 2 78 null null null null null
2020-01-01 101 3 90 null null null null null
2020-01-02 101 4 95 1 WeightIncrease 2020-01-01 101 4


  • Right Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "right").show()
Date UserID VitalID Weight AlertID AlertType Date UserID VitalID
2020-01-02 101 4 95 1 WeightIncrease 2020-01-01 101 4
null null null null 2 MissingVital 2020-01-04 100 null
null null null null 3 MissingVital 2020-01-05 101 null


  • Full Outer Join
join_expr = df_vital.VitalID == df_alert.VitalID
df_vital.join(df_alert, join_expr, "full").show()
Date UserID VitalID Weight AlertID AlertType Date UserID VitalID
null null null null 2 MissingVital 2020-01-04 100 null
null null null null 3 MissingVital 2020-01-05 101 null
2020-01-01 100 1 75 null null null null null
2020-01-02 100 2 78 null null null null null
2020-01-01 101 3 90 null null null null null
2020-01-02 101 4 95 1 WeightIncrease 2020-01-01 101 4


  • Cross Join
df_vital.join(df_alert, None, "cross").show()
Date UserID VitalID Weight AlertID AlertType Date UserID VitalID
2020-01-01 100 1 75 1 WeightIncrease 2020-01-01 101 4
2020-01-02 100 2 78 1 WeightIncrease 2020-01-01 101 4
2020-01-01 100 1 75 2 MissingVital 2020-01-04 100 null
2020-01-01 100 1 75 3 MissingVital 2020-01-05 101 null
2020-01-02 100 2 78 2 MissingVital 2020-01-04 100 null
2020-01-02 100 2 78 3 MissingVital 2020-01-05 101 null
2020-01-01 101 3 90 1 WeightIncrease 2020-01-01 101 4
2020-01-02 101 4 95 1 WeightIncrease 2020-01-01 101 4
2020-01-01 101 3 90 2 MissingVital 2020-01-04 100 null
2020-01-01 101 3 90 3 MissingVital 2020-01-05 101 null
2020-01-02 101 4 95 2 MissingVital 2020-01-04 100 null
2020-01-02 101 4 95 3 MissingVital 2020-01-05 101 null


  • Self Join
join_expr = df_vital.VitalID == df_vital.VitalID
df_vital.join(df_vital, join_expr, "left").show()
Date UserID VitalID Weight Date UserID VitalID Weight
2020-01-01 100 1 75 2020-01-01 100 1 75
2020-01-02 100 2 78 2020-01-02 100 2 78
2020-01-01 101 3 90 2020-01-01 101 3 90
2020-01-02 101 4 95 2020-01-02 101 4 95

SQL JOIN

df_vital.createOrReplaceTempView("Vital")
df_alert.createOrReplaceTempView("Alert")


  • Inner Join
df_inner_join = spark.sql("""SELECT * FROM Vital v
                JOIN Alert a ON v.vitalID = a.vitalID;""")
df_inner_join.show()


  • Left Join
df_left_join = spark.sql("""SELECT * FROM Vital v
                LEFT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_left_join.show()


  • Right Join
df_right_join = spark.sql("""SELECT * FROM Vital v
                RIGHT JOIN Alert a ON v.vitalID = a.vitalID;""")
df_right_join.show()


  • Outer Join
df_outer_join = spark.sql("""SELECT * FROM Vital v
                FULL JOIN Alert a ON v.vitalID = a.vitalID;""")
df_outer_join.show()


  • Cross Join
df_cross_join = spark.sql("""SELECT * FROM Vital v
                CROSS JOIN Alert a""")
df_cross_join.show()


  • Self Join
df_self_join = spark.sql("""SELECT * FROM Vital v1
JOIN Vital v2""")
df_self_join.show()

Ranking

  • refund 여부를 고려하지 않고, 총 매출이 가장 많은 사용자 10명 찾기
필드 설명
사용자ID 총 매출


  • 3개의 테이블을 각각 데이터프레임으로 로딩
  • 데이터프레임 별로 테이블 이름 지정
  • Spark SQL로 처리
    • 조인 방식 결정
      • 조인키
      • 조인 방식
    • 간단한 지표부터 계산


# 데이터는 Redshift에서 가져옴

!pip install pyspark==3.3.1 py4j==0.10.9.5 

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL #1") \
    .getOrCreate()


월별 채널별 매출과 방문자 정보 계산

# Redshift와 연결해서 DataFrame으로 로딩
url = "jdbc:redshift://***.***.ap-northeast-2.redshift.amazonaws.com:5439/dev?user=***&password=***"

df_user_session_channel = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.user_session_channel") \
    .load()

df_session_timestamp = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.session_timestamp") \
    .load()

df_session_transaction = spark.read \
    .format("jdbc") \
    .option("driver", "com.amazon.redshift.jdbc42.Driver") \
    .option("url", url) \
    .option("dbtable", "raw_data.session_transaction") \
    .load()

df_user_session_channel.createOrReplaceTempView("user_session_channel")
df_session_timestamp.createOrReplaceTempView("session_timestamp")
df_session_transaction.createOrReplaceTempView("session_transaction")


df_user_session_channel.show(5)
userid sessionid channel
1651 0004289ee1c7b8b08… Organic
1197 00053f5e11d1fe4e4… Facebook
1401 00056c20eb5a02958… Facebook
1399 00063cb5da1826feb… Facebook
1667 000958fdaefe0dd06… Instagram


df_session_timestamp.show(5)
sessionid ts
00029153d12ae1c9a… 2019-10-18 14:14:…
0004289ee1c7b8b08… 2019-11-16 21:20:…
0006246bee639c7a7… 2019-08-10 16:33:…
0006dd05ea1e999dd… 2019-07-06 19:54:…
000958fdaefe0dd06… 2019-11-02 14:52:…


df_session_transaction.show(5)
sessionid refunded amount
00029153d12ae1c9a… false 85
008909bd27b680698… false 13
0107acb41ef20db22… false 16
018544a2c48077d2c… false 39
020c38173caff0203… false 61

총 매출이 가장 많은 사용자 10명 찾기

  • Inner Join / Left(Right) Join 모두 가능


  • revenue(매출액)으로 order
top_rev_user_df = spark.sql("""
    SELECT userid,
        SUM(str.amount) revenue,
        SUM(CASE WHEN str.refunded = False THEN str.amount END) net_revenue
    FROM user_session_channel usc
    JOIN session_transaction str ON usc.sessionid = str.sessionid
    GROUP BY 1
    ORDER BY 2 DESC
    LIMIT 10""")

top_rev_user_df.show()
userid revenue net_revenue
989 743 743
772 556 556
1615 506 506
654 488 488
1651 463 463
973 438 438
262 422 422
1099 421 343
2682 414 414
891 412 412


  • rank 이용
top_rev_user_df2 = spark.sql("""
SELECT
    userid,
    SUM(amount) total_amount, 
    RANK() OVER (ORDER BY SUM(amount) DESC) rank
FROM session_transaction st
JOIN user_session_channel usc ON st.sessionid = usc.sessionid
GROUP BY userid
ORDER BY rank
LIMIT 10""")

top_rev_user_df2.show()
userid total_amount rank
989 743 1
772 556 2
1615 506 3
654 488 4
1651 463 5
973 438 6
262 422 7
1099 421 8
2682 414 9
891 412 10

월별 채널별 매출과 방문자 정보 계산하기

  • 연도-월, 채널, 총 매출액, 순매출액, 총 방문자, 매출발생 방문자, 매출 변환률 (매출발생 방문자 / 총 방문자)


  • 중요) 데이터를 항상 의심하기!
    • join key가 정말 Unique한지!
    • 아래 sql을 실행했을 때 count 값이 1이 아니면 unique하지 않은 것!
spark.sql("""SELECT sessionid, COUNT(1) count
FROM user_session_channel
GROUP BY 1
ORDER BY 2 DESC
LIMIT 1""").show() 


  • 월별 채널별 총 매출액, 총 방문자, 매출발생 방문자, 변환률
 mon_channel_rev_df = spark.sql("""
  SELECT LEFT(ts, 7) month,
       usc.channel,
       COUNT(DISTINCT userid) uniqueUsers,
       COUNT(DISTINCT (CASE WHEN amount >= 0 THEN userid END)) paidUsers,
       SUM(amount) grossRevenue,
       SUM(CASE WHEN refunded is not True THEN amount END) netRevenue,
       ROUND(COUNT(DISTINCT CASE WHEN amount >= 0 THEN userid END)*100
          / COUNT(DISTINCT userid), 2) conversionRate
   FROM user_session_channel usc
   LEFT JOIN session_timestamp t ON t.sessionid = usc.sessionid
   LEFT JOIN session_transaction st ON st.sessionid = usc.sessionid
   GROUP BY 1, 2
   ORDER BY 1, 2;
""")

사용자별 처음 채널과 마지막 채널 알아내기

first_last_channel_df = spark.sql("""
WITH RECORD AS (
  SELECT /*사용자의 유입에 따른, 채널 순서 매기는 쿼리*/
      userid,
      channel, 
      ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts ASC) AS seq_first,
      ROW_NUMBER() OVER (PARTITION BY userid ORDER BY ts DESC) AS seq_last
  FROM user_session_channel u
  LEFT JOIN session_timestamp t
    ON u.sessionid = t.sessionid
)
SELECT /*유저의 첫번째 유입채널, 마지막 유입 채널 구하기*/
      f.userid,
      f.channel first_channel,
      l.channel last_channel
FROM RECORD f
INNER JOIN RECORD l ON f.userid = l.userid
WHERE f.seq_first = 1 and l.seq_last = 1
ORDER BY userid
""")

또는

first_last_channel_df2 = spark.sql("""
SELECT DISTINCT A.userid,
    FIRST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS First_Channel,
    LAST_VALUE(A.channel) over(partition by A.userid order by B.ts
rows between unbounded preceding and unbounded following) AS Last_Channel
FROM user_session_channel A
LEFT JOIN session_timestamp B
ON A.sessionid = B.sessionid""")


userid first_channel last_channel
27 Youtube Instagram
29 Naver Naver
33 Google Youtube
34 Youtube Naver
36 Naver Youtube
40 Youtube Google
41 Facebook Youtube
44 Naver Instagram
45 Youtube Instagram
59 Instagram Instagram

Window 함수 - ROWS BETWEEN AND

  • window 함수는 기본적으로 레코드 수를 바꾸는 것이 아니라, 새로운 컬럼을 만드는 것
SELECT
    SUM(value) OVER(
        order by value
        rows between unbounded preceding and 2 following  
        -- unbounded: 개수 제한을 두지 않음
    ) AS rolling_sum
FROM rows_test;
value rolling_sum
1 6
2 10
3 15
4 15
5 15

4. Hive 메타스토어

Spark 데이터베이스와 테이블

  • 카탈로그: 테이블과 뷰에 관한 메타 데이터 관리
    • 기본으로 메모리 기반 카탈로그 제공 - 세션이 끝나면 사라짐
    • Hive와 호환되는 카탈로그 제공 - Persistent
  • 테이블 관리 방식
    • 테이블들은 데이터베이스라 부르는 폴더와 같은 구조로 관리 (2단계)

스크린샷 2024-01-19 오후 1 11 54


  • 메모리 기반 테이블/뷰
    • 임시 테이블로 사용
  • 스토리지 기반 테이블
    • 기본적으로 HDFS와 Parquet 포맷 사용
    • Hive와 호환되는 메타스토어 사용
    • 두 종류의 테이블이 존재 (Hive와 동일한 개념)
      • Managed Table: Spark가 실제 데이터와 메타 데이터 모두 관리
      • Unmanaged (External) Table: Spark가 메타 데이터만 관리

Spark SQL - 스토리지 기반 카탈로그

  • Hive와 호환되는 메타스토어 사용
  • SparkSession 생성 시 enableHiveSupport() 호출
    • 기본으로 default라는 이름의 데이터베이스 생성
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark Hive") \
    .enableHiveSupport() \
    .getOrCreate()

Spark SQL - Managed Table

  • 생성 방법
    • dataframe.saveAsTable('table')
    • SQL 문법 사용 (CREATE TABLE, CTAS)
  • spark.sql.warehouse.dir이 가리키는 위치에 데이터가 저장됨
    • Parquet이 기본 데이터 포맷
  • 선호하는 테이블 타입
  • Spark 테이블로 처리하는 것의 장점
    • JDBC/ODBC 등으로 Spark을 연결하여 접근 가능 (태블로, Power BI)

Spark SQL - External Table

  • 이미 HDFS에 존재하는 데이터에 스키마를 정의하여 사용
    • LOCATION이라는 프로퍼티 사용
  • 메타데이터만 카탈로그에 기록됨
    • 데이터는 이미 존재
    • External Table은 삭제되어도 데이터는 그대로!
CREATE TABLE table_name (
    column1 type1,
    column2 type2,
    column3 type3,
    ...
)
USING PARQUET
LOCATION 'hdfs_path';

실습

  • DataFrame을 Managed Table로 저장
  • 새로운 데이터베이스 사용
  • Spark SQL로 Managed Table 사용 (CTAS)


  • 데이터베이스 생성
spark.sql("CREATE DATABASE IF NOT EXISTS TEST_DB")
spark.sql("USE TEST_DB")

스크린샷 2024-01-27 오후 7 59 19


스크린샷 2024-01-27 오후 7 57 20

  • metastor_db : Spark 메타스토어 - Hive 메타스토어와 호환
  • spark-warehouse : HDFS 폴더에 해당
    • spark에서 managed table을 만들면 여기에 저장됨


  • 데이터베이스에 테이블 생성
  • 기본 parquet 형식
df.write.saveAsTable("TEST_DB.orders", mode="overwrite")

스크린샷 2024-01-27 오후 8 00 05


  • 테이블 값 읽기
spark.sql("SELECT * FROM TEST_DB.orders").show()

sparkt.table("TEST_DB.orders").show()

스크린샷 2024-01-27 오후 8 01 50


  • spark catalog
    • catalog가 인메모리가 아닌 HDFS에 영구적으로 저장되는 메타스토어
spark.catalog.listTables() 

스크린샷 2024-01-27 오후 8 04 50

  • isTemporary=False : 임시테이블이 아님
  • tableType='MANAGED" : managed table


  • CTAS로 테이블 생성
spark.sql("""
    CREATE TABLE IF NOT EXISTS TEST_DB.orders_count AS 
    SELECT order_id, COUNT(1) as count 
    FROM TEST_DB.orders
    GROUP BY 1""")

스크린샷 2024-01-27 오후 8 07 13

스크린샷 2024-01-27 오후 8 07 44

5. 유닛테스트

  • 코드 상의 특정 기능 (보통 메소드 형태)을 테스트하기 위해 작성된 코드
  • 보통 정해진 입력을 주고 예상된 출력이 나오는지 테스트
  • CI/CD를 사용하려면 전체 코드의 테스트 coverage가 매우 중요해짐 (7-80% 이상)
  • 각 언어별로 정해진 테스트 프레임워크를 사용하는 것이 일반적
    • JAVA : JUnit
    • .NET : NUnit
    • Python : unittest


  • 실제 환경에서
    • 내 코드를 어떻게 짜면 테스트하기 쉬울까 고민!
    • 함수 - input, output
    • 작은 이슈, 큰 이슈가 생길 때마다 어떻게 테스트를 했으면 이슈를 미연에 방지할 수 있었을까
    • TDD (Test Driven Development)
      • 코드 작성 전 테스트 코드를 먼저 만들어보고 그것에 맞춰 함수, 기능을 채워나가는 논리


실습

  • 일반적으로 colab에서 테스트를 돌리지는 않음
    • test 코드를 따로 만든 다음 해당 코드로 테스트 할 함수를 import해서 사용하는 것이 일반적


from unittest import TestCase

"""
일반적으로는 아래 함수가 정의된 모듈을 임포트하고 그걸 테스트
 - upper_udf_f
 - load_gender
 - get_gender_count

이외에도 2가지 방법 더 존재
 - from pyspark.sql.tests import SparkTestingBase
 - pytest-spark (pytest testing framework plugin)
"""

class UtilsTestCase(TestCase):
    spark = None

    @classmethod
    def setUpClass(cls) -> None:
        cls.spark = SparkSession.builder \
            .appName("Spark Unit Test") \
            .getOrCreate()

    def test_datafile_loading(self):
        sample_df = load_gender(self.spark, "name_gender.csv")
        result_count = sample_df.count()
        self.assertEqual(result_count, 100, "Record count should be 100")

    def test_gender_count(self):
        sample_df = load_gender(self.spark, "name_gender.csv")
        count_list = get_gender_count(self.spark, sample_df, "gender").collect()
        count_dict = dict()
        for row in count_list:
            count_dict[row["gender"]] = row["count"]
        self.assertEqual(count_dict["F"], 65, "Count for F should be 65")
        self.assertEqual(count_dict["M"], 28, "Count for M should be 28")
        self.assertEqual(count_dict["Unisex"], 7, "Count for Unisex should be 7")

    def test_upper_udf(self):
        test_data = [
            { "name": "John Kim" },
            { "name": "Johnny Kim"},
            { "name": "1234" }
        ]
        expected_results = [ "JOHN KIM", "JOHNNY KIM", "1234" ]

        upperUDF = self.spark.udf.register("upper_udf", upper_udf_f)
        test_df = self.spark.createDataFrame(test_data)
        names = test_df.select("name", upperUDF("name").alias("NAME")).collect()
        results = []
        for name in names:
            results.append(name["NAME"])
        self.assertCountEqual(results, expected_results)

    @classmethod
    def tearDownClass(cls) -> None:
        cls.spark.stop()
import unittest

unittest.main(argv=[''], verbosity=2, exit=False)

스크린샷 2024-01-27 오후 8 29 17