Skip to content

Commit 74dd391

Browse files
authored
Merge pull request #20 from iden3/fix-msm
MSM optimizations
2 parents 35dfe13 + 76825e5 commit 74dd391

File tree

5 files changed

+93
-114
lines changed

5 files changed

+93
-114
lines changed

.github/workflows/build.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,25 @@ jobs:
4242
- name: Build prover Android ARM64
4343
run: |
4444
mkdir -p build_prover_android && cd build_prover_android
45-
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DUSE_OPENMP=OFF
45+
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
4646
make -j4 && make install
4747
4848
- name: Build prover Android ARM64 with OpenMP
4949
run: |
5050
mkdir -p build_prover_android_openmp && cd build_prover_android_openmp
51-
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DUSE_OPENMP=ON
51+
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
5252
make -j4 && make install
5353
5454
- name: Build prover Android x86_64
5555
run: |
5656
mkdir -p build_prover_android_x86_64 && cd build_prover_android_x86_64
57-
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DUSE_OPENMP=OFF
57+
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
5858
make -j4 && make install
5959
6060
- name: Build prover Android x86_64 with OpenMP
6161
run: |
6262
mkdir -p build_prover_android_openmp_x86_64 && cd build_prover_android_openmp_x86_64
63-
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DUSE_OPENMP=ON
63+
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
6464
make -j4 && make install
6565
6666
- name: Build prover Linux
@@ -184,13 +184,13 @@ jobs:
184184
if [[ ! -d "depends/gmp/package_macos_arm64" ]]; then ./build_gmp.sh macos_arm64; fi
185185
186186
mkdir -p build_prover_ios && cd build_prover_ios
187-
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios
187+
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios -DBUILD_TESTS=OFF
188188
xcodebuild -destination 'generic/platform=iOS' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj -configuration Release
189189
cp ../depends/gmp/package_ios_arm64/lib/libgmp.a src/Release-iphoneos
190190
cd ../
191191
192-
mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
193-
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO
192+
mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
193+
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO -DBUILD_TESTS=OFF
194194
xcodebuild -destination 'generic/platform=iOS Simulator' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj
195195
cp ../depends/gmp/package_iphone_simulator/lib/libgmp.a src/Debug-iphonesimulator
196196
cd ../

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ project(rapidsnark LANGUAGES CXX C ASM)
1010
set(CMAKE_CXX_STANDARD 11)
1111
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1212

13+
message("BITS_PER_CHUNK=" ${BITS_PER_CHUNK})
1314
message("USE_ASM=" ${USE_ASM})
1415
message("USE_OPENMP=" ${USE_OPENMP})
1516
message("CMAKE_CROSSCOMPILING=" ${CMAKE_CROSSCOMPILING})

src/CMakeLists.txt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ if(USE_ASM)
1010
endif()
1111
endif()
1212

13+
if(DEFINED BITS_PER_CHUNK)
14+
add_definitions(-DMSM_BITS_PER_CHUNK=${BITS_PER_CHUNK})
15+
endif()
16+
1317
if(USE_ASM AND ARCH MATCHES "x86_64")
1418

