Skip to content

Commit e9902c9

Browse files
authored
Upgrade antspy version (#25)
* Upgrade antspy version * Extend to other python versions * Modigy github.workflow * Try to make github actions work * Remove 3.8 and 3.9 because tf version * Try to remove smaller constrain on tensorflow * Try python version 3.8 * Try version 3.8 (2) * Remove tensorflow constrains and add all python versions * Try to add colors to CI * Remove Keras version * Focus on 3.8 * Upgrade tensorflow * Fix tf seed * Update to KerasTensor and compat.v1 * Remove Keras dependency * Remove allensdk dependency * Remove SimpleITK dependency * Fix issues * Add macos in github actions * Fix typo
1 parent cfdcd18 commit e9902c9

File tree

17 files changed

+131
-129
lines changed

17 files changed

+131
-129
lines changed

.github/workflows/run-tests.yml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,24 @@ jobs:
1111

1212
run_test:
1313

14-
runs-on: ubuntu-latest
14+
runs-on: ${{ matrix.os }}
1515

1616
strategy:
17+
1718
matrix:
18-
python-version: [3.7] # currently ANTsPy wheels are only avail for py37
19+
os: [ubuntu-latest] # macos 11 is currently in preview, macos-latest == 1.10.15
20+
python-version: [
21+
3.7,
22+
3.8,
23+
3.9,
24+
]
25+
include:
26+
- python-version: 3.7
27+
tox-env: py37
28+
- python-version: 3.8
29+
tox-env: py38
30+
- python-version: 3.9
31+
tox-env: py39
1932

2033
steps:
2134

@@ -30,13 +43,13 @@ jobs:
3043
- name: install python dependencies
3144
run: |
3245
python -m pip install --upgrade pip
33-
pip install tox
46+
pip install tox tox-gh-actions
3447
3548
- name: linting and code style
3649
run: tox -vv -e lint
3750

3851
- name: tests and coverage
39-
run: tox -vv -e py37
52+
run: tox -vv -e ${{ matrix.tox-env }} -- --color=yes
4053

4154
- name: docs
4255
run: tox -vv -e docs

atlalign/allen/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
"""
2828

2929
import os
30+
import urllib
3031

3132
import matplotlib.pyplot as plt
3233
import numpy as np
3334
import requests
34-
from allensdk.api.queries.image_download_api import ImageDownloadApi
3535

3636
CACHE_FOLDER = os.path.expanduser("~/.atlalign/")
3737

3838

39-
def get_image(image_id, folder=None, **kwargs):
39+
def get_image(image_id, folder=None, expression=False):
4040
"""Get any image from Allen's database just by its id.
4141
4242
Notes
@@ -52,8 +52,9 @@ def get_image(image_id, folder=None, **kwargs):
5252
folder : str or LocalPath or None
5353
Local folder where image saved. If None then automatically defaults to `CACHE_FOLDER`.
5454
55-
**kwargs
56-
Additional parameters to be passed onto the `download_image` method of ``ImageDownloadApi``. See
55+
expression : bool
56+
If True, retrieve the specified SectionImage's expression mask image.
57+
Otherwise, retrieve the specified SectionImage.
5758
See references for details.
5859
5960
Returns
@@ -76,11 +77,7 @@ def get_image(image_id, folder=None, **kwargs):
7677
os.makedirs(folder)
7778

7879
# Create full path
79-
additional_speficier = "_".join(
80-
sorted(["{}_{}".format(k, v) for k, v in kwargs.items()])
81-
)
82-
if additional_speficier:
83-
additional_speficier = "_{}".format(additional_speficier)
80+
additional_speficier = "_expression" if expression else ""
8481
path = "{}{}{}.jpg".format(folder, image_id, additional_speficier)
8582

8683
# Check image exists
@@ -93,10 +90,13 @@ def get_image(image_id, folder=None, **kwargs):
9390
return img
9491

9592
else:
96-
97-
img_api = ImageDownloadApi()
98-
img_api.download_image(image_id, file_path=path, **kwargs)
99-
return get_image(image_id, **kwargs)
93+
image_url = (
94+
f"http://api.brain-map.org/api/v2/section_image_download/{str(image_id)}"
95+
)
96+
if expression:
97+
image_url += "?view=expression"
98+
urllib.request.urlretrieve(image_url, path)
99+
return get_image(image_id, expression=expression)
100100

101101

102102
def get_2d(image_id, ref2inp=False, add_last=False):

atlalign/base.py

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@
3737
SmoothBivariateSpline,
3838
griddata,
3939
)
40-
41-
try:
42-
import SimpleITK as sitk
43-
except ImportError as e:
44-
print(e)
45-
4640
from skimage.transform import resize
4741

