主要内容

PyTorch Wrapper Template

You can use your own PyTorch models in MATLAB by using the Python interface. The py_wrapper_template.py code sample provides an interface with a predefined API.

Follow the instructions in the py_wrapper_template.py sample code to implement the recommended entry points and to insert your model instantiations. Delete the entry points that are not relevant to your project.

For example, the Train PyTorch Channel Prediction Models (5G Toolbox) example shows an offline training workflow that uses the following API set.

  • train — Trains the PyTorch model.

  • predict — Runs the PyTorch model with the provided inputs.

  • load_model_weights — Loads the PyTorch model weights.

py_wrapper_template.py Code Sample

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Import PyTorch AI model as torch_module

#----------------------------------------
#      Example Code Start
#----------------------------------------
# import pytorch_model_module as torch_module

###################################################################
#                   Neural Network Model
###################################################################
def construct_model(param1: int, param2: float, filename: str = "") -> torch.nn.Module:
    """
    Implement a function that uses the imported PyTorch model module (torch_module) 
    to instantiate a PyTorch model using the specified parameters.

    This function uses the `torch_module` to create a PyTorch model. Replace `param1` 
    and `param2` with the specific hyperparameters required to instantiate the model. 
    You may add or remove parameters as necessary to fit the requirements of your model. 
    Provide data type hints to facilitate MATLAB's Python interface with data conversion.

    If a `filename` is provided, the function also loads the model weights from the 
    specified file.

    Args:
        param1 (int): Model parameter 1.
        param2 (float): Model parameter 2.
        filename (str, optional): File name for saved model weights.

    Returns:
        torch.nn.Module: The instantiated PyTorch model. If `filename` is provided, 
        the model is loaded with the trained weights.
    """

    # Insert your model instantiation code here



    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # device = select_device()
    # model = torch_module.pytorch_model(param1, param2).to(device)
    # if filename:
    #     model = load_model_weights(model, filename)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    return my_model

###################################################################
#                   Train
###################################################################
def train(my_model, 
          x_train: np.ndarray, y_train: np.ndarray, 
          x_val: np.ndarray, y_val: np.ndarray, 
          hyperparam1, hyperparam2, hyperparam3, hyperparam4):
    """
    Implement a function that trains a PyTorch model using the provided 
    training and validation data.

    This function is for offline training, where data is passed once to Python 
    and the trained model is returned. Update the inputs based on your training 
    needs. This example assumes one model, training and validation data as arrays, 
    and four hyperparameters.

    Replace `hyperparam1`, `hyperparam2`, `hyperparam3`, and `hyperparam4` with 
    the hyperparameters needed to train the model. Add or remove parameters, 
    if needed. Use data type hints to help MATLAB Python interface with data conversion.

    Args:
        my_model: PyTorch model to train.
        x_train (np.ndarray): Network input for training.
        y_train (np.ndarray): Network output or target for training.
        x_val (np.ndarray): Network input for validation.
        y_val (np.ndarray): Network output or target for validation.
        hyperparam1: Initial learning rate or other hyperparameter.
        hyperparam2: Minibatch size for training.
        hyperparam3: Number of training epochs.
        hyperparam4: Frequency of validation during training in epochs.

    Returns:
        my_model: Trained model.
        training_loss: Array of training loss per iteration (or epoch).
        validation_loss: Array of validation loss per `validation_freq` epochs.
    """
  
    # Select the appropriate device (CPU or GPU) for training
    device = select_device(verbose=True)

    # Insert your training code here




    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # # Create DataLoaders
    # train_dataset = TensorDataset(torch.from_numpy(x_train).float().to(device), 
    #                               torch.from_numpy(y_train).float().to(device))
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    #
    # val_dataset = TensorDataset(torch.from_numpy(x_val).float().to(device), 
    #                             torch.from_numpy(y_val).float().to(device))
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    #
    # # Train
    # my_model, training_loss, validation_loss = torch_module.train(my_model,train_loader,val_loader,num_epochs,validation_freq)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    return my_model, training_loss, validation_loss
  
###################################################################
#                   Initialize Trainer
###################################################################
def setup_trainer(hyperparam1: float, hyperparam2: int, hyperparam3: int, hyperparam4: int):
    """
    Implement a function that sets up a PyTorch model trainer object for online training.

    This function is for online training, where data is passed to Python 
    for each training iteration. Update the inputs based on your training 
    needs. This example assumes four hyperparameters.

    Replace `hyperparam1`, `hyperparam2`, `hyperparam3`, and `hyperparam4` with 
    hyperparameters needed to train the model. Add or remove parameters, 
    if needed. Use data type hints to help MATLAB Python interface with data conversion.

    Args:
        hyperparam1 (float): Initial learning rate.
        hyperparam2 (int): Minibatch size for training.
        hyperparam3 (int): Number of training epochs.
        hyperparam4 (int): Frequency of validation during training in epochs.

    Returns:
        trainer: Trainer object for PyTorch model.
    """
    # Insert your setup code here




    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    #  trainer = torch_module.ModelTrainer(hyperparam1, hyperparam2, hyperparam3, hyperparam4)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    return trainer

