-
Notifications
You must be signed in to change notification settings - Fork 1
/
randomizeModelParameters.m
35 lines (29 loc) · 1.21 KB
/
randomizeModelParameters.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
function [model, layersRandomized] = randomizeModelParameters(model, start, count)
% Copyright 2019 - 2020 The MathWorks, Inc.
% randomizeModelParameters function to randomise the weights of 'model'
% This function randomizes the parameters of a model. It accepts a 'SeriesNetwork' or 'DAGNetwork' as input.
% It assumes that start refers to the layer starting from the beginning of the network
% and count is the number of extra layers you want to randomise
% Example:
% randModel = randomizeModelParameters(model, 1, 4) % randomise the first 5 layers of the model
arguments
model
start (1,1) {mustBeNumeric} = 1
count (1,1) {mustBeNumeric} = 0
end
layers = model.Layers;
if any(strcmp(class(model), 'SeriesNetwork'))
lGraph = layerGraph(model.Layers);
else
lGraph = layerGraph(model);
end
layersRandomized = {};
for i=start:count + start
[layers(i), wasRandomized] = randomizeLayer(layers(i));
if wasRandomized
lGraph = replaceLayer(lGraph, lGraph.Layers(i).Name, layers(i));
layersRandomized{end+1} = layers(i).Name;
end
end
model = assembleNetwork(lGraph);
end