%---------- Experiment parameters threshold = 1; % width of centers nhidden = 40; % # of centers maxepochs = 50; % for clustering (finding centers) showerr = 1; % whether to display -Likelihood while clustering trainsize = .5; % fraction of data as training set thresh = .01; sigma = 1; %---------- uses NETLAB Toolbox if ~exist('build'); load build.dat; end if ~exist('bo'); load bo.dat; end d = 14; % # dimensions dy = 3; % # outputs E = 10; % # training epochs %---------- all training data inpt = build(:,1:d)'; oupt = build(:,d+1:d+dy)'; [d N] = size(inpt); trange = 1:round(trainsize*N); tsize = length(trange); vrange = (max(trange)+1):N; vsize = length(vrange); tin = inpt(:,trange); tout = oupt(:,trange); vin = inpt(:,vrange); vout = oupt(:,vrange); %---------- ordered data io = bo(:,1:d)'; oo = bo(:,d+1:d+dy)'; %-------------------- RBF init options(1) = showerr; options(5) = 1; options(14) = maxepochs; net = rbf(d,nhidden,dy,'gaussian'); %-------------------- RBF train %net.wi = sigma*ones(size(net.wi)); net = rbftrain(net,options,tin',tout'); %----- reset RBF widths to sigma and recompute linear weights break sigma = 1; %----- compute Euclidean distance of centers to others wdist = dist(net.c,net.c'); % netlab's function thebig = max(max(wdist)); wzero = find(wdist<=thresh); wdist(wzero) = (thebig*2)*ones(size(wzero)); net.wi = sigma*min(wdist); [y,act] = rbffwd(net,tin'); temp = pinv([act ones(tsize,1)])*tout'; net.w2 = temp(1:nhidden,:); net.b2 = temp(nhidden+1,:); sigma %-------------------- train MSE touthat = rbffwd(net,tin')'; terr = touthat-tout; tserr = terr.*terr; tsse = sum(sum(tserr)); tmse = tsse/tsize %-------------------- validate MSE vouthat = rbffwd(net,vin')'; verr = vouthat-vout; vserr = verr.*verr; vsse = sum(sum(vserr)); vmse = vsse/tsize %----------- plot true vs predicted yhat = rbffwd(net,inpt')'; figure(1); for y = 1:dy subplot(1,3,y); plot(oupt(y,trange),yhat(y,trange),'rx',oupt(y,vrange),yhat(y,vrange),'b.',[0 1],[0,1],'k--'); axis([0 1 0 1]); ylabel('True'); xlabel('NN'); axis square; tstring = sprintf('(%d) True vs NN output',y); title(tstring); end %-------------------- plot some of the outputs w.r.t. ordered outputs clf; yo = rbffwd(net,io')'; prange = 1:500; for y = 1:dy subplot(dy,1,y); plot(oo(y,prange),'b-'); hold on; plot(yo(y,prange),'r-'); xlabel('Sample'); ystring = sprintf('Output %d',y); ylabel(ystring); % legend('Original','NN'); end