GPU coder Auto code generation for DAGNetwork Fails

1 次查看(过去 30 天)
Dear cummunity,
I am trying to generate C/C++ code from an DAGNetwork, which is imported from an YOLOv7 ONNX model. When I try to run this model in Simulink Using the Predict block, no errrors occour everything works. I saved the DAGNetwork as *.mat file. Loading and running the model again is working well.
Next i wan to generate C/C++ code from the model via Simulink, I get the error message: ??? Layer hyper-parameters for custom layer 'concatenationLayer' must be numeric scalar, scalar logical, character or string array, or a matrix of type double or single.
Once remove this layer "concatenationLayer", the code generation works perfectly .
I undestand that the concatenationLayer is a custom layer, But I am wondering how can I get ride of this error ?
  4 个评论
Sergio Matiz Romero
Hi Oualid,
Thank you for sharing additional details on the issue you are facing. Based on the code you shared, the custom layer property ONNXParams in Reshape_To_ConcatLayer1211 is a likely a structure:
[output, outputNumDims, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, NumDims.onnx__Reshape_488, NumDims.onnx__Reshape_527, NumDims.onnx__Reshape_566, Vars, NumDims, Training, params.State);
Structure types are unsupported as layer properties for deep learning code generation based on:
Therefore, my recomendation in this case would be to modify the code such that the structure fields are stored as individual properties in the Reshape_To_ConcatLayer constructor. An example of what this would look like is provided below
classdef Reshape_To_ConcatLayer1211 < nnet.layer.Layer & nnet.layer.Formattable
properties
OnnxParam1
OnnxParam2
OnnxParam3
end
methods
function this = Reshape_To_ConcatLayer1211(name, onnxParams)
this.Name = name;
this.NumInputs = 3;
this.OutputNames = {'output'};
% Below I use generic names for the different properties and
% store them individually, avoiding the use of structures
OnnxParam1 = onnxParams.Param1;
OnnxParam2 = onnxParams.Param2;
OnnxParam3 = onnxParams.Param3;
end
Then you would have to modify the code such that it picks up the individual properties, instead of using the structure ONNXParams.
I hope this helps you solve the issue you are facing
Oualid
Oualid 2022-8-4
编辑:Oualid 2022-8-4
Hi Sergio,
Thank you for your assistance. I changed the code as you suggested but still have the same issue. Here are the steps I followed :
- 1 Load the ONNX model as DAGnetwork.
net = importONNXNetwork(modelfile,OutputDataFormats="TBC")
-2 Created my own Graph :
Lgraph = net.layerGraph
-3 Extract the ONNXparameters of the Layer causing the problem ( Reshape_To_ConcatLayer1211 )
Param = Lgraph.Layers(300,1).ONNXParams
4- Create New Layer Named (Reshape_To_ConcatLayer1211) and pass the 5 Parameters one by one : I also changed the Autogenerated Layer to accept 5 params input in the way you suggested .
layer = Reshape_To_ConcatLayer1211('Reshape_To_ConcatFcn',Param.Learnables,Param.Nonlearnables,Param.State,Param.NumDimensions,Param.NetworkFunctionName);
5- Replace The Old layer in Lgraph and create new graph named mygraph :
mygraph= replaceLayer(Lgraph,'Reshape_To_ConcatLayer1211',layer,'ReconnectBy','order')
6-Create My own Network
mynet = assembleNetwork(mygraph)
7- Save it as .mat
8- Load to Simulink and Run sim --->>> everything works perfectly
9- Auto code Gen faills with the exactly same Error .
The modified code is below :
classdef Reshape_To_ConcatLayer1211 < nnet.layer.Layer & nnet.layer.Formattable
% A custom layer auto-generated while importing an ONNX network.
%#codegen
%#ok<*PROPLC>
%#ok<*NBRAK>
%#ok<*INUSL>
%#ok<*VARARG>
properties (Learnable)
end
properties
OnnxParam1
OnnxParam2
OnnxParam3
OnnxParam4
OnnxParam5
end
methods
function this = Reshape_To_ConcatLayer1211(name, onnxParam1,onnxParam2,onnxParam3,onnxParam4,onnxParam5)
this.Name = name;
this.NumInputs = 3;
this.OutputNames = {'output'};
this.Description = 'output';
this.Type = 'Relu';
this.OnnxParam1 = onnxParam1;
this.OnnxParam2 = onnxParam2;
this.OnnxParam3 = onnxParam3;
this.OnnxParam4 = onnxParam4;
this.OnnxParam5 = onnxParam5;
end
function [output] = predict(this, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566)
if isdlarray(onnx__Reshape_488)
onnx__Reshape_488 = stripdims(onnx__Reshape_488);
end
if isdlarray(onnx__Reshape_527)
onnx__Reshape_527 = stripdims(onnx__Reshape_527);
end
if isdlarray(onnx__Reshape_566)
onnx__Reshape_566 = stripdims(onnx__Reshape_566);
end
onnx__Reshape_488NumDims = 4;
onnx__Reshape_527NumDims = 4;
onnx__Reshape_566NumDims = 4;
Param1=this.OnnxParam1;
Param2= this.OnnxParam2;
Param3= this.OnnxParam3;
Param4= this.OnnxParam4;
Param5= this.OnnxParam5;
[output, outputNumDims] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims,Param1,Param2,Param3,Param4,Param5, 'Training', false, ...
'InputDataPermutation', {[4 3 1 2], [4 3 1 2], [4 3 1 2], ['as-is'], ['as-is'], ['as-is']}, ...
'OutputDataPermutation', {[3 2 1], ['as-is']});
if any(cellfun(@(A)isempty(A)||~isnumeric(A), {output}))
fprintf('Runtime error in network. The custom layer ''%s'' output an empty or non-numeric value.\n', 'Reshape_To_ConcatLayer1211');
error(message('nnet_cnn_onnx:onnx:BadCustomLayerRuntimeOutput', 'Reshape_To_ConcatLayer1211'));
end
output = dlarray(single(output), 'CBT');
if ~coder.target('MATLAB')
output = extractdata(output);
end
end
function [output] = forward(this, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566)
if isdlarray(onnx__Reshape_488)
onnx__Reshape_488 = stripdims(onnx__Reshape_488);
end
if isdlarray(onnx__Reshape_527)
onnx__Reshape_527 = stripdims(onnx__Reshape_527);
end
if isdlarray(onnx__Reshape_566)
onnx__Reshape_566 = stripdims(onnx__Reshape_566);
end
onnx__Reshape_488NumDims = 4;
onnx__Reshape_527NumDims = 4;
onnx__Reshape_566NumDims = 4;
Param1 = this.OnnxParam1;
Param2= this.OnnxParam2;
Param3= this.OnnxParam3;
Param4= this.OnnxParam4;
Param5= this.OnnxParam5;
[output, outputNumDims] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims,Param1,Param2,Param3,Param4,Param5, 'Training', true, ...
'InputDataPermutation', {[4 3 1 2], [4 3 1 2], [4 3 1 2], ['as-is'], ['as-is'], ['as-is']}, ...
'OutputDataPermutation', {[3 2 1], ['as-is']});
if any(cellfun(@(A)isempty(A)||~isnumeric(A), {output}))
fprintf('Runtime error in network. The custom layer ''%s'' output an empty or non-numeric value.\n', 'Reshape_To_ConcatLayer1211');
error(message('nnet_cnn_onnx:onnx:BadCustomLayerRuntimeOutput', 'Reshape_To_ConcatLayer1211'));
end
output = dlarray(single(output), 'CBT');
if ~coder.target('MATLAB')
output = extractdata(output);
end
end
end
end
function [output, outputNumDims, state] = Reshape_To_ConcatFcn(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims, onnx__Reshape_527NumDims, onnx__Reshape_566NumDims, Param1,Param2,Param3,Param4,Param5,varargin)
% Preprocess the input data and arguments:
[onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin{:});
% Put all variables into a single struct to implement dynamic scoping:
[Vars, NumDims] = packageVariables(Param1,Param2,Param3,Param4, {'onnx__Reshape_488', 'onnx__Reshape_527', 'onnx__Reshape_566'}, {onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566}, [onnx__Reshape_488NumDims onnx__Reshape_527NumDims onnx__Reshape_566NumDims]);
% Call the top-level graph function:
[output, outputNumDims, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, NumDims.onnx__Reshape_488, NumDims.onnx__Reshape_527, NumDims.onnx__Reshape_566, Vars, NumDims, Training, Param3);
% Postprocess the output data
[output] = postprocessOutput(output, outputDataPerms, anyDlarrayInputs, Training, varargin{:});
end
function [output, outputNumDims1210, state] = Reshape_To_ConcatGraph1200(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, onnx__Reshape_488NumDims1207, onnx__Reshape_527NumDims1208, onnx__Reshape_566NumDims1209, Vars, NumDims, Training, state)
% Function implementing the graph 'Reshape_To_ConcatGraph1200'
% Update Vars and NumDims from the graph's formal input parameters. Note that state variables are already in Vars.
Vars.onnx__Reshape_488 = onnx__Reshape_488;
NumDims.onnx__Reshape_488 = onnx__Reshape_488NumDims1207;
Vars.onnx__Reshape_527 = onnx__Reshape_527;
NumDims.onnx__Reshape_527 = onnx__Reshape_527NumDims1208;
Vars.onnx__Reshape_566 = onnx__Reshape_566;
NumDims.onnx__Reshape_566 = onnx__Reshape_566NumDims1209;
% Execute the operators:
% Reshape:
[shape, NumDims.onnx__Transpose_500] = prepareReshapeArgs(Vars.onnx__Reshape_488, Vars.onnx__Reshape_613, NumDims.onnx__Reshape_488, 0);
Vars.onnx__Transpose_500 = reshape(Vars.onnx__Reshape_488, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_501] = prepareTransposeArgs(Vars.TransposePerm1201, NumDims.onnx__Transpose_500);
if ~isempty(perm)
Vars.onnx__Sigmoid_501 = permute(Vars.onnx__Transpose_500, perm);
end
% Sigmoid:
Vars.y = sigmoid(Vars.onnx__Sigmoid_501);
NumDims.y = NumDims.onnx__Sigmoid_501;
% Split:
[Vars.onnx__Mul_503, Vars.onnx__Mul_504, Vars.onnx__Concat_505, NumDims.onnx__Mul_503, NumDims.onnx__Mul_504, NumDims.onnx__Concat_505] = onnxSplit(Vars.y, 4, Vars.SplitSplit1202, 0, NumDims.y);
% Mul:
Vars.onnx__Add_507 = Vars.onnx__Mul_503 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_507 = max(NumDims.onnx__Mul_503, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_509 = Vars.onnx__Add_507 + Vars.onnx__Add_508;
NumDims.onnx__Mul_509 = max(NumDims.onnx__Add_507, NumDims.onnx__Add_508);
% Mul:
Vars.onnx__Concat_511 = Vars.onnx__Mul_509 .* Vars.onnx__Mul_510;
NumDims.onnx__Concat_511 = max(NumDims.onnx__Mul_509, NumDims.onnx__Mul_510);
% Mul:
Vars.onnx__Pow_513 = Vars.onnx__Mul_504 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_513 = max(NumDims.onnx__Mul_504, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_516 = power(Vars.onnx__Pow_513, Vars.onnx__Pow_614);
NumDims.onnx__Mul_516 = max(NumDims.onnx__Pow_513, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_518 = Vars.onnx__Mul_516 .* Vars.onnx__Mul_517;
NumDims.onnx__Concat_518 = max(NumDims.onnx__Mul_516, NumDims.onnx__Mul_517);
% Concat:
[Vars.onnx__Reshape_519, NumDims.onnx__Reshape_519] = onnxConcat(4, {Vars.onnx__Concat_511, Vars.onnx__Concat_518, Vars.onnx__Concat_505}, [NumDims.onnx__Concat_511, NumDims.onnx__Concat_518, NumDims.onnx__Concat_505]);
% Reshape:
[shape, NumDims.onnx__Concat_526] = prepareReshapeArgs(Vars.onnx__Reshape_519, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_519, 0);
Vars.onnx__Concat_526 = reshape(Vars.onnx__Reshape_519, shape{:});
% Reshape:
[shape, NumDims.onnx__Transpose_539] = prepareReshapeArgs(Vars.onnx__Reshape_527, Vars.onnx__Reshape_624, NumDims.onnx__Reshape_527, 0);
Vars.onnx__Transpose_539 = reshape(Vars.onnx__Reshape_527, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_540] = prepareTransposeArgs(Vars.TransposePerm1203, NumDims.onnx__Transpose_539);
if ~isempty(perm)
Vars.onnx__Sigmoid_540 = permute(Vars.onnx__Transpose_539, perm);
end
% Sigmoid:
Vars.y_3 = sigmoid(Vars.onnx__Sigmoid_540);
NumDims.y_3 = NumDims.onnx__Sigmoid_540;
% Split:
[Vars.onnx__Mul_542, Vars.onnx__Mul_543, Vars.onnx__Concat_544, NumDims.onnx__Mul_542, NumDims.onnx__Mul_543, NumDims.onnx__Concat_544] = onnxSplit(Vars.y_3, 4, Vars.SplitSplit1204, 0, NumDims.y_3);
% Mul:
Vars.onnx__Add_546 = Vars.onnx__Mul_542 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_546 = max(NumDims.onnx__Mul_542, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_548 = Vars.onnx__Add_546 + Vars.onnx__Add_547;
NumDims.onnx__Mul_548 = max(NumDims.onnx__Add_546, NumDims.onnx__Add_547);
% Mul:
Vars.onnx__Concat_550 = Vars.onnx__Mul_548 .* Vars.onnx__Mul_549;
NumDims.onnx__Concat_550 = max(NumDims.onnx__Mul_548, NumDims.onnx__Mul_549);
% Mul:
Vars.onnx__Pow_552 = Vars.onnx__Mul_543 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_552 = max(NumDims.onnx__Mul_543, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_555 = power(Vars.onnx__Pow_552, Vars.onnx__Pow_614);
NumDims.onnx__Mul_555 = max(NumDims.onnx__Pow_552, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_557 = Vars.onnx__Mul_555 .* Vars.onnx__Mul_556;
NumDims.onnx__Concat_557 = max(NumDims.onnx__Mul_555, NumDims.onnx__Mul_556);
% Concat:
[Vars.onnx__Reshape_558, NumDims.onnx__Reshape_558] = onnxConcat(4, {Vars.onnx__Concat_550, Vars.onnx__Concat_557, Vars.onnx__Concat_544}, [NumDims.onnx__Concat_550, NumDims.onnx__Concat_557, NumDims.onnx__Concat_544]);
% Reshape:
[shape, NumDims.onnx__Concat_565] = prepareReshapeArgs(Vars.onnx__Reshape_558, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_558, 0);
Vars.onnx__Concat_565 = reshape(Vars.onnx__Reshape_558, shape{:});
% Reshape:
[shape, NumDims.onnx__Transpose_578] = prepareReshapeArgs(Vars.onnx__Reshape_566, Vars.onnx__Reshape_635, NumDims.onnx__Reshape_566, 0);
Vars.onnx__Transpose_578 = reshape(Vars.onnx__Reshape_566, shape{:});
% Transpose:
[perm, NumDims.onnx__Sigmoid_579] = prepareTransposeArgs(Vars.TransposePerm1205, NumDims.onnx__Transpose_578);
if ~isempty(perm)
Vars.onnx__Sigmoid_579 = permute(Vars.onnx__Transpose_578, perm);
end
% Sigmoid:
Vars.y_7 = sigmoid(Vars.onnx__Sigmoid_579);
NumDims.y_7 = NumDims.onnx__Sigmoid_579;
% Split:
[Vars.onnx__Mul_581, Vars.onnx__Mul_582, Vars.onnx__Concat_583, NumDims.onnx__Mul_581, NumDims.onnx__Mul_582, NumDims.onnx__Concat_583] = onnxSplit(Vars.y_7, 4, Vars.SplitSplit1206, 0, NumDims.y_7);
% Mul:
Vars.onnx__Add_585 = Vars.onnx__Mul_581 .* Vars.onnx__Pow_614;
NumDims.onnx__Add_585 = max(NumDims.onnx__Mul_581, NumDims.onnx__Pow_614);
% Add:
Vars.onnx__Mul_587 = Vars.onnx__Add_585 + Vars.onnx__Add_586;
NumDims.onnx__Mul_587 = max(NumDims.onnx__Add_585, NumDims.onnx__Add_586);
% Mul:
Vars.onnx__Concat_589 = Vars.onnx__Mul_587 .* Vars.onnx__Mul_588;
NumDims.onnx__Concat_589 = max(NumDims.onnx__Mul_587, NumDims.onnx__Mul_588);
% Mul:
Vars.onnx__Pow_591 = Vars.onnx__Mul_582 .* Vars.onnx__Pow_614;
NumDims.onnx__Pow_591 = max(NumDims.onnx__Mul_582, NumDims.onnx__Pow_614);
% Pow:
Vars.onnx__Mul_594 = power(Vars.onnx__Pow_591, Vars.onnx__Pow_614);
NumDims.onnx__Mul_594 = max(NumDims.onnx__Pow_591, NumDims.onnx__Pow_614);
% Mul:
Vars.onnx__Concat_596 = Vars.onnx__Mul_594 .* Vars.onnx__Mul_595;
NumDims.onnx__Concat_596 = max(NumDims.onnx__Mul_594, NumDims.onnx__Mul_595);
% Concat:
[Vars.onnx__Reshape_597, NumDims.onnx__Reshape_597] = onnxConcat(4, {Vars.onnx__Concat_589, Vars.onnx__Concat_596, Vars.onnx__Concat_583}, [NumDims.onnx__Concat_589, NumDims.onnx__Concat_596, NumDims.onnx__Concat_583]);
% Reshape:
[shape, NumDims.onnx__Concat_604] = prepareReshapeArgs(Vars.onnx__Reshape_597, Vars.onnx__Reshape_618, NumDims.onnx__Reshape_597, 0);
Vars.onnx__Concat_604 = reshape(Vars.onnx__Reshape_597, shape{:});
% Concat:
[Vars.output, NumDims.output] = onnxConcat(1, {Vars.onnx__Concat_526, Vars.onnx__Concat_565, Vars.onnx__Concat_604}, [NumDims.onnx__Concat_526, NumDims.onnx__Concat_565, NumDims.onnx__Concat_604]);
% Set graph output arguments from Vars and NumDims:
output = Vars.output;
outputNumDims1210 = NumDims.output;
% Set output state from Vars:
state = updateStruct(state, Vars);
end
function [inputDataPerms, outputDataPerms, Training] = parseInputs(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, numDataOutputs, varargin)
% Function to validate inputs to Reshape_To_ConcatFcn:
p = inputParser;
isValidArrayInput = @(x)isnumeric(x) || isstring(x);
addRequired(p, 'onnx__Reshape_488', isValidArrayInput);
addRequired(p, 'onnx__Reshape_527', isValidArrayInput);
addRequired(p, 'onnx__Reshape_566', isValidArrayInput);
addParameter(p, 'InputDataPermutation', 'auto');
addParameter(p, 'OutputDataPermutation', 'auto');
addParameter(p, 'Training', false);
parse(p, onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin{:});
inputDataPerms = p.Results.InputDataPermutation;
outputDataPerms = p.Results.OutputDataPermutation;
Training = p.Results.Training;
if isnumeric(inputDataPerms)
inputDataPerms = {inputDataPerms};
end
if isstring(inputDataPerms) && isscalar(inputDataPerms) || ischar(inputDataPerms)
inputDataPerms = repmat({inputDataPerms},1,3);
end
if isnumeric(outputDataPerms)
outputDataPerms = {outputDataPerms};
end
if isstring(outputDataPerms) && isscalar(outputDataPerms) || ischar(outputDataPerms)
outputDataPerms = repmat({outputDataPerms},1,numDataOutputs);
end
end
function [onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, varargin)
% Parse input arguments
[inputDataPerms, outputDataPerms, Training] = parseInputs(onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566, 1, varargin{:});
anyDlarrayInputs = any(cellfun(@(x)isa(x, 'dlarray'), {onnx__Reshape_488, onnx__Reshape_527, onnx__Reshape_566}));
% Make the input variables into unlabelled dlarrays:
onnx__Reshape_488 = makeUnlabeledDlarray(onnx__Reshape_488);
onnx__Reshape_527 = makeUnlabeledDlarray(onnx__Reshape_527);
onnx__Reshape_566 = makeUnlabeledDlarray(onnx__Reshape_566);
% Permute inputs if requested:
onnx__Reshape_488 = permuteInputVar(onnx__Reshape_488, inputDataPerms{1}, 4);
onnx__Reshape_527 = permuteInputVar(onnx__Reshape_527, inputDataPerms{2}, 4);
onnx__Reshape_566 = permuteInputVar(onnx__Reshape_566, inputDataPerms{3}, 4);
end
function [output] = postprocessOutput(output, outputDataPerms, anyDlarrayInputs, Training, varargin)
% Set output type:
if ~anyDlarrayInputs && ~Training
if isdlarray(output)
output = extractdata(output);
end
end
% Permute outputs if requested:
output = permuteOutputVar(output, outputDataPerms{1}, 3);
end
%% dlarray functions implementing ONNX operators:
function [Y, numDimsY] = onnxConcat(ONNXAxis, XCell, numDimsXArray)
% Concatentation that treats all empties the same. Necessary because
% dlarray.cat does not allow, for example, cat(1, 1x1, 1x0) because the
% second dimension sizes do not match.
numDimsY = numDimsXArray(1);
XCell(cellfun(@isempty, XCell)) = [];
if isempty(XCell)
Y = dlarray([]);
else
if ONNXAxis<0
ONNXAxis = ONNXAxis + numDimsY;
end
DLTAxis = numDimsY - ONNXAxis;
Y = cat(DLTAxis, XCell{:});
end
end
function varargout = onnxSplit(X, ONNXaxis, splits, numSplits, numDimsX)
% Implements the ONNX Split operator
% ONNXaxis is origin 0. splits is a vector of the lengths of each segment.
% If numSplits is nonzero, instead split into segments of equal length.
if ONNXaxis<0
ONNXaxis = ONNXaxis + numDimsX;
end
DLTAxis = numDimsX - ONNXaxis;
if numSplits > 0
C = size(X, DLTAxis);
sz = floor(C/numSplits);
splits = repmat(sz, 1, numSplits);
else
splits = extractdata(splits);
end
S = struct;
S.type = '()';
S.subs = repmat({':'}, 1, ndims(X));
splitIndices = [0 cumsum(splits(:)')];
numY = numel(splitIndices)-1;
for i = 1:numY
from = splitIndices(i) + 1;
to = splitIndices(i+1);
S.subs{DLTAxis} = from:to;
% The first numY outputs are the Y's. The second numY outputs are their
% numDims. We assume all the outputs of Split have the same numDims as
% the input.
varargout{i} = subsref(X, S);
varargout{i + numY} = numDimsX;
end
end
function [DLTShape, numDimsY] = prepareReshapeArgs(X, ONNXShape, numDimsX, allowzero)
% Prepares arguments for implementing the ONNX Reshape operator
ONNXShape = flip(extractdata(ONNXShape)); % First flip the shape to make it correspond to the dimensions of X.
% In ONNX, 0 means "unchanged" if allowzero is false, and -1 means "infer". In DLT, there is no
% "unchanged", and [] means "infer".
DLTShape = num2cell(ONNXShape); % Make a cell array so we can include [].
% Replace zeros with the actual size if allowzero is true
if any(ONNXShape==0) && allowzero==0
i0 = find(ONNXShape==0);
DLTShape(i0) = num2cell(size(X, numDimsX - numel(ONNXShape) + i0)); % right-align the shape vector and dims
end
if any(ONNXShape == -1)
% Replace -1 with []
i = ONNXShape == -1;
DLTShape{i} = [];
end
if numel(DLTShape)==1
DLTShape = [DLTShape 1];
end
numDimsY = numel(ONNXShape);
end
function [perm, numDimsA] = prepareTransposeArgs(ONNXPerm, numDimsA)
% Prepares arguments for implementing the ONNX Transpose operator
if numDimsA <= 1 % Tensors of numDims 0 or 1 are unchanged by ONNX Transpose.
perm = [];
else
if isempty(ONNXPerm) % Empty ONNXPerm means reverse the dimensions.
perm = numDimsA:-1:1;
else
perm = numDimsA-flip(ONNXPerm);
end
end
end
%% Utility functions:
function s = appendStructs(varargin)
% s = appendStructs(s1, s2,...). Assign all fields in s1, s2,... into s.
if isempty(varargin)
s = struct;
else
s = varargin{1};
for i = 2:numel(varargin)
fromstr = varargin{i};
fs = fieldnames(fromstr);
for j = 1:numel(fs)
s.(fs{j}) = fromstr.(fs{j});
end
end
end
end
function checkInputSize(inputShape, expectedShape, inputName)
if numel(expectedShape)==0
% The input is a scalar
if ~isequal(inputShape, [1 1])
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, "[1,1]", inputSizeStr));
end
elseif numel(expectedShape)==1
% The input is a vector
if ~shapeIsColumnVector(inputShape) || ~iSizesMatch({inputShape(1)}, expectedShape)
expectedShape{2} = 1;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
else
% The input has 2 dimensions or more
% The input dimensions have been reversed; flip them back to compare to the
% expected ONNX shape.
inputShape = fliplr(inputShape);
% If the expected shape has fewer dims than the input shape, error.
if numel(expectedShape) < numel(inputShape)
expectedSizeStr = strjoin(["[", strjoin(string(expectedShape), ","), "]"], "");
error(message('nnet_cnn_onnx:onnx:InputHasGreaterNDims', inputName, expectedSizeStr));
end
% Prepad the input shape with trailing ones up to the number of elements in
% expectedShape
inputShape = num2cell([ones(1, numel(expectedShape) - length(inputShape)) inputShape]);
% Find the number of variable size dimensions in the expected shape
numVariableInputs = sum(cellfun(@(x) isa(x, 'char') || isa(x, 'string'), expectedShape));
% Find the number of input dimensions that are not in the expected shape
% and cannot be represented by a variable dimension
nonMatchingInputDims = setdiff(string(inputShape), string(expectedShape));
numNonMatchingInputDims = numel(nonMatchingInputDims) - numVariableInputs;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
if numNonMatchingInputDims == 0 && ~iSizesMatch(inputShape, expectedShape)
% The actual and expected input dimensions match, but in
% a different order. The input needs to be permuted.
error(message('nnet_cnn_onnx:onnx:InputNeedsPermute',inputName, expectedSizeStr, inputSizeStr));
elseif numNonMatchingInputDims > 0
% The actual and expected input sizes do not match.
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
end
end
function doesMatch = iSizesMatch(inputShape, expectedShape)
% Check whether the input and expected shapes match, in order.
% Size elements match if (1) the elements are equal, or (2) the expected
% size element is a variable (represented by a character vector or string)
doesMatch = true;
for i=1:numel(inputShape)
if ~(isequal(inputShape{i},expectedShape{i}) || ischar(expectedShape{i}) || isstring(expectedShape{i}))
doesMatch = false;
return
end
end
end
function sizeStr = makeSizeString(shape)
sizeStr = strjoin(["[", strjoin(string(shape), ","), "]"], "");
end
function isVec = shapeIsColumnVector(shape)
if numel(shape) == 2 && shape(2) == 1
isVec = true;
else
isVec = false;
end
end
function X = makeUnlabeledDlarray(X)
% Make numeric X into an unlabelled dlarray
if isa(X, 'dlarray')
X = stripdims(X);
elseif isnumeric(X)
if isinteger(X)
% Make ints double so they can combine with anything without
% reducing precision
X = double(X);
end
X = dlarray(X);
end
end
function [Vars, NumDims] = packageVariables(Param1,Param2,Param3,Param4, inputNames, inputValues, inputNumDims)
% inputNames, inputValues are cell arrays. inputRanks is a numeric vector.
Vars = appendStructs(Param1, Param2,Param3);
NumDims = Param4;
% Add graph inputs
for i = 1:numel(inputNames)
Vars.(inputNames{i}) = inputValues{i};
NumDims.(inputNames{i}) = inputNumDims(i);
end
end
function X = permuteInputVar(X, userDataPerm, onnxNDims)
% Returns reverse-ONNX ordering
if onnxNDims == 0
return;
elseif onnxNDims == 1 && isvector(X)
X = X(:);
return;
elseif isnumeric(userDataPerm)
% Permute into reverse ONNX ordering
if numel(userDataPerm) ~= onnxNDims
error(message('nnet_cnn_onnx:onnx:InputPermutationSize', numel(userDataPerm), onnxNDims));
end
perm = fliplr(userDataPerm);
elseif isequal(userDataPerm, 'auto') && onnxNDims == 4
% Permute MATLAB HWCN to reverse onnx (WHCN)
perm = [2 1 3 4];
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(X);
else
% userDataPerm is either 'none' or 'auto' with no default, which means
% it's already in onnx ordering, so just make it reverse onnx
perm = max(2,onnxNDims):-1:1;
end
X = permute(X, perm);
end
function Y = permuteOutputVar(Y, userDataPerm, onnxNDims)
switch onnxNDims
case 0
perm = [];
case 1
if isnumeric(userDataPerm)
% Use the user's permutation because Y is a column vector which
% already matches ONNX.
perm = userDataPerm;
elseif isequal(userDataPerm, 'auto')
% Treat the 1D onnx vector as a 2D column and transpose it
perm = [2 1];
else
% userDataPerm is 'none'. Leave Y alone because it already
% matches onnx.
perm = [];
end
otherwise
% ndims >= 2
if isnumeric(userDataPerm)
% Use the inverse of the user's permutation. This is not just the
% flip of the permutation vector.
perm = onnxNDims + 1 - userDataPerm;
elseif isequal(userDataPerm, 'auto')
if onnxNDims == 2
% Permute reverse ONNX CN to DLT CN (do nothing)
perm = [];
elseif onnxNDims == 4
% Permute reverse onnx (WHCN) to MATLAB HWCN
perm = [2 1 3 4];
else
% User wants the output in ONNX ordering, so just reverse it from
% reverse onnx
perm = onnxNDims:-1:1;
end
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(Y);
else
% userDataPerm is 'none', so just make it reverse onnx
perm = onnxNDims:-1:1;
end
end
if ~isempty(perm)
Y = permute(Y, perm);
end
end
function s = updateStruct(s, t)
% Set all existing fields in s from fields in t, ignoring extra fields in t.
for name = transpose(fieldnames(s))
s.(name{1}) = t.(name{1});
end
end

请先登录,再进行评论。

回答(0 个)

类别

Help CenterFile Exchange 中查找有关 Characters and Strings 的更多信息

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by