###################################################################
#                   Train One Iteration
###################################################################
def train_one_iteration(trainer, x_train: np.ndarray, y_train: np.ndarray):
    """
    Implement a function that trains a PyTorch model for a single iteration.

    This function is for online training, where data is passed to Python 
    for each training iteration. Update the inputs based on your training 
    needs. This example assumes one trainer, and training data as arrays.

    Args:
        trainer: PyTorch model trainer object.
        x_train (np.ndarray): Network input for training.
        y_train (np.ndarray): Network output or target for training.

    Returns:
        training_loss (float): The training loss for this iteration.
    """
    # Insert your training code here

    

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # loss = trainer.update(x_train, y_train)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    return loss

###################################################################
#                   Validation
###################################################################
def validate(trainer, x_val: np.ndarray, y_val: np.ndarray):
    """
    Implement a function that validates a PyTorch model during online training.

    This function is for online training, where data is passed to Python 
    for each training iteration. Update the inputs based on your training 
    needs. This example assumes one trainer, and validation data as arrays.

    Args:
        trainer: PyTorch model trainer object.
        x_val (np.ndarray): Network input for validation.
        y_val (np.ndarray): Network output or target for validation.

    Returns:
        val_loss (float): The validation loss.
    """

    # Insert your validation code here



    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # val_loss = trainer.evaluate(x_val, y_val)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    # Return the average validation loss
    return val_loss

###################################################################
#                   Prediction
###################################################################
def predict(my_model, x_data):
    """
    Implement a function that generates predictions using a trained PyTorch model.

    This function provides a generic prediction mechanism. Update the inputs 
    based on your needs. This example assumes a model and input data array.

    Args:
        my_model: The trained PyTorch model.
        x_data (np.ndarray): Input data for predictions.

    Returns:
        predictions (np.ndarray): Predicted output data.
    """

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # device = select_device()
    #
    # # Convert input data to a PyTorch tensor and move it to the specified device
    # x_tensor = torch.from_numpy(x_data).float().to(device)
    #
    # # Set the model to evaluation mode
    # my_model.eval().to(device)
    #
    # # Disable gradient calculation
    # with torch.no_grad():
    #     # Perform the forward pass to get predictions
    #     outputs = my_model(x_tensor)
    #
    # # Move the predictions back to the CPU and convert them to a numpy array
    # predictions = outputs.cpu().numpy()
    #----------------------------------------
    #      Example Code End
    #----------------------------------------
    
    return predictions

###################################################################
#                   Save model weights
###################################################################
def save_model_weigths(my_model, filename: str):
    """
    Implement a function that saves the PyTorch model's state dictionary to a file.

    This function saves the model's state dictionary to a specified file.
    Ensure the filename has a '.pth' extension for compatibility.

    Args:
        my_model: The trained PyTorch model.
        filename (str): The desired filename for saving the model state.

    Returns:
        str: The actual filename used for saving the model state dictionary.
    """

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # # Ensure the filename has a .pth extension
    # base, _ = os.path.splitext(filename)
    # new_filename = base + '.pth'
    #
    # # Move the model to CPU for saving
    # my_model.to("cpu")
    #
    # # Save the model state dictionary and metadata to the specified file
    # torch.save({
    #     'model_state_dict': my_model.state_dict()
    # }, new_filename)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    # Return the filename used for saving
    return new_filename

###################################################################
#                   Load model weights
###################################################################
def load_model_weights(my_model, filename: str):
    """
    Implement a function that loads the PyTorch model's state dictionary from a file.

    This function loads the model's state dictionary from the specified file.
    Ensure the file has a '.pth' extension and exists.

    Args:
        my_model: The untrained PyTorch model.
        filename (str): The path to the file containing the saved model state.

    Returns:
        The PyTorch model with loaded weights.
    """

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # # Check if the filename has a .pth extension
    # if not filename.endswith('.pth'):
    #     raise ValueError("The file must have a .pth extension.")

    # # Check if the file exists
    # if not os.path.isfile(filename):
    #     raise FileNotFoundError(f"The file '{filename}' does not exist.")
    
    # # Load the checkpoint from the specified file
    # checkpoint = torch.load(filename, map_location=torch.device(device), weights_only=False)
    
    # # Load the model's state dictionary from the checkpoint
    # my_model.load_state_dict(checkpoint['model_state_dict'])
    
    # # Call flatten_parameters to optimize the model's internal state (if RNN)
    # if is_rnn_model(my_model):
    #     my_model.flatten_parameters()
    #----------------------------------------
    #      Example Code End
    #----------------------------------------
    
    # Return the updated instance
    return my_model

