PYTHON DATA 시각화 – SEABORN #2
seaborn 그리드 사용햔 다변량 분석¶
seaborn 은 여러 도면을 그리드에 나타낼 수 있다. catplot, implot, pairplot, jointplot, clustermap
figure나 grid 함수는 대부분 axes 함수를 사용해 그리드를 만든다. grid 함수에서 반환된 최종 객체는 grid 형식인데, 서로 다른 4가지 형식이 있다.
1. 성별과 인종에 따른 근무 경력과 급여 간의 관계¶
In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
In [3]:
emp = pd.read_csv('data/employee.csv', parse_dates=['HIRE_DATE','JOB_DATE'])
In [4]:
def yrs_exp(df_):
days_hired = pd.to_datetime('12-1-2016') - df_.HIRE_DATE
return days_hired.dt.days / 365.25
In [5]:
emp = (emp
.assign(YEARS_EXPERIENCE=yrs_exp)
)
emp[['HIRE_DATE','YEARS_EXPERIENCE']]
Out[5]:
1-1. seaborn regplot¶
In [7]:
fig, ax = plt.subplots(figsize=(8,6))
sns.regplot(x='YEARS_EXPERIENCE', y='BASE_SALARY', data=emp, ax=ax)
fig.savefig('c13-scat4.png',dpi=300, bbox_inches='tight')
1-2 seaborn lmplot¶
In [9]:
grid = sns.lmplot(x='YEARS_EXPERIENCE', y='BASE_SALARY', hue='GENDER', scatter_kws={'s':10}, data=emp)
grid.fig.set_size_inches(8,6)
grid.fig.savefig('c13-scat5.png', dpi=300, bbox_inches='tight')
In [10]:
grid = sns.lmplot(x='YEARS_EXPERIENCE', y='BASE_SALARY', hue='GENDER', col='RACE', col_wrap=3,sharex=False,
line_kws = {'linewidth':5}, data=emp)
grid.set(ylim=(20000, 120000))
grid.fig.savefig('c13-scat6.png', dpi=300, bbox_inches='tight')
1-3 seaborn violinplot, catplot¶
In [19]:
deps = emp['DEPARTMENT'].value_counts().index[:2]
races = emp['RACE'].value_counts().index[:3]
is_dep = emp['DEPARTMENT'].isin(deps)
is_race = emp['RACE'].isin(races)
emp2 = (emp
[is_dep & is_race]
.assign(DEPARTMENT=lambda df_ : df_['DEPARTMENT'].str.extract('(HPD|HFD)',expand=True))
)
emp2.shape
emp2['DEPARTMENT'].value_counts()
emp2['RACE'].value_counts()
Out[19]:
In [23]:
common_depts = (emp.groupby('DEPARTMENT')
.filter(lambda group: len(group) > 50)
)
In [24]:
fig, ax = plt.subplots(figsize=(8,6))
sns.violinplot(x='YEARS_EXPERIENCE', y='GENDER', data=common_depts)
fig.savefig('c13-vio1.png', dpi=300, bbox_inches='tight')
In [26]:
grid = sns.catplot(x='YEARS_EXPERIENCE',y='GENDER',col='RACE',row='DEPARTMENT',height=3, aspect=2, data=emp2, kind='violin')
grid.fig.savefig('c13-vio2.png',dpi=300, bbox_inches='tight')
2. seaborn diamonds 데이터셋의 심슨 역설 발견¶
높은 품질의 다이아몬드가 낮은 품질의 다이아몬드보다 더 가치가 없다는 것을 암시하는 황당한 결론에 도달한다. 그 반대가 사실이라는 것을 알려주는 데이터를 좀 더 정교히 취함으로써 심슨의 역설을 밝혀낸다
In [27]:
dia = pd.read_csv('data/diamonds.csv')
dia
Out[27]:
In [29]:
cut_cats = ['Fair','Good','Very Good', 'Premium', 'Ideal']
color_cats = ['J','I','H','G','F','E','D']
claritt_cats = ['I1','SI2','SI1','VS2','VS1','VVS2','VVS1','IF']
dia2 = (dia
.assign(cut=pd.Categorical(dia['cut'],
categories=cut_cats,
ordered=True),
color=pd.Categorical(dia['color'],
categories=color_cats,
ordered=True),
clarity=pd.Categorical(dia['clarity'],
categories=claritt_cats,
ordered=True))
)
dia2
Out[29]:
In [32]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(14,4))
sns.barplot(x='color', y='price', data=dia2, ax=ax1)
sns.barplot(x='cut', y='price', data=dia2, ax=ax2)
sns.barplot(x='clarity', y='price', data=dia2, ax=ax3)
fig.suptitle('Price Descrasing with Increasing Quality?')
fig.savefig('c13-bar4.png', dpi=300, bbox_inches='tight')
In [33]:
grid = sns.catplot(x='color', y='price', col='clarity', col_wrap=4, data=dia2, kind='bar')
grid.fig.savefig('c13-bar5.png', dpi=300, bbox_inches='tight')
색상의 품질이 증가할수록 가격이 하락하는 것처럼 보였지만 투명도가 최고 수준일 때는 오히려 가격이 상승했다. 다이아몬드 크기에는 주의를 기울이지 않고 그저 다이아몬드 가격만 살펴보고 있다.
In [34]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(14,4))
sns.barplot(x='color', y='carat', data=dia2, ax=ax1)
sns.barplot(x='cut', y='carat', data=dia2, ax=ax2)
sns.barplot(x='clarity', y='carat', data=dia2, ax=ax3)
fig.suptitle('Diamond size decreases with quality')
fig.savefig('c13-bar6.png', dpi=300, bbox_inches='tight')
고품질의 다이아몬드는 크기가 작은 것 같고 이는 직관적으로 말이 된다.
In [39]:
dia2 = (dia2.assign(carat_category=pd.cut(dia2.carat, 5)))
from matplotlib.cm import Greys
import numpy as np
greys = Greys(np.arange(50,250,40))
grid = sns.catplot(x='clarity',y='price',data=dia2, hue='carat_category', col='color', col_wrap=4, kind='point')
grid.fig.suptitle('Diamond price by size, color and clarity', y=1.02, size=20)
grid.fig.savefig('c13-bar7.png', dpi=300, bbox_inches='tight')
In [41]:
g = sns.PairGrid(dia2, height=5,
x_vars=['color','cut','clarity'],
y_vars=['price'])
g.map(sns.barplot)
g.fig.suptitle('Replication of Step 3 with PairGrid', y=1.02)
g.fig.savefig('c13-bar8.png',dpi=300, bbox_inches='tight')
In [ ]: