@@ -86,18 +86,22 @@ void P1_times_Vt(const mayo_params_t* p, const uint64_t* P1, const unsigned char
86
86
mul_add_m_upper_triangular_mat_x_mat_trans (PARAM_m_vec_limbs (p ), P1 , V , acc , PARAM_v (p ), PARAM_v (p ), PARAM_k (p ), 1 );
87
87
}
88
88
89
+ #if defined(HAVE_STACKEFFICIENT ) || defined(PQM4 )
89
90
// compute P * S^t = [ P1 P2 ] * [S1] = [P1*S1 + P2*S2]
90
91
// [ 0 P3 ] [S2] [ P3*S2]
91
- static inline void mayo_generic_m_calculate_PS (const uint64_t * P1 , const uint64_t * P2 , const uint64_t * P3 , const unsigned char * S ,
92
- const int m , const int v , const int o , const int k , uint64_t * PS ) {
92
+ // compute S * PS = [ S1 S2 ] * [ P1*S1 + P2*S2 = P1 ] = [ S1*P1 + S2*P2 ]
93
+ // [ P3*S2 = P2 ]
94
+ static inline void mayo_generic_m_calculate_PS_SPS (const uint64_t * P1 , const uint64_t * P2 , const uint64_t * P3 , const unsigned char * S ,
95
+ const int m , const int v , const int o , const int k , uint64_t * SPS ) {
93
96
94
97
const int n = o + v ;
95
98
const int m_vec_limbs = (m + 15 )/16 ;
96
99
97
- #if defined( HAVE_STACKEFFICIENT ) || defined( PQM4 )
100
+ uint64_t PS [( N_MAX + K_MAX ) * M_VEC_LIMBS_MAX ] = { 0 };
98
101
uint64_t accumulator [16 * ((M_MAX + 15 )/16 ) * N_MAX ] = {0 };
99
102
int P1_used ;
100
103
int P3_used ;
104
+
101
105
for (int col = 0 ; col < k ; col ++ ) {
102
106
for (unsigned int i = 0 ; i < sizeof (accumulator )/8 ; i ++ ) {
103
107
accumulator [i ] = 0 ;
@@ -123,11 +127,33 @@ static inline void mayo_generic_m_calculate_PS(const uint64_t *P1, const uint64_
123
127
}
124
128
}
125
129
126
- for (int row = 0 ; row < n ; row ++ ) {
127
- m_vec_multiply_bins (m_vec_limbs , accumulator + row * 16 * m_vec_limbs , PS + (row * k + col ) * m_vec_limbs );
128
- }
130
+ for (int row = 0 ; row < n ; row ++ ) {
131
+ m_vec_multiply_bins (m_vec_limbs , accumulator + row * 16 * m_vec_limbs , PS + (row + col ) * m_vec_limbs );
132
+ }
133
+
134
+ for (int row = 0 ; row < k ; row ++ ) {
135
+ for (unsigned int i = 0 ; i < 16 * ((M_MAX + 15 )/16 ); ++ i )
136
+ accumulator [i ] = 0 ;
137
+ for (int j = 0 ; j < n ; j ++ ) {
138
+ m_vec_add (m_vec_limbs , PS + (j + col ) * m_vec_limbs , accumulator + S [row * n + j ]* m_vec_limbs );
139
+ }
140
+ m_vec_multiply_bins (m_vec_limbs , accumulator , SPS + (row * k + col ) * m_vec_limbs );
141
+ }
142
+
129
143
}
130
- #else
144
+
145
+ }
146
+
147
+ #else
148
+
149
+ // compute P * S^t = [ P1 P2 ] * [S1] = [P1*S1 + P2*S2]
150
+ // [ 0 P3 ] [S2] [ P3*S2]
151
+ static inline void mayo_generic_m_calculate_PS (const uint64_t * P1 , const uint64_t * P2 , const uint64_t * P3 , const unsigned char * S ,
152
+ const int m , const int v , const int o , const int k , uint64_t * SPS ) {
153
+
154
+ const int n = o + v ;
155
+ const int m_vec_limbs = (m + 15 )/16 ;
156
+
131
157
uint64_t accumulator [16 * ((M_MAX + 15 )/16 ) * K_MAX * N_MAX ] = {0 };
132
158
int P1_used = 0 ;
133
159
for (int row = 0 ; row < v ; row ++ ) {
@@ -158,14 +184,14 @@ static inline void mayo_generic_m_calculate_PS(const uint64_t *P1, const uint64_
158
184
// multiply stuff according to the bins of the accumulator and add to PS.
159
185
int i = 0 ;
160
186
while (i < n * k ) {
161
- m_vec_multiply_bins (m_vec_limbs , accumulator + i * 16 * m_vec_limbs , PS + i * m_vec_limbs );
187
+ m_vec_multiply_bins (m_vec_limbs , accumulator + i * 16 * m_vec_limbs , SPS + i * m_vec_limbs );
162
188
i ++ ;
163
189
}
164
190
165
- #endif
166
191
}
167
192
168
-
193
+ // compute S * PS = [ S1 S2 ] * [ P1*S1 + P2*S2 = P1 ] = [ S1*P1 + S2*P2 ]
194
+ // [ P3*S2 = P2 ]
169
195
static inline void mayo_generic_m_calculate_SPS (const uint64_t * PS , const unsigned char * S , int m , int k , int n , uint64_t * SPS ){
170
196
uint64_t accumulator [16 * ((M_MAX + 15 )/16 )* K_MAX * K_MAX ] = {0 };
171
197
const int m_vec_limbs = (m + 15 )/ 16 ;
@@ -185,6 +211,8 @@ static inline void mayo_generic_m_calculate_SPS(const uint64_t *PS, const unsign
185
211
}
186
212
}
187
213
214
+ #endif
215
+
188
216
189
217
static inline
190
218
void P1P1t_times_O (const mayo_params_t * p , const uint64_t * P1 , const unsigned char * O , uint64_t * acc ){
@@ -252,11 +280,15 @@ static inline void m_calculate_PS_SPS(const mayo_params_t *p, const uint64_t *P1
252
280
#ifndef ENABLE_PARAMS_DYNAMIC
253
281
(void ) p ;
254
282
#endif
283
+ #if defined(HAVE_STACKEFFICIENT ) || defined(PQM4 )
284
+ mayo_generic_m_calculate_PS_SPS (P1 , P2 , P3 , s , PARAM_m (p ), PARAM_v (p ), PARAM_o (p ), PARAM_k (p ), SPS );
285
+ #else
255
286
uint64_t PS [N_MAX * K_MAX * M_VEC_LIMBS_MAX ] = { 0 };
256
287
mayo_generic_m_calculate_PS (P1 , P2 , P3 , s , PARAM_m (p ), PARAM_v (p ), PARAM_o (p ), PARAM_k (p ), PS );
257
288
258
289
// compute S * P * S = S* (P*S)
259
290
mayo_generic_m_calculate_SPS (PS , s , PARAM_m (p ), PARAM_k (p ), PARAM_n (p ), SPS );
291
+ #endif
260
292
}
261
293
262
294
#endif
0 commit comments