Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it suitable for regression prediction? #7

Open
zichuan-liu opened this issue Feb 23, 2022 · 5 comments
Open

Is it suitable for regression prediction? #7

zichuan-liu opened this issue Feb 23, 2022 · 5 comments

Comments

@zichuan-liu
Copy link

zichuan-liu commented Feb 23, 2022

Hello, I'd like to ask if I want to make regression prediction and output_ Dim = = 1, is SDT applicable (it seems to be only used for classification model?)

Thanks!

@xuyxu
Copy link
Owner

xuyxu commented Feb 23, 2022

Hi @775269512, I think it is intuitive to use SDT on regression tasks, simply change the training criterion in main.py, and there should be no need to modify anything inside the implementation of SDT.

@zichuan-liu
Copy link
Author

hi, I did a simple experiment. Although the overall loss is decreasing, the output of each sample is the same and cannot be regressed (here out_dim = 1),

x = tensor([[ 1., 1., 1., 1., 1.],
[ 2., 2., 2., 2., 2.],
[ 3., 3., 3., 3., 3.],
[ 4., 4., 4., 4., 4.],
[ 5., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 6.],
[ 7., 7., 7., 7., 7.],
[ 8., 8., 8., 8., 8.],
[ 9., 9., 9., 9., 9.],
[10., 10., 10., 10., 10.]])
and
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]).ravel()

i got this result
tensor([[7.1672],
[7.3185],
[7.3203],
[7.3204],
[7.3204],
[7.3204],
[7.3204],
[7.3204],
[7.3204],
[7.3204]], grad_fn=)
Epoch: 499 | Loss: 1.93230 | Correct: 000/128

I don't know how to change it. It seems that when outdim = 1, the value of each leaf node is the same.

QAQ

@xuyxu
Copy link
Owner

xuyxu commented Feb 23, 2022

Could you show me the code snippet on training and evaluating?

@zichuan-liu
Copy link
Author

zichuan-liu commented Feb 23, 2022

yep, it's here. In addition, I found that a paper is based on SDT, and I will study it: "SDTR: Soft Decision Tree Regressor for Tabular Data". It is difficult to understand the differentiable decision tree. This is a good interpretable model and I want to use it to do something.

btw, I will also study postgraduate in nju next semester. I find you are my senior~

'''
from sklearn.datasets import fetch_california_housing

# Load data
housing = fetch_california_housing()
xs = torch.from_numpy(housing["data"]).float()
ys = torch.from_numpy(housing["target"]).unsqueeze(1).float()

print(xs.size())
print(ys.size())
print(xs)

input_dim = xs.size()[1]
output_dim = ys.size()[1]

# Model and Optimizer
tree = SDT(input_dim, output_dim, depth, lamda, use_cuda)

optimizer = torch.optim.Adam(tree.parameters(),
                             lr=lr,
                             weight_decay=weight_decaly)

# Utils
best_testing_acc = 0.0
testing_acc_list = []
training_loss_list = []
criterion = nn.MSELoss()
device = torch.device("cuda" if use_cuda else "cpu")

output, penalty = tree.forward(xs, is_training_data=True)


for epoch in range(epochs):
    # Training
    tree.train()
    output, penalty = tree.forward(xs, is_training_data=True)
    print(output)

    loss = criterion(output, ys.view(-1))
    # loss += penalty

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print training status
    pred = output.data.max(1)[1]
    correct = pred.eq(ys.view(-1).data).sum()

    msg = (
        "Epoch: {:02d} | Loss: {:.5f} |"
        " Correct: {:03d}/{:03d}"
    )
    print(msg.format(epoch, loss, correct, batch_size))
    training_loss_list.append(loss.cpu().data.numpy())

'''

@xuyxu
Copy link
Owner

xuyxu commented Feb 27, 2022

It looks like you are using the full batch training process (i.e., without using a dataloader that samples batches), maybe you should
consider to use one and train SDT in a stochastic way. Besides, what is the value of learning rate and weight decay?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants