-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
skorch.fit can't handle lists of lists with variable length #605
Comments
Don't know about the specifics of this in skorch but generally you need to add padding / perform slicing so every sample has the same length. The only exception of this are Tensorflow's Ragged Tensors, but even then you have to specify a default value to pad with when converting to regular tensors (Pytorch doesn't have Ragged Tensors yet). |
@econti Could you check whether Otherwise, we have an example here that shows how to potentially deal with variable length sequences. |
Thanks @BenjaminBossan, that did the trick for me. Leaving a code snippet here for anyone else who encounters a similar issue:
|
@econti Great that you found a solution and thanks for the snippet. |
I'm facing a similar issue right now and I suspect I'm doing the same thing that you're doing, which is padding to the longest sequence length in the dataset, which results in significantly more computation than would result from padding at the batch level. I suspect we need something like a collate_fn that operates at the batch level to solve this the right way. |
@ToddMorrill I don't know the exact details of your case, so maybe I'm missing something. In general through, However, this is not the canonical way o fdealing with sequences of different lengths. Maybe you can make use of |
"A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch." Source That's what I'm trying to do. Thanks for pointing me toward I'm not opposed to using |
To be sure, I'm trying to reuse the following torchtext code with skorch.
So far, I haven't found a way to reuse I welcome any suggestions on recycling this code. |
My apologies for all the posts but I just wanted to share a quick update before signing off and ask a question. I created a custom dataset and then implemented a custom
What's amazing about padding at the batch level is that run times went from 60 seconds per epoch to 20 seconds per epoch - a huge improvement. However, I was liking all of the functionality I had while using I'm still interested in recycling the torchtext functionality so if you have thoughts on that, I still welcome them!! Thanks for all of your help! I'm loving skorch. |
That's great to hear, thanks.
Sorry that I have confused you, you can do the same thing with
It depends a bit. What does your target look like? Potentially, it could be possible to extract it and pass it as
Note that you can use the scoring functions also with |
I'm making progress on my example text classification pipeline using |
I believe it makes a lot of sense to make skorch work with popular libraries like torchtext and torchvision. When we released skorch, the former didn't exist yet, so now we might be in a place where not everything works together. However, there might still be a way. I would need to look more thoroughly at what torchtext provides and see what we can do, once I have a bit of time. @ToddMorrill please keep us up-to-date if you find some better solution. |
Hi guys,
I have good news for you :D Here I prepared a short example (somewhat similar to the one provided by @ToddMorrill ) how to integrate import torch
import skorch
import random
import numpy as np
import pandas as pd
from torchtext.data import BucketIterator, Example, Dataset, Field, LabelField
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
SEED = 137
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
def data(size=1000):
return pd.DataFrame({
"query": ["This is a duck", "This is a goose"] * size,
"target": [0, 1] * size,
})
class TextPreprocessor(BaseEstimator, TransformerMixin):
def __init__(self, fields, need_vocab=None):
self.fields = fields
self.need_vocab = need_vocab or {}
def fit(self, X, y=None):
dataset = self.transform(X, y)
for field, min_freq in self.need_vocab.items():
field.build_vocab(dataset, min_freq=min_freq)
return self
def transform(self, X, y=None):
proc = [X[col].apply(f.preprocess) for col, f in self.fields]
examples = [Example.fromlist(f, self.fields) for f in zip(*proc)]
return Dataset(examples, self.fields)
def build_preprocessor():
text_field = Field(lower=True)
label_field = LabelField(is_target=True)
fields = [
('query', text_field),
('target', label_field),
]
return TextPreprocessor(fields, need_vocab={text_field: 0, label_field: 0})
class SimpleModule(torch.nn.Module):
def __init__(self, vocab_size=100, emb_dim=16, lstm_hidden_dim=32):
super().__init__()
self._emb = torch.nn.Embedding(vocab_size, emb_dim)
self._rnn = torch.nn.LSTM(emb_dim, lstm_hidden_dim)
self._out = torch.nn.Linear(lstm_hidden_dim, 2)
def forward(self, inputs):
rnn_output = self._rnn(self._emb(inputs))[0]
return torch.nn.functional.softmax(self._out(rnn_output[-1]))
class InputShapeSetter(skorch.callbacks.Callback):
def on_train_begin(self, net, X, y):
# NB: If your module relies on pretrained embeddings
# net.set_params(module__embeddings=X.fields["query"].vocab.vectors)
pass
def build_model():
model = skorch.NeuralNetClassifier(
module=SimpleModule,
iterator_train=BucketIterator,
iterator_valid=BucketIterator,
train_split=Dataset.split,
callbacks=[InputShapeSetter()],
)
full = make_pipeline(
build_preprocessor(),
model
)
return full
def main():
df = data()
assert type(df) == pd.DataFrame
dataset = build_preprocessor().fit_transform(df)
assert type(dataset) == Dataset
# Putting it all together
model = build_model().fit(
df, # pd.DataFrame, torchtext handles X and y
0.7 # <<< ?? This sets split_ratio for Dataset.split
)
print(model.predict(df))
assert model.score(df, df["target"]) > 0.5, "Fitting issues"
if __name__ == '__main__':
main() This code should work with the latest versions of the libraries. The only strange thing is that you have to pass @BenjaminBossan It looks like you are a member of the dev team. Probably 594 is somehow related to the topic with Once again sorry for spamming. |
@kqf Thanks for posting the example, I'm taking a look at it. At the end of the day, I think it would be nice to add a notebook that showcases how to use torchtext. Ideally, it should use one of the torchtext datasets like IMDB and pretrained embeddings.
Yes, that works, but it's a bit of a hacky solution. This solution here should be clearer: from functools import partial
def my_train_split(dataset, y, split_ratio):
return dataset.split(split_ratio=split_ratio)
...
def build_model():
model = skorch.NeuralNetClassifier(
module=SimpleModule,
iterator_train=BucketIterator,
iterator_valid=BucketIterator,
train_split=partial(my_train_split, split_ratio=0.7),
callbacks=[InputShapeSetter()],
)
...
model = build_model().fit(df) # no need to pass split_ratio here |
It's totally doable, I didn't want to download the data/embeddings on my private laptop.
Yes, I agree, but that was the one of my intentions: to demonstrate that What do you think? |
Yes, what you posted is a really good starting point.
I think those two lines are acceptable :)
I think that could make sense. Do you want to work on this change? In the meantime, I tried to implement a torchtext example with skorch that's a bit closer to a real world problem someone could have. It uses skorch with torchtext and BERT (via huggingface). Here is the notebook: @kqf @ToddMorrill since you know torchtext much better than I do, could you check if what I did makes sense? E.g., I don't really understand what all this The main change that I had to introduce was to slightly change class SkorchBucketIterator(BucketIterator):
def __iter__(self):
for batch in super().__iter__():
# We make a small modification: Instead of just returning batch
# we return batch.text and batch.label, corresponding to X and y
yield batch.text, batch.label.long() skorch basically really wants to always have an ping @ottonemo maybe this is also interesting for you. |
Yes, I'd love to help, but I will have time only on weekends. If it's ok -- I am in.
I am not an expert in I like the way you are handling
I think this is important what you are saying. The default |
@BenjaminBossan One more thing about examples with In any event, if you have to pass multiple fields to
from operator import attrgetter
def batch2dict(batch):
return {f: attrgetter(f)(batch) for f in batch.input_fields}
class SkorchBucketIterator(BucketIterator):
def __iter__(self):
for batch in super().__iter__():
# We make a small modification: Instead of just returning batch
# we return dict() and empty tensor, corresponding to X and y
yield batch2dict(batch), torch.empty(0)
So, this should demonstrate how to use multiple fields with |
No problem at all. If you need help along the way, just ask.
Thanks for providing the example.
I'm curious what exactly you are doing there. I implemented some metric learning approaches in the past, typically using something like a Siamese net. You could use the |
If you ask about the application, it's a chatbot (there is a database with replies, so the model needs to find the most relevant one when supplied with the user query). And it's very similar to the example you mentioned. In my case, it's somewhat easier as I do hard and semi-hard negatives mining within a batch. I decided to separate the logic: I have a separate encoder-towers module and a loss module that mines hard-negatives and calculates the triplet loss. I didn't know about
What do you mean by |
I just meant that the |
This is working really well. My epoch times were pretty much cut in half with this modification. Thank you for your example @BenjaminBossan. I just tried to plug this into a grid search like the following and got an error. I'm including the traceback for reference. I can try to look into the error but I'm not too familiar with sklearn's internals. Is there a way forward here?
Traceback
|
@ToddMorrill Thanks for reporting. It's not quite easy for me to deduce what's going on. Could you either provide me a minimal code sample to reproduce the error or check the following things for me (by using a debugger):
--> 650 X, y, groups = indexable(X, y, groups) what are the types of
|
I was able to reproduce it with your example by adding the following lines to the bottom of the script.
Running this results in the following output. No errors but it didn't train.
From the debugger:
I believe
From the debugger:
Yes, it's fantastic! |
Thanks for investigating @ToddMorrill I tracked down the weird In my opinion, this a bug on the Basically every code that calls
with an unknown attribute will do the wrong thing. This is especially grave with sklearn, since sklearn will at one point check I tried to override their def __getattr__(self, attr):
if attr in self.fields:
[getattr(x, attr) for x in self.examples]
else:
raise AttributeError("no attribute", attr) However, then I run into the next problem, namely these lines: They basically rely on the faulty So overall, I'm sorry to say that you might just not be able to combine |
Thanks for that explanation @BenjaminBossan. I filed a bug with |
Quick update on this. This describes their plans a bit more. I'm hoping in the long run this will make |
Thanks for reporting back. I read it but since I'm not familiar with torchtext, I can't really judge the changes. The general idea seems to be good. Whether it makes it easier to integrate with skorch will have to be seen. @ToddMorrill do you have any experience with using the facilities provided by huggingface instead of torchtext? I wonder if those cooperate better with skorch. I think it could also be interesting to provide sklearn transformers to wrap their tokenizers, which would allow to integrate them into an sklearn pipeline. |
I haven't had a chance to use huggingface's tools, but it's on my current project's roadmap. I'll share if I get anything running. |
Hey @BenjaminBossan, quick question. Circling back to my comment above - would it be possible to use |
@ToddMorrill could you try if one of these three proposals works for you?
|
Good thoughts! I tried all 3 techniques and you can see the example that I'm working on for the dask team here. There's a section in this notebook titled "Grid search with Skorch" where you'll see all 3 attempts that all resulted in |
Could you please paste the full stack trace for the error? I assume it's the same for all 3 cases? |
Indeed, the error and stack trace were the same for all 3 cases. Here it is.
|
This is interesting, it looks like it works a few times and then suddenly it breaks. Could you please initialize the net with After trying that, regardless of if it helps, please do the following: |
FWIW, the default value for the
Makes sense, thanks for the insight. Running with
Do you think this is because my custom |
That's a bit strange:
This code path should never be reached because this line comes before it: Lines 1154 to 1155 in 6fe94fd
Could you maybe turn on the debugger and check the value of |
This release of skorch contains a few minor improvements and some nice additions. As always, we fixed a few bugs and improved the documentation. Our [learning rate scheduler](https://skorch.readthedocs.io/en/latest/callbacks.html#skorch.callbacks.LRScheduler) now optionally logs learning rate changes to the history; moreover, it now allows the user to choose whether an update step should be made after each batch or each epoch. If you always longed for a metric that would just use whatever is defined by your criterion, look no further than [`loss_scoring`](https://skorch.readthedocs.io/en/latest/scoring.html#skorch.scoring.loss_scoring). Also, skorch now allows you to easily change the kind of nonlinearity to apply to the module's output when `predict` and `predict_proba` are called, by passing the `predict_nonlinearity` argument. Besides these changes, we improved the customization potential of skorch. First of all, the `criterion` is now set to `train` or `valid`, depending on the phase -- this is useful if the criterion should act differently during training and validation. Next we made it easier to add custom modules, optimizers, and criteria to your neural net; this should facilitate implementing architectures like GANs. Consult the [docs](https://skorch.readthedocs.io/en/latest/user/neuralnet.html#subclassing-neuralnet) for more on this. Conveniently, [`net.save_params`](https://skorch.readthedocs.io/en/latest/net.html#skorch.net.NeuralNet.save_params) can now persist arbitrary attributes, including those custom modules. As always, these improvements wouldn't have been possible without the community. Please keep asking questions, raising issues, and proposing new features. We are especially grateful to those community members, old and new, who contributed via PRs: ``` Aaron Berk guybuk kqf Michał Słapek Scott Sievert Yann Dubois Zhao Meng ``` Here is the full list of all changes: ### Added - Added the `event_name` argument for `LRScheduler` for optional recording of LR changes inside `net.history`. NOTE: Supported only in Pytorch>=1.4 - Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params - Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch. - Added the `scoring` module with `loss_scoring` function, which computes the net's loss (using `get_loss`) on provided input data. - Added a parameter `predict_nonlinearity` to `NeuralNet` which allows users to control the nonlinearity to be applied to the module output when calling `predict` and `predict_proba` (#637, #661) - Added the possibility to save the criterion with `save_params` and with checkpoint callbacks - Added the possibility to save custom modules with `save_params` and with checkpoint callbacks ### Changed - Removed support for schedulers with a `batch_step()` method in `LRScheduler`. - Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620) - The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527) - Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626) - Set train/validation on criterion if it's a PyTorch module (#621) - Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605). - `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658) - When using `CrossEntropyLoss`, softmax is now automatically applied to the output when calling `predict` or `predict_proba` ### Fixed - Fixed a bug where `CyclicLR` scheduler would update during both training and validation rather than just during training. - Fixed a bug introduced by moving the `optimizer.zero_grad()` call outside of the train step function, making it incompatible with LBFGS and other optimizers that call the train step several times per batch (#636) - Fixed pickling of the `ProgressBar` callback (#656)
@ToddMorrill any updates? |
Since there haven't been any updates for quite a while, I assume this has been resolved. Feel free to re-open if not. |
I'm having a hard time figuring out how to pass a list of lists (with variable length) to skorch's
fit
method.Specifically, I have a feature that is a list of ID's (e.g.
[[1, 12, 3], [6, 22]...]
) which are converted to a dense representation using an embedding table in my PyTorch module'sforward
method:When I call
net.fit()
on my data set (e.g.{"X_float": ..., "X_id_list": ...}
I get the following error caused by the list of lists:I've also tried converting the list of lists to a pandas dataframe and numpy array (of objects) and neither works. How do you handle variable length lists of lists in
skorch.fit
?The text was updated successfully, but these errors were encountered: