You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class KanMLP(nn.Module):
"""Some Information about KanLinear"""
def __init__(self,
in_features=1152,
hidden_features = None,
out_features = None,
drop=0.
):
super().__init__()
approx_gelu = lambda: nn.GELU(approximate="tanh")
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.mlp = nn.ModuleDict(
dict(
c_fc=KAN(width=[in_features, hidden_features]),
c_proj=KAN(width=[hidden_features, out_features]),
act=NewGELU(),
dropout=nn.Dropout(0.0),
)
)
m = self.mlp
self.mlpf = lambda x: m.dropout(
m.c_proj(m.act(m.c_fc(x)))
) # MLP forward
def forward(self, x):
x = self.mlpf(x)
return x
net = KanMLP(1152,1152*4).to("cuda")
x = torch.rand(size=(4,4096*4,1152)).to("cuda")
nex(x)
When the number of tokens reaches a certain size, the following situation will occur
CUDA out of memory.
The text was updated successfully, but these errors were encountered:
I dropped your code into claude, and hopfully this gives you some indication:
The main reason you're running out of CUDA memory is the large size of your input tensor. Let's break down the memory usage:
Input tensor x:
Shape: (4, 4096*4, 1152) = (4, 16384, 1152)
Elements: 4 * 16384 * 1152 = 75,497,472
Assuming float32, this tensor alone requires about 302 MB of memory.
Network parameters:
Input size: 1152
Hidden size: 1152 * 4 = 4608
This results in two large matrices in the KAN layers, each potentially using significant memory.
Intermediate activations:
The forward pass will create several large intermediate tensors, further increasing memory usage.
To address this issue, you can try the following approaches:
Reduce batch size:
Instead of processing 4 samples at once, try reducing it to 1 or 2:
pythonCopyx = torch.rand(size=(1, 4096*4, 1152)).to("cuda")
Use gradient accumulation:
If you need to process larger batches for training stability, you can use gradient accumulation. This involves processing smaller sub-batches and accumulating gradients before performing an optimization step.
Use mixed precision training:
Utilize float16 (half-precision) computations to reduce memory usage. You can use NVIDIA's Apex library or PyTorch's native AMP (Automatic Mixed Precision):
pythonCopyfrom torch.cuda.amp import autocast
with autocast():
output = net(x)
Optimize your model architecture:
Consider if you can reduce the size of your hidden layers or use more memory-efficient architectures.
Use gradient checkpointing:
This technique trades computation for memory by not storing all intermediate activations.
If possible, process your data in smaller chunks:
Instead of processing the entire 4096*4 sequence length at once, you might be able to process it in smaller segments.
The text was updated successfully, but these errors were encountered: