본문 바로가기
[업무 지식]/Machine learning

[선형회귀] 회귀분석

by 에디터 윤슬 2024. 11. 18.

자주 쓰는 함수

  • sklearn.linear_model.LinearRegression: 선형회귀 모델 클래스
    • coef_: 회귀 계수
    • intercept: 편향(bias)
    • fit: 데이터 학습
    • predict: 데이터 예측

선형회귀 실습

'tips' 데이터를 가지고 전체 금액(X)를 알면 받을 수 있는 팁(y)에 대한 회귀분석을 진행한다.

  • seaborn 시각화 라이브러리 데이터셋 'tips'
tips_df = sns.load_dataset('tips')
tips_df.head(3)

  • 전체 금액(total_bill)과 팁(tip)과의 선형성을 산점도로 확인한다
sns.scatterplot(data = tips_df, x = 'total_bill', y = 'tip')

  • 종속변수(y), 독립변수(x) 설정
# X: total_bill  -- X, 대문자
# y: tip

model_lr = LinearRegression()
X = tips_df[['total_bill']]
y = tips_df[['tip']]
model_lr.fit(X, y)

 

  • 편향(w0), 회귀계수(w1) 설정
# y(tip) = w1 * x(total_bill) + w0

w1_tip = model_lr.coef_[0][0]
w0_tip = model_lr.intercept_[0]

print('y = {}x + {}'.format(w1_tip.round(2), w0_tip.round(2)))

y = 0.11x + 0.92

# 전체 결제금액이 1달러 오를때, 팁은 0.11달러 추가된다.
# 전체 결제금액이 100달러 오를때, 팁은 11달러 추가된다.
  • 예측값 생성
y_true_tip = tips_df['tip']
y_pred_tip = model_lr.predict(tips_df[['total_bill']])

y_true_tip[:5]
y_pred_tip[:5]

 

  • MSE, rsquare 설정
mean_squared_error(y_true_tip, y_pred_tip)
1.036019442011377

r2_score(y_true_tip, y_pred_tip)
0.45661658635167657
  • 예측값 df에 추가
tips_df['pred'] = y_pred_tip
tips_df.head(3)

  • 산점도와 라인플롯 시각화
sns.scatterplot(data = tips_df, x = 'total_bill', y = 'tip')
sns.lineplot(data = tips_df, x = 'total_bill', y = 'pred', color = 'red')

 

범주형 데이터 선형회귀

  • 범주형 데이터 숫자로 인코딩
# female 0 , male 1

def get_sex(x):
	if x == 'Female':
    	return 0
    else:
    	return 1
  • apply 함수 적용
tips_df['sex_en'] = tips_df['sex'].apply(get_sex)
tips_df.head(3)

  • 모델 설계도 가져오기, 학습, 평가
model_lr2 = LinearRegrssion()
X = tips_df[['total_bill', 'sex_en']]
y = tips_df[['tip']]
model_lr2.fit(X, y)
  • 예측
y_pred_tip2 = model_lr2.predict(X)
y_pred_tip2[:5]

  • MSE, Rsquare 확인
# 단순선형회귀 mse: X변수가 전체 금액
# 다중선형회귀 mse: X변수가 전체 금액, 성별
print('단순선형회귀', mean_squared_error(y_true_tip, y_pred_tip))
print('다중선형회귀', mean_squared_error(y_true_tip, y_pred_tip2))

단순선형회귀 1.036019442011377
다중선형회귀 1.0358604137213614

print('단순선형회귀', r2_score(y_true_tip, y_pred_tip))
print('다중선형회귀', r2_score(y_true_tip, y_pred_tip2))
단순선형회귀 0.45661658635167657
다중선형회귀 0.45669999534149974

 

'[업무 지식] > Machine learning' 카테고리의 다른 글

[로지스틱회귀] 분류분석  (0) 2024.11.19