Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/main/java/tech/sourced/gemini/WeightedMinHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ public class WeightedMinHash implements Serializable {
protected int dim;
protected int sampleSize;

protected double[][] rs;
protected double[][] lnCs;
protected double[][] betas;
protected float[][] rs;
protected float[][] lnCs;
protected float[][] betas;

/**
* Initializes a WeightedMinHash
Expand Down Expand Up @@ -53,37 +53,37 @@ public WeightedMinHash(int dim, int sampleSize, long seed) {

GammaDistribution gammaGen = new GammaDistribution(randSrc, 2, 1);

rs = new double[sampleSize][dim];
rs = new float[sampleSize][dim];

for (int y = 0; y < sampleSize; y++) {
double[] arr = rs[y];
float[] arr = rs[y];
for (int x = 0; x < dim; x++) {
arr[x] = gammaGen.sample();
arr[x] = (float)gammaGen.sample();
}
}

lnCs = new double[sampleSize][dim];
lnCs = new float[sampleSize][dim];

for (int y = 0; y < sampleSize; y++) {
double[] arr = lnCs[y];
float[] arr = lnCs[y];
for (int x = 0; x < dim; x++) {
arr[x] = log(gammaGen.sample());
arr[x] = (float)log(gammaGen.sample());
}
}

UniformRealDistribution uniformGen = new UniformRealDistribution(randSrc, 0, 1);

betas = new double[sampleSize][dim];
betas = new float[sampleSize][dim];

for (int y = 0; y < sampleSize; y++) {
double[] arr = betas[y];
float[] arr = betas[y];
for (int x = 0; x < dim; x++) {
arr[x] = uniformGen.sample();
arr[x] = (float)uniformGen.sample();
}
}
}

WeightedMinHash(int dim, int sampleSize, double[][] rs, double[][] lnCs, double[][] betas) {
WeightedMinHash(int dim, int sampleSize, float[][] rs, float[][] lnCs, float[][] betas) {
this.dim = dim;
this.sampleSize = sampleSize;
this.rs = rs;
Expand All @@ -97,7 +97,7 @@ public WeightedMinHash(int dim, int sampleSize, long seed) {
* @param values weighted vector
* @return weighted MinHash
*/
public long[][] hash(double[] values) {
public long[][] hash(float[] values) {
if (values.length != dim) {
throw new IllegalArgumentException("input dimension mismatch, expected " + dim);
}
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/tech/sourced/gemini/FeaturesHash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object FeaturesHash {
* Factory method for initializing WMH data structure
* \w default parameters, specific to Gemini.
*
* Allocates at least 2*dim*sampleSize*8 bytes of RAM
* Allocates at least 3*dim*sampleSize*4 bytes of RAM
*
* @param dim weight vector size
* @param sampleSize number of samples
Expand All @@ -49,14 +49,14 @@ object FeaturesHash {
* @param docFreq
* @return
*/
def toBagOfFeatures(features: Iterator[Feature], docFreq: OrderedDocFreq): Array[Double] = {
def toBagOfFeatures(features: Iterator[Feature], docFreq: OrderedDocFreq): Array[Float] = {
val OrderedDocFreq(docs, tokens, df) = docFreq

val bag = new Array[Double](tokens.size)
val bag = new Array[Float](tokens.size)
features.foreach { feature =>
tokens.search(feature.name) match {
case Found(idx) => {
val tf = feature.weight.toDouble
val tf = feature.weight

bag(idx) = MathUtil.logTFlogIDF(tf, df(feature.name), docs)
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/tech/sourced/gemini/Hash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class Hash(session: SparkSession,
* Create WeightedMinHash instance and broadcasts it
*
* create it only once and keep on node
* because the instance is relatively huge (2 * N of features * sampleSize(160 or 256 depends on mode) * 8)
* According to tests ~1.6 Gb per 1 PGA bucket (but really depends on bucket)
* because the instance is relatively huge (3 * N of features * sampleSize(160 or 256 depends on mode) * 4)
* On my system with 1.000.000 tokens it allocated 1920009680 bytes (~2Gb) (measured using JAMM)
*
* @param tokens number of features
* @param sampleSize depends on hashing mode and threshold
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/tech/sourced/gemini/util/MathUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ object MathUtil {
)
}

def logTFlogIDF(tf: Double, df: Double, ndocs: Int): Double = math.log(1 + tf) * math.log(ndocs / df)
def logTFlogIDF(tf: Float, df: Int, ndocs: Int): Float = (math.log(1 + tf) * math.log(ndocs / df)).toFloat
}
2 changes: 1 addition & 1 deletion src/test/resources/docfreq.json

Large diffs are not rendered by default.

Loading