데이터분석/R

[ADP] 의사결정나무 (Decision Tree)

버섯도리 2022. 1. 15. 17:31

> # 06. 의사결정나무 (Decision Tree)


> # 2. 분류트리 

> library(MASS)

> data("biopsy")
> str(biopsy)
'data.frame': 699 obs. of  11 variables:
 $ ID   : chr  "1000025" "1002945" "1015425" "1016277" ...
 $ V1   : int  5 5 3 6 4 8 1 2 2 4 ...
 $ V2   : int  1 4 1 8 1 10 1 1 1 2 ...
 $ V3   : int  1 4 1 8 1 10 1 2 1 1 ...
 $ V4   : int  1 5 1 1 3 8 1 1 1 1 ...
 $ V5   : int  2 7 2 3 2 7 2 2 2 2 ...
 $ V6   : int  1 10 2 4 1 10 10 1 1 1 ...
 $ V7   : int  3 3 3 3 3 9 3 3 1 2 ...
 $ V8   : int  1 2 1 7 1 7 1 1 1 1 ...
 $ V9   : int  1 1 1 1 1 1 1 1 5 1 ...
 $ class: Factor w/ 2 levels "benign","malignant": 1 1 1 1 1 2 1 1 1 1 ...
> biopsy <- biopsy[,-1]
> names(biopsy) <- c("thick","size","shape","adhsn","s.size","nucl","chrom","n.nuc","mit","class")

> head(biopsy)
  thick size shape adhsn s.size nucl chrom n.nuc mit     class
1     5    1     1     1      2    1     3     1   1    benign
2     5    4     4     5      7   10     3     2   1    benign
3     3    1     1     1      2    2     3     1   1    benign
4     6    8     8     1      3    4     3     7   1    benign
5     4    1     1     3      2    1     3     1   1    benign
6     8   10    10     8      7   10     9     7   1 malignant
> biopsy.v2 <- na.omit(biopsy)

> set.seed(123) # 난수 발생 초기화
> ind <- sample(2, nrow(biopsy.v2), replace = TRUE, prob = c(0.7,0.3))
> train <- biopsy.v2[ind==1,]
> test <- biopsy.v2[ind==2,]

> library(rpart)

> tree.train <- rpart(class~., data = train)
> tree.train$cptable
          CP nsplit rel error    xerror       xstd
1 0.79651163      0 1.0000000 1.0000000 0.06086254
2 0.07558140      1 0.2034884 0.3081395 0.03988975
3 0.01162791      2 0.1279070 0.1744186 0.03082013
4 0.01000000      3 0.1162791 0.1802326 0.03129429
> # cp = 복잡도 파라미터
> # nsplit = 트리의 분할 횟수
> # rel error = RSS(k) / RSS(0)
> # xerror = 평균오차
> # xstd = 표준편차
> # 전체 데이터에서 2번 분할하면 가장 낮은 오차를 얻음

> library(partykit)
> # 2번 가지치기를 하고 시각화한다.
> cp <- tree.train$cptable[3,"CP"]
> prune.tree.train <- prune(tree.train, cp=cp)
plot(as.party(prune.tree.train))


> rparty.test <- predict(prune.tree.train, newdata = test, type = "class")
table(rparty.test, test$class)
           
rparty.test benign malignant
  benign       136         3
  malignant      6        64
> # 정확도는 (136+64)/209 = 95.7%

plot(as.party(tree.train))

> # 가지치기 이전의 트리구조를 확인할 수 있다.


> # 3. 회귀트리 

> library(ElemStatLearn)

> data("prostate")
> str(prostate)
'data.frame': 97 obs. of  10 variables:
 $ lcavol : num  -0.58 -0.994 -0.511 -1.204 0.751 ...
 $ lweight: num  2.77 3.32 2.69 3.28 3.43 ...
 $ age    : int  50 58 74 58 62 50 64 58 47 63 ...
 $ lbph   : num  -1.39 -1.39 -1.39 -1.39 -1.39 ...
 $ svi    : int  0 0 0 0 0 0 0 0 0 0 ...
 $ lcp    : num  -1.39 -1.39 -1.39 -1.39 -1.39 ...
 $ gleason: int  6 6 7 6 6 6 6 6 6 6 ...
 $ pgg45  : int  0 0 20 0 0 0 0 0 0 0 ...
 $ lpsa   : num  -0.431 -0.163 -0.163 -0.163 0.372 ...
 $ train  : logi  TRUE TRUE TRUE TRUE TRUE TRUE ...
