본문 바로가기
[업무 지식]/Seaborn

[scatterplot] 한번에 모든 컬럼 시각화

by 에디터 윤슬 2024. 11. 27.
import matplotlib.pyplot as plt
import seaborn as sns

def get_scatter(df, target_column=None):
    """
    데이터프레임의 열들 간 산점도를 플롯합니다.
    
    매개변수:
        - df: pandas DataFrame, 데이터를 포함한 데이터프레임
        - target_column: str, 특정 타겟 열과 나머지 열 간의 산점도를 그릴 경우 지정 (기본값: None)
    """
    columns = df.columns
    num_columns = len(columns)
    
    if target_column:
        # 특정 타겟 열과 나머지 열 간의 산점도만 생성
        columns = [col for col in df.columns if col != target_column]
        num_plots = len(columns)
        rows = (num_plots // 5) + (num_plots % 5 > 0)
        
        plt.figure(figsize=(20, rows * 5))
        for i, column in enumerate(columns):
            ax = plt.subplot(rows, 5, i + 1)
            sns.scatterplot(data=df, x=column, y=target_column)
            plt.xlabel(column, fontsize=12)
            plt.ylabel(target_column, fontsize=12)
        
    else:
        # 모든 열 쌍에 대해 산점도 생성
        num_plots = num_columns * (num_columns - 1) // 2  # 조합 개수 계산
        rows = (num_plots // 5) + (num_plots % 5 > 0)
        
        plt.figure(figsize=(20, rows * 5))
        plot_idx = 1
        for i in range(num_columns):
            for j in range(i + 1, num_columns):
                ax = plt.subplot(rows, 5, plot_idx)
                sns.scatterplot(data=df, x=df.columns[i], y=df.columns[j])
                plt.xlabel(df.columns[i], fontsize=12)
                plt.ylabel(df.columns[j], fontsize=12)
                plot_idx += 1

    plt.tight_layout()
    plt.show()
    
get_scatter(df)