-
Notifications
You must be signed in to change notification settings - Fork 10
Create optimal_quantization.py #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
quantization/optimal_quantization.py
Outdated
@@ -0,0 +1,72 @@ | |||
""" | |||
Optimal quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe a little bit more what this deos please?
quantization/optimal_quantization.py
Outdated
S2metric = HypersphereMetric(dimension=2) | ||
|
||
TOLERANCE = 1e-5 | ||
IMPLEMENTED = ['S2'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a tuple, not a list
quantization/optimal_quantization.py
Outdated
|
||
def sample_from(points, size=1): | ||
""" | ||
Sample from the empirical distribution associated to points |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please be consistent with . at the end of the sentences in the comments.
quantization/optimal_quantization.py
Outdated
n_points = points.shape[0] | ||
dimension = points.shape[-1] | ||
|
||
ind = np.random.randint(low=0, high=n_points, size=size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use full words - ind ?
quantization/optimal_quantization.py
Outdated
# random initialization of the centers | ||
centers = sample_from(points, n_centers) | ||
|
||
gap = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gap = 1.0 to show that it is a floating point value
quantization/optimal_quantization.py
Outdated
|
||
while gap > tolerance: | ||
step += 1 | ||
k = np.floor(step / n_repetition) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the rationale for starting at step = 1 instead of step = 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Starting at step = 0 will replace the center to be updated by the new sample, instead of just moving it in the direction of it (as in any other step > 0). It is equivalent to a slightly different initialization of the centers (only one is different). Since there is no reason for that initialization to be better than the previous one, I start at step = 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok you should find a better name for k - hard to understand what this variable mean. Usually k will be used for indexing an array. variables like i,j k should have a trivial behavior like starting at 0 and incrementing. Any other more complex behavior should have a clear name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes Nina had the same comment, so I changed it to step_size
.
quantization/optimal_quantization.py
Outdated
|
||
centers[ind, :] = new_center | ||
|
||
return centers, step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need unit tests for this file
Corrections of the script including new descriptions, implementation of the circle (will be useful in unit test) and addition of new outputs.
Add karcher flow algorithm for the purpose of the unit test.
Thanks for the reviews ! Here are the last changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Some comments, mostly syntax.
quantization/optimal_quantization.py
Outdated
@@ -0,0 +1,171 @@ | |||
""" | |||
Optimal quantization of the empirical distribution of a dataset - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: why not a point at the end of the line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced it with a dot.
quantization/optimal_quantization.py
Outdated
|
||
def diameter_of_data(points, space=None): | ||
""" | ||
Compute the two points that are farthest away from each other in points. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Compute the distance between the two points
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thanks !
quantization/optimal_quantization.py
Outdated
return index_closest_neighbor | ||
|
||
|
||
def diameter_of_data(points, space=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: maybe diameter
would sound better and add to riemannian_metric.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected.
quantization/optimal_quantization.py
Outdated
n_points = points.shape[0] | ||
|
||
for i in range(n_points-1): | ||
dist_to_neighbors = metric.dist(points[i, :], points[i+1:, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this give only one real number: the distance between point i and point i+1? If so, I don't understand the next line taking the max of a single real number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because the second argument contains all the points from i+1 to the end (the ":" is hard to see).
quantization/optimal_quantization.py
Outdated
""" | ||
Compute the Karcher mean of points using a Karcher flow algorithm. | ||
Return : | ||
- the karcher mean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Uppercase letter for Karcher everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected.
sample = oq.sample_from(self.points) | ||
result = False | ||
for i in range(self.n_points): | ||
if (self.points[i, :] == sample).all(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use np.allclose(self.points[i, :], sample)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected.
result = False | ||
for i in range(self.n_points): | ||
if (self.points[i, :] == sample).all(): | ||
result = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can add a break
when sample is found to be one point, in order to stop the for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tip !
closest_neighbor = self.points[index, :] | ||
result = False | ||
for i in range(self.n_points): | ||
if (self.points[i, :] == closest_neighbor).all(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
|
||
import unittest | ||
|
||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use import geomstats.backend as gs
.
for i in range(self.n_points): | ||
tangent_vectors[i, :] = self.metric.log( | ||
self.points[i, :], karcher_mean) | ||
sum_tan_vec = np.sum(tangent_vectors, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: sum_tangent_vecs
?
quantization/optimal_quantization.py
Outdated
return diameter | ||
|
||
|
||
def karcher_flow(points, space=None, tolerance=TOLERANCE): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like: https://github.com/geomstats/geomstats/blob/master/geomstats/riemannian_metric.py#L219
Could we fuse the codes maybe?
quantization/optimal_quantization.py
Outdated
IMPLEMENTED = ('S1', 'S2') | ||
|
||
|
||
def sample_from(points, size=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a function doing this already, something like: https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.choice.html
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like numpy.random.choice
takes only 1D arrays as entries, and so it cannot be directly used to sample from a set of points in 2 or more dimensions. I can use it to choose an index, but I don't think it would be very different from using numpy.random.randint
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, true. Too bad, it seems that they wanted to add the option but haven't done it so far: numpy/numpy#7810
quantization/optimal_quantization.py
Outdated
return sample | ||
|
||
|
||
def closest_neighbor(point, neighbors, space=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to riemannian_metric.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !
Remove closest_neighbor, diameter and karcher_flow
I had missed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Almost there! I added a decent amount of new comments, bcs I've understood more of the code with this second review :)
quantization/optimal_quantization.py
Outdated
IMPLEMENTED = ('S1', 'S2') | ||
|
||
|
||
def sample_from(points, size=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, true. Too bad, it seems that they wanted to add the option but haven't done it so far: numpy/numpy#7810
quantization/optimal_quantization.py
Outdated
IMPLEMENTED = ('S1', 'S2') | ||
|
||
|
||
def sample_from(points, size=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about n_samples=1
instead of size=1
to be consistent with the other sampling functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected.
quantization/optimal_quantization.py
Outdated
dimension = points.shape[-1] | ||
|
||
index = gs.random.randint(low=0, high=n_points, size=size) | ||
sample = points[gs.ix_(index, gs.arange(dimension))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I don't understand the size
parameter. Why not:
index = gs.random.randint(low=0, high=n_points, size=(n_samples,))
sample = points[index, :]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that size=n_samples
(previously size=size
) gives the same result as size=(n_samples,)
.
quantization/optimal_quantization.py
Outdated
if gs.isclose(gap, 0, atol=tolerance): | ||
break | ||
|
||
if iteration is n_max_iterations: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be n_max_iterations-1
or the while loop condition above should be <=.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ==
to compare integers.
https://stackoverflow.com/questions/2239737/is-it-better-to-use-is-or-for-number-comparison-in-python
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thanks !
quantization/optimal_quantization.py
Outdated
else: | ||
metric = HypersphereMetric(dimension=2) | ||
|
||
# random initialization of the centers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this comment: the code is clear enough and this will save us from meaningless leftover comments if we later change the code but forget to adapt the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected.
quantization/plot_quantization_s1.py
Outdated
centers, weights, clustering, n_iterations = oq.optimal_quantization( | ||
points, n_centers, space='S1', n_repetitions=20, tolerance=1e-6 | ||
) | ||
theta = gs.linspace(0, 2*gs.pi, 100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add this as a class Circle
in visualization.py
, similar to the class Sphere
? Thank you!
https://github.com/geomstats/geomstats/blob/master/geomstats/visualization.py#L64
quantization/plot_quantization_s1.py
Outdated
from geomstats.hypersphere import Hypersphere | ||
|
||
CIRCLE = Hypersphere(dimension=1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add all the constants at the beginning:
N_POINTS = 1000
N_CENTERS = 5
N_REPETITIONS = 20
TOLERANCE=1e-6
quantization/optimal_quantization.py
Outdated
|
||
def optimal_quantization(points, n_centers=10, space=None, n_repetitions=20, | ||
tolerance=TOLERANCE, n_max_iterations=50000): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a short explanation about how you use n_repetitions
? thanks.
quantization/plot_quantization_s2.py
Outdated
plt.figure(1) | ||
ax = plt.subplot(111, projection="3d", aspect="equal") | ||
color = np.random.rand(n_centers, 3) | ||
ax.plot_wireframe(sphere.sphere_x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use methods of the sphere: sphere.draw
, sphere.add_points
, sphere.draw_points
, etc
https://github.com/geomstats/geomstats/blob/master/geomstats/visualization.py#L94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried to use them in the best way possible, however it would be more satisfactory if the sphere was not drawn as many times as the number of clusters (I need each cluster to be drawn in a different color and so I repeat sphere.draw
n_clusters
times). Maybe this could be fixed by creating a new list of points each time add_points
is called, that could be plotted in a different color (or some other way) ?
Also, I would recommend to change the alpha=0.5
in draw
to a lower value such as alpha=0.2
so that the points are more visible, or to make it into a adjustable parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, the sphere is not drawn when using sphere.draw_points
instead of sphere.draw
. Sorry about that !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Just a few additional comments.
quantization/plot_quantization_s1.py
Outdated
n_repetitions=N_REPETITIONS, tolerance=TOLERANCE | ||
) | ||
visualization.plot(centers, space='S1', color='red') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to also draw the points, each in the color corresponding to its center, to mimic plot_quantization_s2 example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes ! I added that.
quantization/optimal_quantization.py
Outdated
return sample | ||
|
||
|
||
def optimal_quantization(points, metric, n_centers=10, n_repetitions=20, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put
N_CENTERS = 10
N_REPETITIONS = 20
N_MAX_ITERATIONS = 5000
at the beginning of the file.
quantization/plot_quantization_s2.py
Outdated
color = gs.random.rand(N_CENTERS, 3) | ||
for i in range(N_CENTERS): | ||
cluster_i = gs.vstack([point for point in clusters[i]]) | ||
sphere = visualization.Sphere() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a new Sphere for each cluster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed visualization.py in that direction. See geomstats/geomstats#141.
quantization/plot_quantization_s2.py
Outdated
cluster_i = gs.vstack([point for point in clusters[i]]) | ||
sphere = visualization.Sphere() | ||
sphere.add_points(cluster_i) | ||
if i == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this edge case? If so, could we tackle it directly in visualization.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, see PR #141.
quantization/optimal_quantization.py
Outdated
Return : | ||
- n_centers centers | ||
- n_centers weights between 0 and 1 | ||
- a dictionary containing the clusters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add something like: "where each key is the cluster index, and its value is the lists of points belonging to the cluster."?
Thanks! Let's first agree on the visualisation.py version and I'll have another look at this one after. |
quantization/optimal_quantization.py
Outdated
|
||
while gap > tolerance: | ||
step += 1 | ||
k = np.floor(step / n_repetition) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok you should find a better name for k - hard to understand what this variable mean. Usually k will be used for indexing an array. variables like i,j k should have a trivial behavior like starting at 0 and incrementing. Any other more complex behavior should have a clear name.
No description provided.