> summary(prostate)
     lcavol           lweight           age             lbph              svi              lcp             gleason     
 Min.   :-1.3471   Min.   :2.375   Min.   :41.00   Min.   :-1.3863   Min.   :0.0000   Min.   :-1.3863   Min.   :6.000  
 1st Qu.: 0.5128   1st Qu.:3.376   1st Qu.:60.00   1st Qu.:-1.3863   1st Qu.:0.0000   1st Qu.:-1.3863   1st Qu.:6.000  
 Median : 1.4469   Median :3.623   Median :65.00   Median : 0.3001   Median :0.0000   Median :-0.7985   Median :7.000  
 Mean   : 1.3500   Mean   :3.629   Mean   :63.87   Mean   : 0.1004   Mean   :0.2165   Mean   :-0.1794   Mean   :6.753  
 3rd Qu.: 2.1270   3rd Qu.:3.876   3rd Qu.:68.00   3rd Qu.: 1.5581   3rd Qu.:0.0000   3rd Qu.: 1.1787   3rd Qu.:7.000  
 Max.   : 3.8210   Max.   :4.780   Max.   :79.00   Max.   : 2.3263   Max.   :1.0000   Max.   : 2.9042   Max.   :9.000  
     pgg45             lpsa           train        
 Min.   :  0.00   Min.   :-0.4308   Mode :logical  
 1st Qu.:  0.00   1st Qu.: 1.7317   FALSE:30       
 Median : 15.00   Median : 2.5915   TRUE :67       
 Mean   : 24.38   Mean   : 2.4784                  
 3rd Qu.: 40.00   3rd Qu.: 3.0564                  
 Max.   :100.00   Max.   : 5.5829                  


> prostate$gleason <- ifelse(prostate$gleason == 6, 0, 1) # 범주형으로 변환
> summary(prostate)
     lcavol           lweight           age             lbph              svi              lcp             gleason      
 Min.   :-1.3471   Min.   :2.375   Min.   :41.00   Min.   :-1.3863   Min.   :0.0000   Min.   :-1.3863   Min.   :0.0000  
 1st Qu.: 0.5128   1st Qu.:3.376   1st Qu.:60.00   1st Qu.:-1.3863   1st Qu.:0.0000   1st Qu.:-1.3863   1st Qu.:0.0000  
 Median : 1.4469   Median :3.623   Median :65.00   Median : 0.3001   Median :0.0000   Median :-0.7985   Median :1.0000  
 Mean   : 1.3500   Mean   :3.629   Mean   :63.87   Mean   : 0.1004   Mean   :0.2165   Mean   :-0.1794   Mean   :0.6392  
 3rd Qu.: 2.1270   3rd Qu.:3.876   3rd Qu.:68.00   3rd Qu.: 1.5581   3rd Qu.:0.0000   3rd Qu.: 1.1787   3rd Qu.:1.0000  
 Max.   : 3.8210   Max.   :4.780   Max.   :79.00   Max.   : 2.3263   Max.   :1.0000   Max.   : 2.9042   Max.   :1.0000  
     pgg45             lpsa           train        
 Min.   :  0.00   Min.   :-0.4308   Mode :logical  
 1st Qu.:  0.00   1st Qu.: 1.7317   FALSE:30       
 Median : 15.00   Median : 2.5915   TRUE :67       
 Mean   : 24.38   Mean   : 2.4784                  
 3rd Qu.: 40.00   3rd Qu.: 3.0564                  
 Max.   :100.00   Max.   : 5.5829                  
> pros.train <- subset(prostate, train==TRUE)[,1:9]
> pros.test <- subset(prostate, train==FALSE)[,1:9]

> tree.pros <- rpart(lpsa~., data = pros.train)
> print(tree.pros$cptable)
          CP nsplit rel error    xerror       xstd
1 0.35852251      0 1.0000000 1.0282472 0.18211906
2 0.12295687      1 0.6414775 1.0048719 0.13250840
3 0.11639953      2 0.5185206 0.8004360 0.10353351
4 0.05350873      3 0.4021211 0.7520035 0.09101813
5 0.01032838      4 0.3486124 0.6782018 0.08373662
6 0.01000000      5 0.3382840 0.6702477 0.08200366
> # 전체 데이터셋에서 5번 분할하면 가장 낮은 오차를 얻는다.

> plotcp(tree.pros)


> cp <- tree.pros$cptable[6,"CP"]
> prune.tree.pros <- prune(tree.pros, cp=cp)
plot(as.party(prune.tree.pros))



> # 가지치기를 한 트리가 Test 데이터를 이용해 수행 능력을 확인한다.
> party.pros.test = predict(prune.tree.pros, newdata = pros.test)
rpart.resid = party.pros.test - pros.test$lpsa
mean(rpart.resid^2) # 평균제곱오차(MSE)
[1] 0.6136057

 

 

 

 

 

출처 : 2020 데이터 분석 전문가 ADP 필기 한 권으로 끝내기