4842
from atlalign.utils import griddata_custom
@@ -990,24 +984,23 @@ def pseudo_inverse(
990984
"""
991985
interpolator_kwargs = interpolator_kwargs or {}
992986

993-
if interpolation_method != "itk":
994-
x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0])))
995-
xi = (y, x)
996-
x_r, y_r = x.ravel(), y.ravel()
987+
x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0])))
988+
xi = (y, x)
989+
x_r, y_r = x.ravel(), y.ravel()
997990

998-
points = np.hstack(
999-
(
1000-
(y_r + self.delta_y.ravel()).reshape(-1, 1),
1001-
(x_r + self.delta_x.ravel()).reshape(-1, 1),
1002-
)
991+
points = np.hstack(
992+
(
993+
(y_r + self.delta_y.ravel()).reshape(-1, 1),
994+
(x_r + self.delta_x.ravel()).reshape(-1, 1),
1003995
)
996+
)
1004997

1005-
# Downsampling
1006-
points = points[::ds_f]
1007-
x_r_ds = x_r[::ds_f]
1008-
y_r_ds = y_r[::ds_f]
998+
# Downsampling
999+
points = points[::ds_f]
1000+
x_r_ds = x_r[::ds_f]
1001+
y_r_ds = y_r[::ds_f]
10091002

1010-
x_, y_ = points[:, 1], points[:, 0]
1003+
x_, y_ = points[:, 1], points[:, 0]
10111004

10121005
if interpolation_method == "griddata":
10131006
values_grid_x = griddata(points=points, values=x_r_ds, xi=xi)
@@ -1023,35 +1016,6 @@ def pseudo_inverse(
10231016
delta_x = values_grid_x.reshape(self.shape) - x
10241017
delta_y = values_grid_y.reshape(self.shape) - y
10251018

1026-
elif interpolation_method == "itk":
1027-
# ~ 30 ms per image
1028-
df_sitk = sitk.GetImageFromArray(
1029-
np.concatenate(
1030-
(self.delta_x[..., np.newaxis], self.delta_y[..., np.newaxis]),
1031-
axis=2,
1032-
),
1033-
isVector=True,
1034-
)
1035-
1036-
invertor = sitk.InvertDisplacementFieldImageFilter()
1037-
1038-
# Set behaviour
1039-
user_spec = {
1040-
"n_iter": interpolator_kwargs.get("n_iter", 20),
1041-
"tol": interpolator_kwargs.get("tol", 1e-3),
1042-
}
1043-
1044-
# invertor.EnforceBoundaryConditionOn()
1045-
invertor.SetMeanErrorToleranceThreshold(user_spec["tol"]) # big effect
1046-
invertor.SetMaximumNumberOfIterations(user_spec["n_iter"]) # big effect
1047-
1048-
# Run
1049-
df_sitk_inv = invertor.Execute(df_sitk)
1050-
1051-
delta_xy = sitk.GetArrayFromImage(df_sitk_inv)
1052-
1053-
delta_x, delta_y = delta_xy[..., 0], delta_xy[..., 1]
1054-
10551019
elif interpolation_method == "noop":
10561020
# for benchmarking purposes
10571021
delta_x, delta_y = np.zeros(self.shape), np.zeros(self.shape)

atlalign/metrics.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,13 +1000,14 @@ def perceptual_loss_img(y_true, y_pred, model="net-lin", net="vgg"):
10001000
y_true = np.stack((y_true,) * 3, axis=-1)
10011001
y_pred = np.stack((y_pred,) * 3, axis=-1)
10021002

1003-
image0_ph = tf.placeholder(tf.float32)
1004-
image1_ph = tf.placeholder(tf.float32)
1003+
image0_ph = tf.Variable(tf.float32) # noqa: F841
1004+
image1_ph = tf.Variable(tf.float32) # noqa: F841
10051005

1006-
distance_t = lpips_tf.lpips(image0_ph, image1_ph, model=model, net=net)
1006+
@tf.function
1007+
def lpips(image0_ph, image1_ph):
1008+
return lpips_tf.lpips(image0_ph, image1_ph, model=model, net=net)
10071009

1008-
with tf.Session() as session:
1009-
pl = session.run(distance_t, feed_dict={image0_ph: y_true, image1_ph: y_pred})
1010+
pl = lpips(y_true, y_pred)
10101011

10111012
tf.reset_default_graph()
10121013

atlalign/ml_utils/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import pathlib
2323

2424
import h5py
25-
import keras
2625
import mlflow
2726
import pandas as pd
27+
from tensorflow import keras
2828

2929
from atlalign.data import annotation_volume, segmentation_collapsing_labels
3030
from atlalign.metrics import evaluate_single

atlalign/ml_utils/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
import pathlib
2525

2626
import h5py
27-
import keras
2827
import mlflow
2928
import numpy as np
29+
from tensorflow import keras
3030

3131
from atlalign.base import DisplacementField
3232
from atlalign.data import nissl_volume

atlalign/ml_utils/layers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,19 @@
2121

2222
import numpy as np
2323
import tensorflow as tf
24-
from keras import backend as K
25-
from keras.engine.topology import Layer
26-
from keras.layers import (
24+
from tensorflow.keras import backend as K
25+
from tensorflow.keras.layers import (
2726
BatchNormalization,
2827
Conv2D,
2928
Dense,
3029
Flatten,
3130
Lambda,
31+
Layer,
3232
MaxPool2D,
3333
ReLU,
3434
Reshape,
3535
)
36+
from tensorflow_addons.image import resampler
3637

3738

3839
def K_meshgrid(x, y):
@@ -349,7 +350,7 @@ def call(self, tensors, mask=None):
349350

350351
f_x_f_y = grid + dvfs
351352

352-
output = tf.contrib.resampler.resampler(imgs, f_x_f_y)
353+
output = resampler(imgs, f_x_f_y)
353354

354355
return output
355356

atlalign/ml_utils/losses.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
along with this program. If not, see <https://www.gnu.org/licenses/>.
2020
"""
2121

22-
import keras
23-
import keras.backend as K
22+
import tensorflow.keras.backend as K
23+
from tensorflow import keras
2424

2525
try:
2626
import lpips_tf
@@ -141,7 +141,7 @@ def ncc(self, I, J): # noqa
141141
cc = cross * cross / (I_var * J_var + self.eps)
142142

143143
# return negative cc.
144-
return tf.reduce_mean(cc)
144+
return tf.reduce_mean(input_tensor=cc)
145145

146146
def loss(self, I, J): # noqa
147147
"""Compute loss."""
@@ -269,12 +269,12 @@ def _diffs(self, y):
269269
def loss(self, _, y_pred):
270270
"""Compute loss."""
271271
if self.penalty == "l1":
272-
df = [tf.reduce_mean(tf.abs(f)) for f in self._diffs(y_pred)]
272+
df = [tf.reduce_mean(input_tensor=tf.abs(f)) for f in self._diffs(y_pred)]
273273
else:
274274
assert self.penalty == "l2", (
275275
"penalty can only be l1 or l2. Got: %s" % self.penalty
276276
)
277-
df = [tf.reduce_mean(f * f) for f in self._diffs(y_pred)]
277+
df = [tf.reduce_mean(input_tensor=f * f) for f in self._diffs(y_pred)]
278278
return tf.add_n(df) / len(df)
279279

280280

@@ -310,7 +310,8 @@ def jacobian(_, y_pred):
310310

311311
n_pixels = tf.constant(np.prod(keras.backend.int_shape(det)[1:]), dtype=tf.float32)
312312
count_artifacts = tf.cast(
313-
tf.count_nonzero(tf.greater_equal(-det, 0.0), axis=(1, 2)), dtype=tf.float32
313+
tf.math.count_nonzero(tf.greater_equal(-det, 0.0), axis=(1, 2)),
314+
dtype=tf.float32,
314315
)
315316
perc_artifacts = count_artifacts / n_pixels
316317

@@ -394,10 +395,12 @@ def vector_distance(y_true, y_pred):
394395
diff = y_pred - y_true
395396

396397
vector_distance_per_output = tf.reduce_mean(
397-
tf.sqrt(tf.abs(tf.square(diff[..., 0]) + tf.square(diff[..., 1]) + 0.001)),
398+
input_tensor=tf.sqrt(
399+
tf.abs(tf.square(diff[..., 0]) + tf.square(diff[..., 1]) + 0.001)
400+
),
398401
axis=0,
399402
)
400-
vector_distance_average = tf.reduce_mean(vector_distance_per_output)
403+
vector_distance_average = tf.reduce_mean(input_tensor=vector_distance_per_output)
401404

402405
return vector_distance_average
403406

@@ -421,8 +424,8 @@ def mse_po(y_true, y_pred):
421424
diff = y_pred - y_true
422425

423426
vector_distance_per_output = tf.reduce_mean(
424-
tf.square(diff[..., 0]) + tf.square(diff[..., 1]), axis=0
427+
input_tensor=tf.square(diff[..., 0]) + tf.square(diff[..., 1]), axis=0
425428
)
426-
vector_distance_average = tf.reduce_mean(vector_distance_per_output)
429+
vector_distance_average = tf.reduce_mean(input_tensor=vector_distance_per_output)
427430

428431
return vector_distance_average

atlalign/ml_utils/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import pathlib
2323
from copy import deepcopy
2424

25-
from keras.layers import Lambda, concatenate
26-
from keras.models import Model
27-
from keras.models import load_model as load_model_keras
28-
from keras.models import model_from_json
25+
from tensorflow.keras.layers import Lambda, concatenate
26+
from tensorflow.keras.models import Model
27+
from tensorflow.keras.models import load_model as load_model_keras
28+
from tensorflow.keras.models import model_from_json
2929

3030
from atlalign.ml_utils import (
3131
Affine2DVF,
@@ -120,7 +120,7 @@ def save_model(model, path, separate=True, overwrite=True):
120120
raise ValueError("Please specify a path without extension (folder).")
121121

122122
if not separate:
123-
model.save(str(path) + ".h5", overwrite=overwrite)
123+
model.save(str(path) + ".h5", overwrite=overwrite, save_format="h5")
124124

125125
else:
126126
path_architecture = path / (path.stem + ".json")

atlalign/nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
along with this program. If not, see <https://www.gnu.org/licenses/>.
2020
"""
2121

22-
import keras
2322
import mlflow
2423
import numpy as np
25-
from keras.layers import (
24+
from tensorflow import keras
25+
from tensorflow.keras.layers import (
2626
Conv2D,
2727
Cropping2D,
2828
Dense,
@@ -36,7 +36,7 @@
3636
ZeroPadding2D,
3737
concatenate,
3838
)
39-
from keras.models import Model
39+
from tensorflow.keras.models import Model
4040

4141
from atlalign.ml_utils import (
4242
ALL_DVF_LOSSES,

0 commit comments

Comments
 (0)