Skip to content

Commit 4a17b25

Browse files
committed
Add dist_key for ShortestPathGraphKernel
1 parent abecdd9 commit 4a17b25

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

src/graph_kernels.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,42 @@ end
111111
# ShortestPathGraphKernel
112112
# ================================================================
113113

114+
115+
"""
116+
ShortestPathGraphKernel <: AbstractGraphKernel
117+
118+
A graph kernel that compares two graphs `g` and `g'` by comparing all pairs
119+
of vertices `(u, v)` of the `g` and `(u', v')` of `g'` if their shortest distance
120+
is smaller than `tol`. In that case, the vertices `u` and `u'`, as well as `v`, `v'` are
121+
compared with `vertex_kernel`.
122+
123+
# Keywords
124+
- `tol=0.0`: Only pairs of vertices where the shortest distance is at most `tol` are
125+
compared.
126+
- `dist_key=:`: The key for the edge values to compute the shortest distance with. Can be either
127+
an `Integer` or a `Symbol` for a key to a specific edge value, `nothing` to use a default distance
128+
of `1` for each edge, or `:` in which case the default edge weight for that graph type
129+
is used.
130+
- `vertex_kernel=ConstVertexKernel(1.0)`: The kernel used to compare two vertices.
131+
132+
# References
133+
[Borgwardt, K. M., & Kriegel, H. P.: Shortest-path kernels on graphs](https://www.dbs.ifi.lmu.de/~borgward/papers/BorKri05.pdf)
134+
"""
114135
struct ShortestPathGraphKernel{VK <: AbstractVertexKernel} <: AbstractGraphKernel
115136

116137
tol::Float64
117138
vertex_kernel::VK
139+
dist_key::Union{Int, Symbol, Colon, Nothing}
118140
end
119141

120-
function ShortestPathGraphKernel(;tol=0.0, vertex_kernel=ConstVertexKernel(1.0))
142+
function ShortestPathGraphKernel(;tol=0.0, vertex_kernel=ConstVertexKernel(1.0), dist_key=Colon())
121143

122-
return ShortestPathGraphKernel(tol, vertex_kernel)
144+
return ShortestPathGraphKernel(tol, vertex_kernel, dist_key)
123145
end
124146

125147
function preprocessed_form(kernel::ShortestPathGraphKernel, g::AbstractGraph)
126148

127-
dists = _make_dists(g)
149+
dists = _make_dists(g, kernel.dist_key)
128150

129151
ds = map(t -> t.dist, dists)
130152
us = map(t -> t.u, dists)
@@ -147,6 +169,7 @@ function apply_preprocessed(kernel::ShortestPathGraphKernel, pre1, pre2)
147169
len1 = length(ds1)
148170
len2 = length(ds2)
149171

172+
# TODO we are not using tol here at the moment
150173
i2 = 1
151174
@inbounds for i1 in Base.OneTo(length(ds1))
152175
d1 = ds1[i1]
@@ -165,9 +188,16 @@ function apply_preprocessed(kernel::ShortestPathGraphKernel, pre1, pre2)
165188

166189
end
167190

168-
function _make_dists(g)
191+
function _make_dists(g, dist_key)
192+
193+
dists = if dist_key === Colon()
194+
floyd_warshall_shortest_paths(g).dists
195+
elseif dist_key === nothing
196+
floyd_warshall_shortest_paths(g, LightGraphs.DefaultDistance(nv(g))).dists
197+
else
198+
floyd_warshall_shortest_paths(g, weights(g, dist_key)).dists
199+
end
169200

170-
dists = floyd_warshall_shortest_paths(g).dists
171201
verts = vertices(g)
172202
tm = typemax(eltype(dists))
173203
dists_list = [(dist=dists[u, v], u=u, v=v) for u in verts for v in verts if u != v && dists[u, v] != tm]

0 commit comments

Comments
 (0)