Skip to content

Commit e0569dc

Browse files
committed
Improvements to PKCS1Encoding
1 parent 811c490 commit e0569dc

File tree

3 files changed

+402
-133
lines changed

3 files changed

+402
-133
lines changed

core/src/main/java/org/bouncycastle/crypto/encodings/PKCS1Encoding.java

Lines changed: 115 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -233,52 +233,86 @@ private byte[] encodeBlock(
233233
}
234234

235235
/**
236-
* Checks if the argument is a correctly PKCS#1.5 encoded Plaintext
237-
* for encryption.
238-
*
239-
* @param encoded The Plaintext.
240-
* @param pLen Expected length of the plaintext.
241-
* @return Either 0, if the encoding is correct, or -1, if it is incorrect.
236+
* Check the argument is a valid encoding with type 1. Returns the plaintext length if valid, or -1 if invalid.
242237
*/
243-
private static int checkPkcs1Encoding(byte[] encoded, int pLen)
238+
private static int checkPkcs1Encoding1(byte[] buf)
244239
{
245-
int correct = 0;
246-
/*
247-
* Check if the first two bytes are 0 2
248-
*/
249-
correct |= (encoded[0] ^ 2);
240+
int foundZeroMask = 0;
241+
int lastPadPos = 0;
250242

251-
/*
252-
* Now the padding check, check for no 0 byte in the padding
253-
*/
254-
int plen = encoded.length - (
255-
pLen /* Length of the PMS */
256-
+ 1 /* Final 0-byte before PMS */
257-
);
243+
// The first byte should be 0x01
244+
int badPadSign = -((buf[0] & 0xFF) ^ 0x01);
258245

259-
for (int i = 1; i < plen; i++)
246+
// There must be a zero terminator for the padding somewhere
247+
for (int i = 1; i < buf.length; ++i)
260248
{
261-
int tmp = encoded[i];
262-
tmp |= tmp >> 1;
263-
tmp |= tmp >> 2;
264-
tmp |= tmp >> 4;
265-
correct |= (tmp & 1) - 1;
249+
int padByte = buf[i] & 0xFF;
250+
int is0x00Mask = ((padByte ^ 0x00) - 1) >> 31;
251+
int is0xFFMask = ((padByte ^ 0xFF) - 1) >> 31;
252+
lastPadPos ^= i & ~foundZeroMask & is0x00Mask;
253+
foundZeroMask |= is0x00Mask;
254+
badPadSign |= ~(foundZeroMask | is0xFFMask);
266255
}
267256

268-
/*
269-
* Make sure the padding ends with a 0 byte.
270-
*/
271-
correct |= encoded[encoded.length - (pLen + 1)];
257+
// The header should be at least 10 bytes
258+
badPadSign |= lastPadPos - 9;
272259

273-
/*
274-
* Return 0 or 1, depending on the result.
275-
*/
276-
correct |= correct >> 1;
277-
correct |= correct >> 2;
278-
correct |= correct >> 4;
279-
return ~((correct & 1) - 1);
260+
int plaintextLength = buf.length - 1 - lastPadPos;
261+
return plaintextLength | badPadSign >> 31;
280262
}
281263

264+
/**
265+
* Check the argument is a valid encoding with type 2. Returns the plaintext length if valid, or -1 if invalid.
266+
*/
267+
private static int checkPkcs1Encoding2(byte[] buf)
268+
{
269+
int foundZeroMask = 0;
270+
int lastPadPos = 0;
271+
272+
// The first byte should be 0x02
273+
int badPadSign = -((buf[0] & 0xFF) ^ 0x02);
274+
275+
// There must be a zero terminator for the padding somewhere
276+
for (int i = 1; i < buf.length; ++i)
277+
{
278+
int padByte = buf[i] & 0xFF;
279+
int is0x00Mask = ((padByte ^ 0x00) - 1) >> 31;
280+
lastPadPos ^= i & ~foundZeroMask & is0x00Mask;
281+
foundZeroMask |= is0x00Mask;
282+
}
283+
284+
// The header should be at least 10 bytes
285+
badPadSign |= lastPadPos - 9;
286+
287+
int plaintextLength = buf.length - 1 - lastPadPos;
288+
return plaintextLength | badPadSign >> 31;
289+
}
290+
291+
/**
292+
* Check the argument is a valid encoding with type 2 of a plaintext with the given length. Returns 0 if
293+
* valid, or -1 if invalid.
294+
*/
295+
private static int checkPkcs1Encoding2(byte[] buf, int plaintextLength)
296+
{
297+
// The first byte should be 0x02
298+
int badPadSign = -((buf[0] & 0xFF) ^ 0x02);
299+
300+
int lastPadPos = buf.length - 1 - plaintextLength;
301+
302+
// The header should be at least 10 bytes
303+
badPadSign |= lastPadPos - 9;
304+
305+
// All pad bytes before the last one should be non-zero
306+
for (int i = 1; i < lastPadPos; ++i)
307+
{
308+
badPadSign |= (buf[i] & 0xFF) - 1;
309+
}
310+
311+
// Last pad byte should be zero
312+
badPadSign |= -(buf[lastPadPos] & 0xFF);
313+
314+
return badPadSign >> 31;
315+
}
282316

