@@ -111,20 +111,42 @@ end
111
111
# ShortestPathGraphKernel
112
112
# ================================================================
113
113
114
+
115
+ """
116
+ ShortestPathGraphKernel <: AbstractGraphKernel
117
+
118
+ A graph kernel that compares two graphs `g` and `g'` by comparing all pairs
119
+ of vertices `(u, v)` of the `g` and `(u', v')` of `g'` if their shortest distance
120
+ is smaller than `tol`. In that case, the vertices `u` and `u'`, as well as `v`, `v'` are
121
+ compared with `vertex_kernel`.
122
+
123
+ # Keywords
124
+ - `tol=0.0`: Only pairs of vertices where the shortest distance is at most `tol` are
125
+ compared.
126
+ - `dist_key=:`: The key for the edge values to compute the shortest distance with. Can be either
127
+ an `Integer` or a `Symbol` for a key to a specific edge value, `nothing` to use a default distance
128
+ of `1` for each edge, or `:` in which case the default edge weight for that graph type
129
+ is used.
130
+ - `vertex_kernel=ConstVertexKernel(1.0)`: The kernel used to compare two vertices.
131
+
132
+ # References
133
+ [Borgwardt, K. M., & Kriegel, H. P.: Shortest-path kernels on graphs](https://www.dbs.ifi.lmu.de/~borgward/papers/BorKri05.pdf)
134
+ """
114
135
struct ShortestPathGraphKernel{VK <: AbstractVertexKernel } <: AbstractGraphKernel
115
136
116
137
tol:: Float64
117
138
vertex_kernel:: VK
139
+ dist_key:: Union{Int, Symbol, Colon, Nothing}
118
140
end
119
141
120
- function ShortestPathGraphKernel (;tol= 0.0 , vertex_kernel= ConstVertexKernel (1.0 ))
142
+ function ShortestPathGraphKernel (;tol= 0.0 , vertex_kernel= ConstVertexKernel (1.0 ), dist_key = Colon () )
121
143
122
- return ShortestPathGraphKernel (tol, vertex_kernel)
144
+ return ShortestPathGraphKernel (tol, vertex_kernel, dist_key )
123
145
end
124
146
125
147
function preprocessed_form (kernel:: ShortestPathGraphKernel , g:: AbstractGraph )
126
148
127
- dists = _make_dists (g)
149
+ dists = _make_dists (g, kernel . dist_key )
128
150
129
151
ds = map (t -> t. dist, dists)
130
152
us = map (t -> t. u, dists)
@@ -147,6 +169,7 @@ function apply_preprocessed(kernel::ShortestPathGraphKernel, pre1, pre2)
147
169
len1 = length (ds1)
148
170
len2 = length (ds2)
149
171
172
+ # TODO we are not using tol here at the moment
150
173
i2 = 1
151
174
@inbounds for i1 in Base. OneTo (length (ds1))
152
175
d1 = ds1[i1]
@@ -165,9 +188,16 @@ function apply_preprocessed(kernel::ShortestPathGraphKernel, pre1, pre2)
165
188
166
189
end
167
190
168
- function _make_dists (g)
191
+ function _make_dists (g, dist_key)
192
+
193
+ dists = if dist_key === Colon ()
194
+ floyd_warshall_shortest_paths (g). dists
195
+ elseif dist_key === nothing
196
+ floyd_warshall_shortest_paths (g, LightGraphs. DefaultDistance (nv (g))). dists
197
+ else
198
+ floyd_warshall_shortest_paths (g, weights (g, dist_key)). dists
199
+ end
169
200
170
- dists = floyd_warshall_shortest_paths (g). dists
171
201
verts = vertices (g)
172
202
tm = typemax (eltype (dists))
173
203
dists_list = [(dist= dists[u, v], u= u, v= v) for u in verts for v in verts if u != v && dists[u, v] != tm]
0 commit comments