###################################################################
#                   Save model
###################################################################
def save(my_model, filename: str, param1: int, param2: float):
    """
    Implement a function that saves the PyTorch model's state dictionary and metadata to a file.

    This function saves the model's state dictionary and additional metadata
    to a specified file. Ensure the filename has a '.pth' extension.

    Update the inputs based on your needs. This example assumes a model, 
    a file name, and two parameters as metadata, which are needed to 
    instantiate the PyTorch model. Replace, remove or add parameters. 
    Update the param* varaibles in the following code. Use data type 
    hints to help MATLAB Python interface with data conversion. 

    Args:
        my_model: The trained PyTorch model.
        filename (str): The desired filename for saving the model state.
        param1 (int): Hyperparameter to be saved as metadata.
        param2 (float): Hyperparameter to be saved as metadata.

    Returns:
        str: The actual filename used for saving the model state dictionary.
    """

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # # Ensure the filename has a .pth extension
    # base, _ = os.path.splitext(filename)
    # new_filename = base + '.pth'
    #
    # # Move the model to CPU for saving
    # my_model.to("cpu")
    #
    # # Create metadata for the model
    # metadata = {
    #     'param1': param1, # <== must match inputs
    #     'param2': param2 # <== must match inputs
    # }
    #
    # # Save the model state dictionary and metadata to the specified file
    # torch.save({
    #     'model_state_dict': my_model.state_dict(),
    #     'metadata': metadata
    # }, new_filename)
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

    # Return the filename used for saving
    return new_filename

###################################################################
#                   Load model
###################################################################
def load(filename: str):
    """
    Implement a function that load the PyTorch model's state dictionary 
    and metadata from a file.

    This function loads the model's state dictionary and additional metadata
    from a specified file. Ensure the file has a '.pth' extension and exists.

    Update the inputs based on your needs. This example assumes a file name. 
    The following code assumes that the parameters needed to instantiate the 
    PyTorch model are saved as metadata. Replace, remove, or add parameters 
    as necessary. Use data type hints to help MATLAB Python interface with 
    data conversion.

    Args:
        filename (str): The path to the file containing the saved model state.

    Returns:
        The loaded PyTorch model with state and metadata.
    """

    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # # Check if the filename has a .pth extension
    # if not filename.endswith('.pth'):
    #     raise ValueError("The file must have a .pth extension.")
    #
    # # Check if the file exists
    # if not os.path.isfile(filename):
    #     raise FileNotFoundError(f"The file '{filename}' does not exist.")
    #
    # # Load the checkpoint from the specified file
    # checkpoint = torch.load(filename, map_location=torch.device(device), weights_only=False)
    # metadata = checkpoint['metadata']
    #
    # param1 = metadata['param1']  # <== must match the saved metadata
    # param2 = metadata['param2']  # <== must match the saved metadata
    # my_model = construct_model(param1, param2)  # <== must match construct_model inputs
    #
    # # Load the model's state dictionary from the checkpoint
    # my_model.load_state_dict(checkpoint['model_state_dict'])
    #
    # # Call flatten_parameters to optimize the model's internal state (if RNN)
    # if is_rnn_model(my_model)
    #     my_model.flatten_parameters()
    #----------------------------------------
    #      Example Code End
    #----------------------------------------
    
    # Return the updated instance
    return my_model

###################################################################
#                   Model Information
###################################################################
def info(my_model):
    """
    Implement a function that prints and/or returns information about the PyTorch model.

    This function prints the model architecture and returns the total number
    of learnable parameters.

    Args:
        my_model: The trained PyTorch model.

    Returns:
        int: Total number of parameters in the model.
    """
   
    #----------------------------------------
    #      Example Code Start
    #----------------------------------------
    # print("Model architecture:")
    # print(my_model)
    #
    # # Move the model to the appropriate device (if using GPU)
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # my_model.to(device)
    #
    # # Calculate the total number of parameters
    # total_params = sum(p.numel() for p in my_model.parameters())
    # return total_params
    #----------------------------------------
    #      Example Code End
    #----------------------------------------

###################################################################
#                   Helpers
###################################################################
def select_device(verbose: bool = False) -> torch.device:
    """
    Selects the computational device for PyTorch operations.

    This function chooses the appropriate computational device (GPU or CPU)
    based on the availability of CUDA.

    Args:
        verbose (bool): If True, prints information about the selected device.

    Returns:
        torch.device: The selected device, either GPU ('cuda') if available, or CPU ('cpu').
    """
    # Determine the device based on availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Verbose output for debugging and information
    if verbose:
        if device.type == "cuda":
            gpu_name = torch.cuda.get_device_name(device)
            total_memory = torch.cuda.get_device_properties(device).total_memory / (1024**3)  # Convert bytes to GB
            print(f"Selected device: GPU ({gpu_name}, {total_memory:.2f} GB)")
        else:
            print("Selected device: CPU")

    return device

def is_rnn_model(model: nn.Module):
    """
    Checks if the given model is an RNN or contains RNN layers.

    Args:
        model (nn.Module): The PyTorch model to check.

    Returns:
        bool: True if the model or any of its layers is an RNN, otherwise False.
    """
    # Check if the model itself is an instance of any RNN class
    if isinstance(model, (nn.RNN, nn.LSTM, nn.GRU)):
        return True
    
    # Alternatively, check if any layer in the model is an RNN
    for layer in model.children():
        if isinstance(layer, (nn.RNN, nn.LSTM, nn.GRU)):
            return True
    
    return False

See Also

Topics