Initializing LSTM which is imported using ONNX

8 次查看(过去 30 天)
Hi,
I am training an LSTM for RL using Ray in Python. I would like to export this model using ONNX and afterwards import it in Matlab. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.
Minimum working example:
Python code to train LSTM:
import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
% Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)
% train for 2 episodes
for i in range(2):
result = algo.train()
% get policiy
ppo_policy = algo.get_policy()
% batch size
B=1
% initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)
% apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))
% save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)
Code in Matlab:
% Import model from where I saved it
net = importNetworkFromONNX('path/to/onnx-model');
% input shapes
obs_size = [1,4];
state_size=[2,1,256];
seq_lens_size=[1];
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
Error message:
I appreciate any help!
Best,
Andreas

回答(3 个)

Joss Knight
Joss Knight 2024-7-18
This code is suspect
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
I think your network has a single input, so you need to pass a single input to initialize (along with the network), basically just some example input exactly like you want to pass to predict. I think you have two channels and a sequence length of 256? And one of your dimensions is Time so you need a T dimension. And I don't think you have any spatial dimensions, so no S labels. So you need something like
exampleInput = dlarray(rand(2,1,256),'CBT');
net = initialize(net, exampleInput);
Or if you prefer, a permutation of that like
exampleInput = dlarray(rand(256,2,1),'TCB');
net = initialize(net, exampleInput);
If this doesn't work, try running analyzeNetwork(net) to see where your inputs are and we can work out what to expect.
  1 个评论
Andreas
Andreas 2024-7-23
Hi,
the network does not have a single input. I managed to solve the issue, see below for my response. Thank you, for your help anyway!

请先登录,再进行评论。


Kaustab Pal
Kaustab Pal 2024-7-19
It seems you want to determine the input dimension of your imported network. You can easily find this information using the analyzeNetwork function. This function provides an interactive visualization of the network architecture and detailed information, including:
  • Layer types
  • Sizes and formats of layer learnable parameters
  • States and activations
  • Total number of learnable parameters
The activation size of the topmost layer will give you the input dimension.
Additionally, when creating dlarray objects in MATLAB, you need to specify the format, which must follow this order:
  • "S" (Spatial)
  • "C" (Channel)
  • "B" (Batch)
  • "T" (Time)
  • "U" (Unspecified)
For more details, you can refer to the following links:
  1. analyzeNetwork Documentation: https://www.mathworks.com/help/deeplearning/ref/analyzenetwork.html#mw_bdd24886-fa03-4540-a111-391541a0a684
  2. dlarray Documentation:: https://www.mathworks.com/help/deeplearning/ref/dlarray.html#d126e57736:~:text=When%20you%20create%20a%20formatted%20dlarray%20object%2C%20the%20software%20automatically%20permutes%20the%20dimensions%20such%20that%20the%20format%20has%20dimensions%20in%20this%20order%3A
Hope this helps.

Andreas
Andreas 2024-7-23
Helly everyone,
thank you for your help. Unfortunately, I had to work around this issue but I could solve it in the end. I believe the reason for matlab struggling is that within Ray's Rllib the models contain a lot of complicated overhead. In particular the inputs to the network are lists/dicts etc which undergo quite some reformatting which seemed to cause some issues. In the end, what I did is extract the actual torch models which are relevant from the trained Rllib object and joined them in a new torch.nn.Module object. For this object it worked out just fine using torch.onnx.export.
Thank you all for your help.
Best, Andreas

类别

Help CenterFile Exchange 中查找有关 Sequence and Numeric Feature Data Workflows 的更多信息

产品


版本

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by