Subscribe to DSC Newsletter

I'm working on implementing regression trees, but the algorithm I'm using to calculate the best split is not working.  I've read and re-read the chapters on regression trees and choosing the best split using Least Squares Regression, and they are very clear about how to do it.  I've also validated my math against R with respect to MSE calculations and they match perfectly.  My problem comes when I go to calculate the gain in MSE from performing a split. The texts are very clear.  That calculation is as follows:

deltaR(s, t) = R(t) - R(leftSplit) - R( rightSplit )
Best split = max( deltaR(s,t) )

However, quite frequently I'm getting negative numbers or very small numbers for this quantity.  For example, if R(t) = 0.65 then R(left) = 0.64 and R(right) = 0.43.  Since adding up the left and right is potentially greater than 1.0 it can go negative.  In this particular case R says the improvement is 0.17, and I'm coming up with -0.42.  I also tried multiplying those two quantities by their relative size:

R(t) - sizeOf(leftSplit) / sizeOf(t) * R(leftSplit) - sizeOf(rightSplit) / sizeOf(t) * R(rightSplit)

But I didn't get anything close to what R had as the gain in MSE.  At this point I'm not sure if they are measuring the gain the same way, but in their comments they mention Sum of Squares, and they calculate the mean of each split.  

I've compared this against the improvements in gain from R using rpart(), and I can't figure out how they are coming up with the numbers they get.  I've even read the code in R, but they are doing an optimization that I can't quite understand so comparing their method against my own is not direct.  If someone understands the optimization I'd love to incorporate that into what I'm doing, but without an explanation of how it works I'm reluctant to just copy it.

What am I missing here?  I'm running out of ideas.

Thanks in Advance,

Tags: cart, decisiontrees, regression

Views: 91

Reply to This

Replies to This Discussion


I don't have a lot of experience with this, but the one implementation of CART that I use is from SPSS and they use Least Squares Deviation (LSD) as the impurity measure....




On Data Science Central

© 2020   TechTarget, Inc.   Powered by

Badges  |  Report an Issue  |  Privacy Policy  |  Terms of Service