283317
/**
284318
* Decode PKCS#1.5 encoding, and return a random value if the padding is not correct.
@@ -298,132 +332,94 @@ private byte[] decodeBlockOrRandom(byte[] in, int inOff, int inLen)
298332
throw new InvalidCipherTextException("sorry, this method is only for decryption, not for signing");
299333
}
300334

301-
byte[] block = engine.processBlock(in, inOff, inLen);
302-
byte[] random;
303-
if (this.fallback == null)
335+
int plaintextLength = this.pLen;
336+
337+
byte[] random = fallback;
338+
if (fallback == null)
304339
{
305-
random = new byte[this.pLen];
340+
random = new byte[plaintextLength];
306341
this.random.nextBytes(random);
307342
}
308-
else
343+
344+
int badPadMask = 0;
345+
int strictBlockSize = engine.getOutputBlockSize();
346+
byte[] block = engine.processBlock(in, inOff, inLen);
347+
348+
byte[] data = block;
349+
if (block.length != strictBlockSize)
309350
{
310-
random = fallback;
351+
if (useStrictLength || block.length < strictBlockSize)
352+
{
353+
data = blockBuffer;
354+
}
311355
}
312356

313-
byte[] data = (useStrictLength & (block.length != engine.getOutputBlockSize())) ? blockBuffer : block;
357+
badPadMask |= checkPkcs1Encoding2(data, plaintextLength);
314358

315-
/*
316-
* Check the padding.
317-
*/
318-
int correct = PKCS1Encoding.checkPkcs1Encoding(data, this.pLen);
319-
320359
/*
321360
* Now, to a constant time constant memory copy of the decrypted value
322361
* or the random value, depending on the validity of the padding.
323362
*/
324-
byte[] result = new byte[this.pLen];
325-
for (int i = 0; i < this.pLen; i++)
363+
int dataOff = data.length - plaintextLength;
364+
byte[] result = new byte[plaintextLength];
365+
for (int i = 0; i < plaintextLength; ++i)
326366
{
327-
result[i] = (byte)((data[i + (data.length - pLen)] & (~correct)) | (random[i] & correct));
367+
result[i] = (byte)((data[dataOff + i] & ~badPadMask) | (random[i] & badPadMask));
328368
}
329369

330-
Arrays.fill(data, (byte)0);
370+
Arrays.fill(block, (byte)0);
371+
Arrays.fill(blockBuffer, 0, Math.max(0, blockBuffer.length - block.length), (byte)0);
331372

332373
return result;
333374
}
334375

335376
/**
336377
* @throws InvalidCipherTextException if the decrypted block is not in PKCS1 format.
337378
*/
338-
private byte[] decodeBlock(
339-
byte[] in,
340-
int inOff,
341-
int inLen)
379+
private byte[] decodeBlock(byte[] in, int inOff, int inLen)
342380
throws InvalidCipherTextException
343381
{
344382
/*
345383
* If the length of the expected plaintext is known, we use a constant-time decryption.
346384
* If the decryption fails, we return a random value.
347385
*/
348-
if (this.pLen != -1)
386+
if (forPrivateKey && this.pLen != -1)
349387
{
350388
return this.decodeBlockOrRandom(in, inOff, inLen);
351389
}
352390

391+
int strictBlockSize = engine.getOutputBlockSize();
353392
byte[] block = engine.processBlock(in, inOff, inLen);
354-
boolean incorrectLength = (useStrictLength & (block.length != engine.getOutputBlockSize()));
355-
356-
byte[] data;
357-
if (block.length < getOutputBlockSize())
358-
{
359-
data = blockBuffer;
360-
}
361-
else
362-
{
363-
data = block;
364-
}
365-
366-
byte type = data[0];
367-
368-
boolean badType;
369-
if (forPrivateKey)
370-
{
371-
badType = (type != 2);
372-
}
373-
else
374-
{
375-
badType = (type != 1);
376-
}
377-
378-
//
379-
// find and extract the message block.
380-
//
381-
int start = findStart(type, data);
382-
383-
start++; // data should start at the next byte
384393

385-
if (badType | start < HEADER_LENGTH)
386-
{
387-
Arrays.fill(data, (byte)0);
388-
throw new InvalidCipherTextException("block incorrect");
389-
}
394+
boolean incorrectLength = useStrictLength & (block.length != strictBlockSize);
390395

391-
// if we get this far, it's likely to be a genuine encoding error
392-
if (incorrectLength)
396+
byte[] data = block;
397+
if (block.length < strictBlockSize)
393398
{
394-
Arrays.fill(data, (byte)0);
395-
throw new InvalidCipherTextException("block incorrect size");
399+
data = blockBuffer;
396400
}
397401

398-
byte[] result = new byte[data.length - start];
399-
400-
System.arraycopy(data, start, result, 0, result.length);
401-
402-
return result;
403-
}
402+
int plaintextLength = forPrivateKey ? checkPkcs1Encoding2(data) : checkPkcs1Encoding1(data);
404403

