Skip to content

Commit 9f4aa72

Browse files
author
fabian.froehlich
committed
Merge branch 'stable'
2 parents ecb4c4d + 264a42a commit 9f4aa72

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2529
-1168
lines changed

.gitignore

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,32 +49,6 @@ index.html
4949

5050
AMICI.mlappinstall
5151

52-
examples/example_1/simulate_model_example_1.m
53-
54-
examples/example_2/simulate_model_example_2.m
55-
56-
examples/example_3/simulate_model_example_3.m
57-
58-
examples/example_4/simulate_model_example_4.m
59-
60-
examples/example_5/simulate_model_example_5.m
61-
62-
examples/example_6/simulate_model_example_6.m
63-
64-
examples/geneExpression/simulate_geneExpression.m
65-
66-
examples/example_1/html/*
67-
68-
examples/example_2/html/*
69-
70-
examples/example_3/html/*
71-
72-
examples/example_4/html/*
73-
74-
examples/example_5/html/*
75-
76-
examples/example_6/html/*
77-
7852
*.mat
7953

8054
mtoc/makeExampleDoc.m
@@ -170,3 +144,9 @@ examples/example_dirac_adjoint/dxdotdp_model_dirac_adjoint.m
170144
examples/example_adjoint/J_model_adjoint.m
171145

172146
examples/example_adjoint/dxdotdp_model_adjoint.m
147+
148+
examples/example_jakstat_adjoint_hvp/simulate_model_jakstat_adjoint_hvp.m
149+
150+
examples/example_dirac_adjoint_hessVecProd/simulate_model_dirac_adjoint_hessVecProd.m
151+
152+
examples/example_adjoint_hessian/simulate_model_adjoint_hessian.m

@amifun/gccode.m

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,35 @@
2222
this.sym = subs(this.sym,sym('D([2], am_min)'),sym('D2am_min'));
2323
end
2424

25-
if(model.splineflag)
26-
for nodes = [3,4,5,10]
27-
for ideriv = 1:nodes
28-
this.sym = subs(this.sym,sym(['D([' num2str(ideriv*2+1) '], spline_pos' num2str(nodes) ')']),sym(['D' num2str(ideriv*2+1) 'spline_pos' num2str(nodes)]));
29-
this.sym = subs(this.sym,sym(['D([' num2str(ideriv*2+1) '], spline' num2str(nodes) ')']),sym(['D' num2str(ideriv*2+1) 'spline' num2str(nodes)]));
25+
% If we have spline, we need to parse them to get derivatives
26+
if (model.splineflag)
27+
symstr = char(this.sym);
28+
if (strfind(symstr, 'spline'))
29+
tokens = regexp(symstr, 't\,\s(\w+\.\w+)\,', 'tokens');
30+
nNodes = round(str2double(tokens{1}));
31+
end
32+
if (regexp(symstr, 'D\(\[(\w+|\w+\,\w+)\]\,.am_spline'))
33+
isDSpline = true;
34+
else
35+
isDSpline = false;
36+
end
37+
38+
if (isDSpline)
39+
[~, nCol] = size(this.sym);
40+
for iCol = 1 : nCol
41+
for iNode = 1 : nNodes
42+
if (model.o2flag)
43+
for jNode = 1:nNodes
44+
this.sym(:,iCol) = subs(this.sym(:,iCol),sym(['D([' num2str(iNode*2+2) ', ' num2str(jNode*2+2) '], am_spline_pos)']),sym(['D' num2str(iNode*2+2) 'D' num2str(jNode*2+2) 'am_spline_pos']));
45+
this.sym(:,iCol) = subs(this.sym(:,iCol),sym(['D([' num2str(iNode*2+2) ', ' num2str(jNode*2+2) '], am_spline)']),sym(['D' num2str(iNode*2+2) 'D' num2str(jNode*2+2) 'am_spline']));
46+
end
47+
end
48+
this.sym(:,iCol) = subs(this.sym(:,iCol),sym(['D([' num2str(iNode*2+2) '], am_spline_pos)']),sym(['D' num2str(iNode*2+2) 'am_spline_pos']));
49+
this.sym(:,iCol) = subs(this.sym(:,iCol),sym(['D([' num2str(iNode*2+2) '], am_spline)']),sym(['D' num2str(iNode*2+2) 'am_spline']));
50+
end
3051
end
3152
end
32-
end
33-
53+
end
3454

3555
cstr = ccode(this.sym);
3656
if(~strcmp(cstr(3:4),'t0'))
@@ -46,8 +66,17 @@
4666
cstr = strrep(cstr,'log','amilog');
4767
% fix derivatives again (we cant do this before as this would yield
4868
% incorrect symbolic expressions
49-
cstr = regexprep(cstr,'D([0-9]*)([\w]*)\(','D$2\($1,');
69+
cstr = regexprep(regexprep(cstr,'D([0-9]*)([\w]*)\(','D$2\($1,'),'DD([0-9]*)([\w]*)\(','DD$2\($1,');
70+
cstr = strrep(strrep(cstr, 'DDam_spline', 'am_DDspline'), 'Dam_spline', 'am_Dspline');
5071

72+
if (model.splineflag)
73+
if (strfind(symstr, 'spline'))
74+
% The floating numbers after 't' must be converted to integers
75+
cstr = regexprep(cstr, '(spline|spline_pos)\(t\,\w+\.\w+\,', ['$1\(t\,', num2str(nNodes), '\,']);
76+
cstr = regexprep(cstr, '(spline|spline_pos)\((\w+)\,t\,\w+\.\w+\,', ['$1\($2\,t\,', num2str(nNodes), '\,']);
77+
cstr = regexprep(cstr, '(spline|spline_pos)\((\w+)\,(\w+)\,t\,\w+\.\w+\,', ['$1\($2\,$3\,t\,', num2str(nNodes), '\,']);
78+
end
79+
end
5180

5281
if(numel(cstr)>1)
5382

@@ -83,7 +112,7 @@
83112
cstr = regexprep(cstr,'xdotdp([0-9]+)',['xdotdp\[$1 + ip*' num2str(model.nx) '\]']);
84113

85114
if(strcmp(this.cvar,'qBdot'))
86-
cstr = regexprep(cstr,'qBdot\[([0-9]*)\]','qBdot\[ip]');
115+
cstr = regexprep(cstr,'qBdot\[([0-9]*)\]','qBdot\[ip+np*$1]');
87116
elseif(strcmp(this.cvar,'stau'))
88117
cstr = regexprep(cstr,'stau\[([0-9]*)\]','stau\[ip]');
89118
elseif(strcmp(this.cvar,'y'))
@@ -124,7 +153,7 @@
124153
cstr = regexprep(cstr,'dydx_([0-9]+)','dydx\[$1]');
125154
cstr = regexprep(cstr,'dydp_([0-9]+)',['dydp\[$1+ip*' num2str(model.ny) ']']);
126155
cstr = regexprep(cstr,'my_([0-9]+)','my\[it+nt*$1]');
127-
cstr = regexprep(cstr,'sdy_([0-9]+)','sd_y\[$1\]');
156+
cstr = regexprep(cstr,'sigma_y_([0-9]+)','sigma_y\[$1\]');
128157
cstr = regexprep(cstr,'dsdydp\[([0-9]*)\]','dsigma_ydp\[$1\]');
129158
if(strcmp(this.cvar,'sJy'))
130159
cstr = regexprep(cstr,'sy_([0-9]+)','sy\[$1\]');
@@ -141,7 +170,7 @@
141170
cstr = regexprep(cstr,'dzdp_([0-9]+)',['dzdp\[$1+ip*' num2str(model.nz) ']']);
142171
cstr = regexprep(cstr,'mz_([0-9]+)','mz\[nroots[ie]+nmaxevent*$1]');
143172
cstr = regexprep(cstr,'sz_([0-9]+)',['sz\[nroots[ie]+nmaxevent*\($1+ip*' num2str(model.nz) '\)\]']);
144-
cstr = regexprep(cstr,'sdz_([0-9]+)','sd_z\[$1\]');
173+
cstr = regexprep(cstr,'sigma_z_([0-9]+)','sigma_z\[$1\]');
145174
cstr = regexprep(cstr,'dsdzdp\[([0-9]*)\]','dsigma_zdp\[$1\]');
146175
cstr = regexprep(cstr,'z_([0-9]+)','z\[nroots[ie]+nmaxevent*$1\]');
147176
cstr = strrep(cstr,'=','+=');

@amifun/getArgs.m

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,21 @@
120120
case 's2root'
121121
this.argstr = '(realtype t, int ie, int *nroots, realtype *s2root, N_Vector x, N_Vector *sx, void *user_data)';
122122
case 'Jy'
123-
this.argstr = '(realtype t, int it, realtype *Jy, realtype *y, N_Vector x, realtype *my, realtype *sd_y, void *user_data)';
123+
this.argstr = '(realtype t, int it, realtype *Jy, realtype *y, N_Vector x, realtype *my, realtype *sigma_y, void *user_data)';
124124
case 'dJydx'
125-
this.argstr = '(realtype t, int it, realtype *dJydx, realtype *y, N_Vector x, realtype *dydx, realtype *my, realtype *sd_y, void *user_data)';
125+
this.argstr = '(realtype t, int it, realtype *dJydx, realtype *y, N_Vector x, realtype *dydx, realtype *my, realtype *sigma_y, void *user_data)';
126126
case 'dJydp'
127-
this.argstr = '(realtype t, int it, realtype *dJydp, realtype *y, N_Vector x, realtype *dydp, realtype *my, realtype *sd_y, realtype *dsigma_ydp, void *user_data)';
127+
this.argstr = '(realtype t, int it, realtype *dJydp, realtype *y, N_Vector x, realtype *dydp, realtype *my, realtype *sigma_y, realtype *dsigma_ydp, void *user_data)';
128128
case 'sJy'
129129
this.argstr = '(realtype t, int it, realtype *sJy, realtype *dJydy, realtype *dJydp, realtype *sy, void *user_data)';
130130
case 'Jz'
131-
this.argstr = '(realtype t, int ie, realtype *Jz, realtype *z, N_Vector x, realtype *mz, realtype *sd_z, void *user_data, void *temp_data)';
131+
this.argstr = '(realtype t, int ie, realtype *Jz, realtype *z, N_Vector x, realtype *mz, realtype *sigma_z, void *user_data, void *temp_data)';
132132
case 'dJzdx'
133-
this.argstr = '(realtype t, int ie, realtype *dJzdx, realtype *z, N_Vector x, realtype *dzdx, realtype *mz, realtype *sd_z, void *user_data, void *temp_data)';
133+
this.argstr = '(realtype t, int ie, realtype *dJzdx, realtype *z, N_Vector x, realtype *dzdx, realtype *mz, realtype *sigma_z, void *user_data, void *temp_data)';
134134
case 'dJzdp'
135-
this.argstr = '(realtype t, int ie, realtype *dJzdp, realtype *z, N_Vector x, realtype *dzdp, realtype *mz, realtype *sd_z, realtype *dsigma_zdp, void *user_data, void *temp_data)';
135+
this.argstr = '(realtype t, int ie, realtype *dJzdp, realtype *z, N_Vector x, realtype *dzdp, realtype *mz, realtype *sigma_z, realtype *dsigma_zdp, void *user_data, void *temp_data)';
136136
case 'sJz'
137-
this.argstr = '(realtype t, int ie, realtype *sJz, realtype *z, N_Vector x, realtype *dzdp, realtype *sz, realtype *mz, realtype *sd_z, realtype *dsigma_zdp, void *user_data, void *temp_data)';
137+
this.argstr = '(realtype t, int ie, realtype *sJz, realtype *z, N_Vector x, realtype *dzdp, realtype *sz, realtype *mz, realtype *sigma_z, realtype *dsigma_zdp, void *user_data, void *temp_data)';
138138
case 'w'
139139
this.argstr = ['(realtype t, N_Vector x, N_Vector dx, void *user_data)'];
140140
case 'dwdp'

@amifun/getSyms.m

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,14 @@
330330
this.strsym(idx) = Js;
331331

332332
case 'JB'
333-
this.sym=-transpose(model.fun.J.sym);
333+
this.sym = sym(zeros(model.nx, model.nx));
334+
% Augmentation needs a transposition in the submatrices
335+
for ix = 1 : model.ng
336+
for jx = 1 : model.ng
337+
this.sym((ix-1)*model.nxtrue+1:ix*model.nxtrue, (jx-1)*model.nxtrue+1:jx*model.nxtrue) = ...
338+
-transpose(model.fun.J.sym((ix-1)*model.nxtrue+1:ix*model.nxtrue, (jx-1)*model.nxtrue+1:jx*model.nxtrue));
339+
end
340+
end
334341

335342
case 'dxdotdp'
336343
if(~isempty(w))
@@ -408,19 +415,43 @@
408415
% transform to symbolic variable
409416
vs = sym(vs);
410417
% multiply
411-
this.sym = -transpose(model.fun.J.sym)*vs;
418+
this.sym = model.fun.JB.sym*vs;
412419

413420
case 'xBdot'
414421
if(strcmp(model.wtype,'iw'))
415422
syms t
416423
this.sym = diff(transpose(model.fun.M.sym),t)*model.fun.xB.sym + transpose(model.fun.M.sym)*model.fun.dxB.sym - transpose(model.fun.dfdx.sym)*model.fun.xB.sym;
417424
else
418-
this.sym = -transpose(model.fun.J.sym)*model.fun.xB.sym;
425+
% Augmenting the system needs transposition of submatrices
426+
% I'm sure, there is a more intelligent solution to it...
427+
this.sym = sym(zeros(nx,1));
428+
if(model.o2flag)
429+
for ix = 1 : model.ng
430+
for jx = 1 : model.ng
431+
this.sym((ix-1)*model.nxtrue+1 : ix*model.nxtrue) = this.sym((ix-1)*model.nxtrue+1 : ix*model.nxtrue) - ...
432+
transpose(model.fun.J.sym((ix-1)*model.nxtrue+1 : ix*model.nxtrue , (jx-1)*model.nxtrue+1 : jx*model.nxtrue)) ...
433+
* model.fun.xB.sym((jx-1)*model.nxtrue+1 : jx*model.nxtrue);
434+
end
435+
end
436+
else
437+
this.sym = -transpose(model.fun.J.sym) * model.fun.xB.sym;
438+
end
419439
end
420440

421441
case 'qBdot'
422-
this.sym = -transpose(model.fun.xB.sym)*model.fun.dxdotdp.sym;
423-
442+
% If we do second order adjoints, we have to augment
443+
if (model.nxtrue < nx)
444+
this.sym = sym(zeros(model.ng, model.np));
445+
this.sym(1,:) = -transpose(model.fun.xB.sym(1:model.nxtrue)) * model.fun.dxdotdp.sym(1:model.nxtrue, :);
446+
for ig = 2 : model.ng
447+
this.sym(ig,:) = ...
448+
-transpose(model.fun.xB.sym(1:model.nxtrue)) * model.fun.dxdotdp.sym((ig-1)*model.nxtrue+1 : ig*model.nxtrue, :) ...
449+
-transpose(model.fun.xB.sym((ig-1)*model.nxtrue+1 : ig*model.nxtrue)) * model.fun.dxdotdp.sym(1:model.nxtrue, :);
450+
end
451+
else
452+
this.sym = -transpose(model.fun.xB.sym)*model.fun.dxdotdp.sym;
453+
end
454+
424455
case 'dsigma_ydp'
425456
this.sym = jacobian(model.fun.sigma_y.sym,p);
426457
this = makeStrSyms(this);
@@ -490,7 +521,7 @@
490521
staus = sym(staus);
491522
% multiply
492523
this.strsym = staus;
493-
524+
494525
case 'deltax'
495526
if(nevent>0)
496527
this.sym = [model.event.bolus];
@@ -513,7 +544,7 @@
513544
for ievent = 1:nevent
514545
this.sym(:,ievent,:) = jacobian(model.fun.deltax.sym(:,ievent),x);
515546
end
516-
547+
517548
case 'ddeltaxdt'
518549
this.sym = diff(model.fun.deltax.sym,'t');
519550

@@ -547,9 +578,18 @@
547578
end
548579

549580
case 'deltaqB'
550-
this.sym = sym(zeros(np,nevent));
581+
if (model.nxtrue < nx)
582+
ng_tmp = round(nx / model.nxtrue);
583+
this.sym = sym(zeros(np*ng_tmp,nevent));
584+
else
585+
this.sym = sym(zeros(np,nevent));
586+
end
587+
551588
for ievent = 1:nevent
552-
this.sym(:,ievent) = transpose(model.fun.xB.sym)*squeeze(model.fun.ddeltaxdp.sym(:,ievent,:));
589+
this.sym(1:np,ievent) = transpose(model.fun.xB.sym)*squeeze(model.fun.ddeltaxdp.sym(:,ievent,:));
590+
% This is just a very quick fix. Events in adjoint systems
591+
% have to be implemented in a way more rigorous way later
592+
% on... Some day...
553593
end
554594

555595
case 'deltaxB'
@@ -625,19 +665,52 @@
625665

626666
case 'Jy'
627667
this.sym = model.sym.Jy;
668+
% replace unify symbolic expression
669+
this.sym = mysubs(this.sym,model.sym.y,model.fun.y.strsym);
628670
case 'dJydy'
629-
this.sym = jacobian(model.fun.Jy.sym,model.fun.y.strsym);
671+
this.sym = sym(zeros(model.nytrue, model.ng, model.ny));
672+
for iy = 1 : model.nytrue
673+
this.sym(iy,:,:) = jacobian(model.fun.Jy.sym(iy,:),model.fun.y.strsym);
674+
end
630675
this = makeStrSyms(this);
631676
case 'dJydx'
632-
this.sym = model.fun.dJydy.sym*model.fun.dydx.strsym;
677+
this.sym = sym(zeros(model.nytrue, model.nxtrue, model.ng)); % Maybe nxtrue is sufficient...
678+
dJydy_tmp = sym(zeros(model.ng, model.ny));
679+
for iy = 1 : model.nytrue
680+
dJydy_tmp(:,:) = model.fun.dJydy.sym(iy,:,:);
681+
this.sym(iy,:,:) = transpose(dJydy_tmp * model.fun.dydx.strsym(:,1:model.nxtrue));
682+
% Transposition is necessary to have things sorted
683+
% correctly in gccode.m
684+
end
685+
disp('');
633686
case 'dJydsigma'
634-
this.sym = jacobian(model.fun.Jy.sym,model.fun.sigma_y.strsym);
687+
this.sym = sym(zeros(model.nytrue, model.ng, model.nytrue));
688+
for iy = 1 : model.nytrue
689+
this.sym(iy,:,:) = jacobian(model.fun.Jy.sym(iy,:),model.fun.sigma_y.strsym(1:model.nytrue));
690+
end
635691
case 'dJydp'
636-
this.sym = model.fun.dJydy.sym*model.fun.dydp.strsym + model.fun.dJydsigma.sym*model.fun.dsigma_ydp.strsym;
692+
this.sym = sym(zeros(model.nytrue, model.np, model.ng));
693+
dJydy_tmp = sym(zeros(model.ng, model.ny));
694+
dJydsigma_tmp = sym(zeros(model.ng, model.nytrue));
695+
for iy = 1 : model.nytrue
696+
dJydy_tmp(:,:) = model.fun.dJydy.sym(iy,:,:);
697+
dJydsigma_tmp(:,:) = model.fun.dJydsigma.sym(iy,:,:);
698+
this.sym(iy,:,:) = transpose(dJydy_tmp * model.fun.dydp.strsym ...
699+
+ dJydsigma_tmp * model.fun.dsigma_ydp.strsym(1:model.nytrue,:));
700+
% Transposition is necessary to have things sorted
701+
% correctly in gccode.m
702+
end
637703
this = makeStrSyms(this);
638704
case 'sJy'
639-
this.sym = model.fun.dJydy.strsym*model.fun.sy.strsym + model.fun.dJydp.strsym;
640-
705+
this.sym = sym(zeros(model.nytrue, model.np, model.ng));
706+
dJydy_tmp = sym(zeros(model.ng, model.ny));
707+
for iy = 1 : model.nytrue
708+
dJydy_tmp(:,:) = model.fun.dJydy.strsym(iy,:,:);
709+
this.sym(iy,:,:) = transpose(dJydy_tmp*model.fun.sy.strsym);
710+
% Transposition is necessary to have things sorted
711+
% correctly in gccode.m
712+
end
713+
this.sym = this.sym + model.fun.dJydp.strsym;
641714
case 'Jz'
642715
this.sym = model.sym.Jz(:);
643716
case 'dJzdz'

@amifun/printLocalVars.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function printLocalVars(this,model,fid)
4646
fprintf(fid,'int ix;\n');
4747
fprintf(fid,['memset(xBdot_tmp,0,sizeof(realtype)*' num2str(nx) ');\n']);
4848
case 'qBdot'
49-
fprintf(fid,'memset(qBdot_tmp,0,sizeof(realtype)*np);\n');
49+
fprintf(fid,'memset(qBdot_tmp,0,sizeof(realtype)*np*ng);\n');
5050
case 'x0'
5151
fprintf(fid,['memset(x0_tmp,0,sizeof(realtype)*' num2str(nx) ');\n']);
5252
case 'dx0'
@@ -112,7 +112,7 @@ function printLocalVars(this,model,fid)
112112
case 'deltaxB'
113113
fprintf(fid,['memset(deltaxB,0,sizeof(realtype)*' num2str(nx) ');\n']);
114114
case 'deltaqB'
115-
fprintf(fid,['memset(deltaqB,0,sizeof(realtype)*np);\n']);
115+
fprintf(fid,['memset(deltaqB,0,sizeof(realtype)*np*ng);\n']);
116116
case 'deltasx'
117117
fprintf(fid,['memset(deltasx,0,sizeof(realtype)*' num2str(nx) '*np);\n']);
118118
case 'stau'

@amifun/writeCcode.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function writeCcode(this,model,fid)
8787
if(any(any(nonzero)))
8888
for iy = 1:model.nytrue
8989
fprintf(fid,['if(!mxIsNaN(my[' num2str(iy-1) '*nt+it])){\n']);
90-
tmpfun.sym = this.sym(iy,:);
90+
tmpfun.sym = permute(this.sym(iy,:,:),[2,3,1]);
9191
tmpfun.gccode(model,fid);
9292
fprintf(fid,'}\n');
9393
end

@amifun/writeCcode_sensi.m

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function writeCcode_sensi(this,model,fid)
1010
% void
1111

1212
np = model.np;
13+
ng = model.ng;
1314

1415
nonzero_idx = find(this.sym);
1516
nonzero = zeros(size(this.sym));
@@ -18,7 +19,7 @@ function writeCcode_sensi(this,model,fid)
1819
if(strcmp(this.funstr,'deltaqB'))
1920
if(any(nonzero))
2021
tmpfun = this;
21-
for ip=1:np
22+
for ip=1:np*ng
2223
if(nonzero(ip))
2324
fprintf(fid,[' case ' num2str(ip-1) ': {\n']);
2425
tmpfun.sym = this.sym(ip,:);
@@ -54,6 +55,20 @@ function writeCcode_sensi(this,model,fid)
5455
end
5556
end
5657
end
58+
elseif(strcmp(this.funstr,'qBdot'))
59+
nonzero = this.sym ~=0;
60+
if(any(any(nonzero)))
61+
tmpfun = this;
62+
for ip=1:np
63+
if(any(nonzero(:,ip)))
64+
fprintf(fid,[' case ' num2str(ip-1) ': {\n']);
65+
tmpfun.sym = this.sym(:,ip);
66+
tmpfun.writeCcode(model,fid);
67+
fprintf(fid,'\n');
68+
fprintf(fid,' } break;\n\n');
69+
end
70+
end
71+
end
5772
else
5873
nonzero = this.sym ~=0;
5974
if(any(any(nonzero)))

0 commit comments

Comments
 (0)