1519
if (CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin")
@@ -131,12 +135,15 @@ if(USE_SODIUM)
131135
target_link_libraries(prover sodium)
132136
endif()
133137

138+
option(BUILD_TESTS "Build the tests" ON)
134139

135-
enable_testing()
136-
add_executable(test_public_size test_public_size.c)
137-
target_link_libraries(test_public_size rapidsnarkStaticFrFq)
138-
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
139-
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
140+
if(BUILD_TESTS)
141+
enable_testing()
142+
add_executable(test_public_size test_public_size.c)
143+
target_link_libraries(test_public_size rapidsnarkStaticFrFq pthread)
144+
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
145+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
146+
endif()
140147

141148
if(OpenMP_CXX_FOUND)
142149

src/groth16.cpp

Lines changed: 72 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "random_generator.hpp"
22
#include "logging.hpp"
3-
#include <future>
3+
#include "misc.hpp"
4+
#include <vector>
5+
#include <mutex>
46

57
namespace Groth16 {
68

@@ -46,114 +48,84 @@ std::unique_ptr<Prover<Engine>> makeProver(
4648
template <typename Engine>
4749
std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement *wtns) {
4850

49-
#ifdef USE_OPENMP
51+
ThreadPool &threadPool = ThreadPool::defaultPool();
52+
5053
LOG_TRACE("Start Multiexp A");
5154
uint32_t sW = sizeof(wtns[0]);
5255
typename Engine::G1Point pi_a;
53-
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
56+
E.g1.multiMulByScalarMSM(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
5457
std::ostringstream ss2;
5558
ss2 << "pi_a: " << E.g1.toString(pi_a);
5659
LOG_DEBUG(ss2);
5760

5861
LOG_TRACE("Start Multiexp B1");
5962
typename Engine::G1Point pib1;
60-
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
63+
E.g1.multiMulByScalarMSM(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
6164
std::ostringstream ss3;
6265
ss3 << "pib1: " << E.g1.toString(pib1);
6366
LOG_DEBUG(ss3);
6467

6568
LOG_TRACE("Start Multiexp B2");
6669
typename Engine::G2Point pi_b;
67-
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
70+
E.g2.multiMulByScalarMSM(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
6871
std::ostringstream ss4;
6972
ss4 << "pi_b: " << E.g2.toString(pi_b);
7073
LOG_DEBUG(ss4);
7174

7275
LOG_TRACE("Start Multiexp C");
7376
typename Engine::G1Point pi_c;
74-
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
77+
E.g1.multiMulByScalarMSM(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
7578
std::ostringstream ss5;
7679
ss5 << "pi_c: " << E.g1.toString(pi_c);
7780
LOG_DEBUG(ss5);
78-
#else
79-
LOG_TRACE("Start Multiexp A");
80-
uint32_t sW = sizeof(wtns[0]);
81-
typename Engine::G1Point pi_a;
82-
auto pA_future = std::async([&]() {
83-
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
84-
});
85-
86-
LOG_TRACE("Start Multiexp B1");
87-
typename Engine::G1Point pib1;
88-
auto pB1_future = std::async([&]() {
89-
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
90-
});
91-
92-
LOG_TRACE("Start Multiexp B2");
93-
typename Engine::G2Point pi_b;
94-
auto pB2_future = std::async([&]() {
95-
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
96-
});
97-
98-
LOG_TRACE("Start Multiexp C");
99-
typename Engine::G1Point pi_c;
100-
auto pC_future = std::async([&]() {
101-
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
102-
});
103-
#endif
10481

10582
LOG_TRACE("Start Initializing a b c A");
10683
auto a = new typename Engine::FrElement[domainSize];
10784
auto b = new typename Engine::FrElement[domainSize];
10885
auto c = new typename Engine::FrElement[domainSize];
10986

110-
#pragma omp parallel for
111-
for (u_int32_t i=0; i<domainSize; i++) {
112-
E.fr.copy(a[i], E.fr.zero());
113-
E.fr.copy(b[i], E.fr.zero());
114-
}
87+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
88+
for (u_int32_t i=begin; i<end; i++) {
89+
E.fr.copy(a[i], E.fr.zero());
90+
E.fr.copy(b[i], E.fr.zero());
91+
}
92+
});
11593

11694
LOG_TRACE("Processing coefs");
117-
#ifdef _OPENMP
118-
#define NLOCKS 1024
119-
omp_lock_t locks[NLOCKS];
120-
for (int i=0; i<NLOCKS; i++) omp_init_lock(&locks[i]);
121-
#pragma omp parallel for
122-
#endif
123-
for (u_int64_t i=0; i<nCoefs; i++) {
124-
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
125-
typename Engine::FrElement aux;
126-
127-
E.fr.mul(
128-
aux,
129-
wtns[coefs[i].s],
130-
coefs[i].coef
131-
);
132-
#ifdef _OPENMP
133-
omp_set_lock(&locks[coefs[i].c % NLOCKS]);
134-
#endif
135-
E.fr.add(
136-
ab[coefs[i].c],
137-
ab[coefs[i].c],
138-
aux
139-
);
140-
#ifdef _OPENMP
141-
omp_unset_lock(&locks[coefs[i].c % NLOCKS]);
142-
#endif
143-
}
144-
#ifdef _OPENMP
145-
for (int i=0; i<NLOCKS; i++) omp_destroy_lock(&locks[i]);
146-
#endif
14795

96+
#define NLOCKS 1024
97+
std::vector<std::mutex> locks(NLOCKS);
98+
99+
threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
100+
for (u_int64_t i=begin; i<end; i++) {
101+
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
102+
typename Engine::FrElement aux;
103+
104+
E.fr.mul(
105+
aux,
106+
wtns[coefs[i].s],
107+
coefs[i].coef
108+
);
109+
110+
std::lock_guard<std::mutex> guard(locks[coefs[i].c % NLOCKS]);
111+
112+
E.fr.add(
113+
ab[coefs[i].c],
114+
ab[coefs[i].c],
115+
aux
116+
);
117+
}
118+
});
148119
LOG_TRACE("Calculating c");
149-
#pragma omp parallel for
150-
for (u_int32_t i=0; i<domainSize; i++) {
151-
E.fr.mul(
152-
c[i],
153-
a[i],
154-
b[i]
155-
);
156-
}
120+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
121+
for (u_int64_t i=begin; i<end; i++) {
122+
E.fr.mul(
123+
c[i],
124+
a[i],
125+
b[i]
126+
);
127+
}
128+
});
157129

158130
LOG_TRACE("Initializing fft");
159131
u_int32_t domainPower = fft->log2(domainSize);
@@ -164,10 +136,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
164136
LOG_DEBUG(E.fr.toString(a[0]).c_str());
165137
LOG_DEBUG(E.fr.toString(a[1]).c_str());
166138
LOG_TRACE("Start Shift A");
167-
#pragma omp parallel for
168-
for (u_int64_t i=0; i<domainSize; i++) {
169-
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
170-
}
139+
140+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
141+
for (u_int64_t i=begin; i<end; i++) {
142+
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
143+
}
144+
});
145+
171146
LOG_TRACE("a After shift:");
172147
LOG_DEBUG(E.fr.toString(a[0]).c_str());
173148
LOG_DEBUG(E.fr.toString(a[1]).c_str());
@@ -182,10 +157,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
182157
LOG_DEBUG(E.fr.toString(b[0]).c_str());
183158
LOG_DEBUG(E.fr.toString(b[1]).c_str());
184159
LOG_TRACE("Start Shift B");
185-
#pragma omp parallel for
186-
for (u_int64_t i=0; i<domainSize; i++) {
187-
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
188-
}
160+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
161+
for (u_int64_t i=begin; i<end; i++) {
162+
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
163+
}
164+
});
189165
LOG_TRACE("b After shift:");
190166
LOG_DEBUG(E.fr.toString(b[0]).c_str());
191167
LOG_DEBUG(E.fr.toString(b[1]).c_str());
@@ -201,10 +177,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
201177
LOG_DEBUG(E.fr.toString(c[0]).c_str());
202178
LOG_DEBUG(E.fr.toString(c[1]).c_str());
203179
LOG_TRACE("Start Shift C");
204-
#pragma omp parallel for
205-
for (u_int64_t i=0; i<domainSize; i++) {
206-
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
207-
}
180+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
181+
for (u_int64_t i=begin; i<end; i++) {
182+
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
183+
}
184+
});
208185
LOG_TRACE("c After shift:");
209186
LOG_DEBUG(E.fr.toString(c[0]).c_str());
210187
LOG_DEBUG(E.fr.toString(c[1]).c_str());
@@ -215,12 +192,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
215192
LOG_DEBUG(E.fr.toString(c[1]).c_str());
216193

