diff --git a/rsa/core.py b/rsa/core.py index 7649f11..2cd23fb 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -28,6 +28,14 @@ def assert_int(var: int, name: str) -> None: raise TypeError("{} should be an integer, not {}".format(name, var.__class__)) +def assert_length(message: int, n: int) -> None: + if message < 0: + raise ValueError("Only non-negative numbers are supported") + + if message >= n: + raise OverflowError("The message %i is too long for n=%i" % (message, n)) + + def encrypt_int(message: int, ekey: int, n: int) -> int: """Encrypts a message using encryption key 'ekey', working modulo n""" @@ -35,11 +43,7 @@ def encrypt_int(message: int, ekey: int, n: int) -> int: assert_int(ekey, "ekey") assert_int(n, "n") - if message < 0: - raise ValueError("Only non-negative numbers are supported") - - if message >= n: - raise OverflowError("The message %i is too long for n=%i" % (message, n)) + assert_length(message, n) return pow(message, ekey, n) @@ -51,6 +55,8 @@ def decrypt_int(cyphertext: int, dkey: int, n: int) -> int: assert_int(dkey, "dkey") assert_int(n, "n") + assert_length(cyphertext, n) + message = pow(cyphertext, dkey, n) return message @@ -60,6 +66,7 @@ def decrypt_int_fast( rs: typing.List[int], ds: typing.List[int], ts: typing.List[int], + n: int, ) -> int: """Decrypts a cypher text more quickly using the Chinese Remainder Theorem.""" @@ -70,6 +77,8 @@ def decrypt_int_fast( assert_int(d, "d") for t in ts: assert_int(t, "t") + + assert_length(cyphertext, n) p, q, rs = rs[0], rs[1], rs[2:] exp1, exp2, ds = ds[0], ds[1], ds[2:] diff --git a/rsa/key.py b/rsa/key.py index fd30447..97a34e3 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -551,6 +551,7 @@ def blinded_decrypt(self, encrypted: int) -> int: [self.p, self.q] + self.rs, [self.exp1, self.exp2] + self.ds, [self.coef] + self.ts, + self.n, ) return self.unblind(decrypted, blindfac_inverse)