Dynamic sequence length for transforme-based model - error when exporting from Python to MATLAB

10 次查看(过去 30 天)
We developed a simple transformer architecture (see the Python code below). This model, which we created using Python, can handle sequences of different lengths. I want to use my model in MATLAB. I tried to export the model to ONNX or to PT format. In both cases, I had to fix the input shape to export my model. I used the torch.jit.script() function in Python to trace and export my model in the .pt format. However, I think pytorchmex from the Deep Learning Toolbox Converter for PyTorch Models only works with torch.jit.trace.
I want to find a way to use a model in MATLAB that can accept inputs of any length.
Any help would be much appreciated.
# Python Code
# Model class to export
class TransformerModel(nn.Module):
def __init__(
self,
input_dim,
model_dim,
n_classes,
num_heads,
num_layers,
):
super(TransformerModel, self).__init__()
self.model_dim = model_dim
# Embedding Layer
self.embedding = nn.Linear(input_dim, model_dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output Layer
self.output_layer = nn.Linear(model_dim, n_classes)
def forward(self, x, padding_mask):
padding_mask = ~padding_mask
x = self.embedding(x)
# Transformer Encoder
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
# Model prediction
output = self.output_layer(x)
return output

回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Deep Learning with GPU Coder 的更多信息

产品


版本

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by