217194
LOG_TRACE("Start ABC");
218-
#pragma omp parallel for
219-
for (u_int64_t i=0; i<domainSize; i++) {
220-
E.fr.mul(a[i], a[i], b[i]);
221-
E.fr.sub(a[i], a[i], c[i]);
222-
E.fr.fromMontgomery(a[i], a[i]);
223-
}
195+
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
196+
for (u_int64_t i=begin; i<end; i++) {
197+
E.fr.mul(a[i], a[i], b[i]);
198+
E.fr.sub(a[i], a[i], c[i]);
199+
E.fr.fromMontgomery(a[i], a[i]);
200+
}
201+
});
224202
LOG_TRACE("abc:");
225203
LOG_DEBUG(E.fr.toString(a[0]).c_str());
226204
LOG_DEBUG(E.fr.toString(a[1]).c_str());
@@ -230,7 +208,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
230208

231209
LOG_TRACE("Start Multiexp H");
232210
typename Engine::G1Point pih;
233-
E.g1.multiMulByScalar(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
211+
E.g1.multiMulByScalarMSM(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
234212
std::ostringstream ss1;
235213
ss1 << "pih: " << E.g1.toString(pih);
236214
LOG_DEBUG(ss1);
@@ -247,13 +225,6 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
247225
randombytes_buf((void *)&(r.v[0]), sizeof(r)-1);
248226
randombytes_buf((void *)&(s.v[0]), sizeof(s)-1);
249227

250-
#ifndef USE_OPENMP
251-
pA_future.get();
252-
pB1_future.get();
253-
pB2_future.get();
254-
pC_future.get();
255-
#endif
256-
257228
typename Engine::G1Point p1;
258229
typename Engine::G2Point p2;
259230

0 commit comments

Comments
 (0)