405-
private int findStart(byte type, byte[] block)
406-
throws InvalidCipherTextException
407-
{
408-
int start = -1;
409-
boolean padErr = false;
410-
411-
for (int i = 1; i != block.length; i++)
404+
try
412405
{
413-
byte pad = block[i];
414-
415-
if (pad == 0 & start < 0)
406+
if (plaintextLength < 0)
416407
{
417-
start = i;
408+
throw new InvalidCipherTextException("block incorrect");
409+
}
410+
if (incorrectLength)
411+
{
412+
throw new InvalidCipherTextException("block incorrect size");
418413
}
419-
padErr |= (type == 1 & start < 0 & pad != (byte)0xff);
420-
}
421414

422-
if (padErr)
415+
byte[] result = new byte[plaintextLength];
416+
System.arraycopy(data, data.length - plaintextLength, result, 0, plaintextLength);
417+
return result;
418+
}
419+
finally
423420
{
424-
return -1;
421+
Arrays.fill(block, (byte)0);
422+
Arrays.fill(blockBuffer, 0, Math.max(0, blockBuffer.length - block.length), (byte)0);
425423
}
426-
427-
return start;
428424
}
429425
}

prov/src/main/java/org/bouncycastle/jcajce/provider/asymmetric/rsa/CipherSpi.java

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.bouncycastle.crypto.InvalidCipherTextException;
3030
import org.bouncycastle.crypto.encodings.ISO9796d1Encoding;
3131
import org.bouncycastle.crypto.encodings.OAEPEncoding;
32-
import org.bouncycastle.crypto.encodings.PKCS1Encoding;
3332
import org.bouncycastle.crypto.engines.RSABlindedEngine;
3433
import org.bouncycastle.crypto.params.ParametersWithRandom;
3534
import org.bouncycastle.jcajce.provider.asymmetric.util.BaseCipherSpi;
@@ -200,7 +199,7 @@ protected void engineSetPadding(
200199
}
201200
else if (pad.equals("PKCS1PADDING"))
202201
{
203-
cipher = new PKCS1Encoding(new RSABlindedEngine());
202+
cipher = new CustomPKCS1Encoding(new RSABlindedEngine());
204203
}
205204
else if (pad.equals("ISO9796-1PADDING"))
206205
{
@@ -531,15 +530,26 @@ private byte[] getOutput()
531530
{
532531
try
533532
{
534-
return cipher.processBlock(bOut.getBuf(), 0, bOut.size());
535-
}
536-
catch (InvalidCipherTextException e)
537-
{
538-
throw new BadBlockException("unable to decrypt block", e);
539-
}
540-
catch (ArrayIndexOutOfBoundsException e)
541-
{
542-
throw new BadBlockException("unable to decrypt block", e);
533+
byte[] output;
534+
try
535+
{
536+
output = cipher.processBlock(bOut.getBuf(), 0, bOut.size());
537+
}
538+
catch (InvalidCipherTextException e)
539+
{
540+
throw new BadBlockException("unable to decrypt block", e);
541+
}
542+
catch (ArrayIndexOutOfBoundsException e)
543+
{
544+
throw new BadBlockException("unable to decrypt block", e);
545+
}
546+
547+
if (output == null)
548+
{
549+
throw new BadBlockException("unable to decrypt block", null);
550+
}
551+
552+
return output;
543553
}
544554
finally
545555
{
@@ -565,7 +575,7 @@ static public class PKCS1v1_5Padding
565575
{
566576
public PKCS1v1_5Padding()
567577
{
568-
super(new PKCS1Encoding(new RSABlindedEngine()));
578+
super(new CustomPKCS1Encoding(new RSABlindedEngine()));
569579
}
570580
}
571581

@@ -574,7 +584,7 @@ static public class PKCS1v1_5Padding_PrivateOnly
574584
{
575585
public PKCS1v1_5Padding_PrivateOnly()
576586
{
577-
super(false, true, new PKCS1Encoding(new RSABlindedEngine()));
587+
super(false, true, new CustomPKCS1Encoding(new RSABlindedEngine()));
578588
}
579589
}
580590

@@ -583,7 +593,7 @@ static public class PKCS1v1_5Padding_PublicOnly
583593
{
584594
public PKCS1v1_5Padding_PublicOnly()
585595
{
586-
super(true, false, new PKCS1Encoding(new RSABlindedEngine()));
596+
super(true, false, new CustomPKCS1Encoding(new RSABlindedEngine()));
587597
}
588598
}
589599

0 commit comments

Comments
 (0)