Skip to content

Commit bbbabd0

Browse files
author
Ryan Kim
authored
Merge pull request #441 from kroma-network/feat/impl-radix2ditparallel
feat: impl `Radix2DitParallel`
2 parents 5de7501 + fa4c33f commit bbbabd0

21 files changed

+605
-54
lines changed

tachyon/math/finite_fields/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ tachyon_cc_library(
147147
deps = [
148148
":finite_field",
149149
":legendre_symbol",
150+
":packed_prime_field_traits_forward",
150151
":prime_field_util",
151152
"//tachyon/base:bits",
152153
"//tachyon/base/json",

tachyon/math/finite_fields/baby_bear/BUILD.bazel

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
12
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
23
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
34
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s")
4-
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
5+
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")
56

67
package(default_visibility = ["//visibility:public"])
78

8-
generate_prime_fields(
9+
string_flag(
10+
name = SUBGROUP_GENERATOR,
11+
build_setting_default = "31",
12+
)
13+
14+
generate_fft_prime_fields(
915
name = "baby_bear",
1016
class_name = "BabyBear",
1117
# 2³¹ - 2²⁷ + 1
1218
# Hex: 0x78000001
1319
modulus = "2013265921",
1420
namespace = "tachyon::math",
21+
subgroup_generator = ":" + SUBGROUP_GENERATOR,
1522
use_montgomery = True,
1623
)
1724

tachyon/math/finite_fields/baby_bear/packed_baby_bear.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedBabyBear> {
3737
using Config = BabyBear::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<BabyBear> {
42+
using PackedPrimeField = PackedBabyBear;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/koala_bear/BUILD.bazel

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
12
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
23
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
34
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp2s", "generate_fp4s")
4-
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
5+
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")
56

67
package(default_visibility = ["//visibility:public"])
78

8-
generate_prime_fields(
9+
string_flag(
10+
name = SUBGROUP_GENERATOR,
11+
build_setting_default = "3",
12+
)
13+
14+
generate_fft_prime_fields(
915
name = "koala_bear",
1016
class_name = "KoalaBear",
1117
# 2³¹ - 2²⁴ + 1
1218
# Hex: 0x7f000001
1319
modulus = "2130706433",
1420
namespace = "tachyon::math",
21+
subgroup_generator = ":" + SUBGROUP_GENERATOR,
1522
use_montgomery = True,
1623
)
1724

tachyon/math/finite_fields/koala_bear/packed_koala_bear.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedKoalaBear> {
3737
using Config = KoalaBear::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<KoalaBear> {
42+
using PackedPrimeField = PackedKoalaBear;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/mersenne31/packed_mersenne31.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedMersenne31> {
3737
using Config = Mersenne31::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<Mersenne31> {
42+
using PackedPrimeField = PackedMersenne31;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/packed_prime_field_traits_forward.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace tachyon::math {
55

6-
template <typename T>
6+
template <typename T, typename SFINAE = void>
77
struct PackedPrimeFieldTraits;
88

99
} // namespace tachyon::math

tachyon/math/finite_fields/prime_field_base.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tachyon/math/base/gmp/gmp_util.h"
2323
#include "tachyon/math/finite_fields/finite_field.h"
2424
#include "tachyon/math/finite_fields/legendre_symbol.h"
25+
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"
2526
#include "tachyon/math/finite_fields/prime_field_util.h"
2627

2728
namespace tachyon {
@@ -160,6 +161,12 @@ H AbslHashValue(H h, const F& prime_field) {
160161
return h;
161162
}
162163

164+
template <typename T>
165+
struct PackedPrimeFieldTraits<
166+
T, std::enable_if_t<std::is_base_of_v<math::PrimeFieldBase<T>, T>>> {
167+
using PackedPrimeField = T;
168+
};
169+
163170
} // namespace math
164171

165172
namespace base {

tachyon/math/matrix/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ tachyon_cc_library(
2626
name = "matrix_utils",
2727
hdrs = ["matrix_utils.h"],
2828
deps = [
29+
"//tachyon/base:bits",
30+
"//tachyon/base:openmp_util",
2931
"//tachyon/base/containers:container_util",
3032
"//tachyon/math/finite_fields:packed_prime_field_traits_forward",
3133
"@eigen_archive//:eigen3",

tachyon/math/matrix/matrix_utils.h

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#ifndef TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
22
#define TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
33

4+
#include <utility>
45
#include <vector>
56

67
#include "third_party/eigen3/Eigen/Core"
78

9+
#include "tachyon/base/bits.h"
810
#include "tachyon/base/containers/container_util.h"
11+
#include "tachyon/base/openmp_util.h"
912
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"
1013

1114
namespace tachyon::math {
@@ -45,22 +48,25 @@ MakeCirculant(const Eigen::MatrixBase<ArgType>& arg) {
4548
CirculantFunctor<ArgType>(arg.derived()));
4649
}
4750

51+
// NOTE(ashjeong): Important! |matrix| should carry the same amount of rows as
52+
// the parent matrix it is a block from. |PackRowHorizontally| currently only
53+
// supports row-major matrices.
4854
template <typename PackedPrimeField, typename Derived, typename PrimeField>
49-
std::vector<PackedPrimeField> PackRowHorizontally(
50-
const Eigen::MatrixBase<Derived>& matrix, size_t row,
51-
std::vector<PrimeField>& remaining_values) {
55+
std::vector<PackedPrimeField*> PackRowHorizontally(
56+
Eigen::Block<Derived>& matrix, size_t row,
57+
std::vector<PrimeField*>& remaining_values) {
58+
static_assert(Derived::Options & Eigen::RowMajorBit);
5259
size_t num_packed = matrix.cols() / PackedPrimeField::N;
5360
size_t remaining_start_idx = num_packed * PackedPrimeField::N;
54-
remaining_values =
55-
base::CreateVector(matrix.cols() - remaining_start_idx,
56-
[row, remaining_start_idx, &matrix](size_t col) {
57-
return matrix(row, remaining_start_idx + col);
58-
});
59-
61+
remaining_values = base::CreateVector(
62+
matrix.cols() - remaining_start_idx,
63+
[row, remaining_start_idx, &matrix](size_t col) {
64+
return reinterpret_cast<PrimeField*>(
65+
matrix.data() + row * matrix.cols() + remaining_start_idx + col);
66+
});
6067
return base::CreateVector(num_packed, [row, &matrix](size_t col) {
61-
return PackedPrimeField::From([row, col, &matrix](size_t i) {
62-
return matrix(row, PackedPrimeField::N * col + i);
63-
});
68+
return reinterpret_cast<PackedPrimeField*>(
69+
matrix.data() + row * matrix.cols() + PackedPrimeField::N * col);
6470
});
6571
}
6672

@@ -74,6 +80,50 @@ std::vector<PackedPrimeField> PackRowVertically(
7480
});
7581
}
7682

83+
// Expands a |Eigen::MatrixBase|'s rows from |rows| to |rows|^(|added_bits|),
84+
// moving values from row |i| to row |i|^(|added_bits|). All new entries are set
85+
// to |F::Zero()|.
86+
template <typename Derived>
87+
void ExpandInPlaceWithZeroPad(Eigen::MatrixBase<Derived>& mat,
88+
size_t added_bits) {
89+
if (added_bits == 0) {
90+
return;
91+
}
92+
93+
Eigen::Index original_rows = mat.rows();
94+
Eigen::Index new_rows = mat.rows() << added_bits;
95+
Eigen::Index cols = mat.cols();
96+
97+
Derived padded = Derived::Zero(new_rows, cols);
98+
99+
OPENMP_PARALLEL_FOR(Eigen::Index row = 0; row < original_rows; ++row) {
100+
Eigen::Index padded_row_index = row << added_bits;
101+
// TODO(ashjeong): Check if moved properly
102+
padded.row(padded_row_index) = std::move(mat.row(row));
103+
}
104+
mat = std::move(padded);
105+
}
106+
107+
// Swaps rows of a |Eigen::MatrixBase| such that each row is changed to the row
108+
// accessed with the reversed bits of the current index. Crashes if the number
109+
// of rows is not a power of two.
110+
template <typename Derived>
111+
void ReverseMatrixIndexBits(Eigen::MatrixBase<Derived>& mat) {
112+
size_t rows = static_cast<size_t>(mat.rows());
113+
if (rows == 0) {
114+
return;
115+
}
116+
CHECK(base::bits::IsPowerOfTwo(rows));
117+
size_t log_n = base::bits::Log2Ceiling(rows);
118+
119+
OPENMP_PARALLEL_FOR(size_t row = 1; row < rows; ++row) {
120+
size_t ridx = base::bits::BitRev(row) >> (sizeof(size_t) * 8 - log_n);
121+
if (row < ridx) {
122+
mat.row(row).swap(mat.row(ridx));
123+
}
124+
}
125+
}
126+
77127
} // namespace tachyon::math
78128

79129
#endif // TACHYON_MATH_MATRIX_MATRIX_UTILS_H_

0 commit comments

Comments
 (0)