File CVE-2023-52323-side_channel-RSA_decrypt.patch of Package python-pycryptodome.32675

From afb5e27a15efe59e33c2825d40ef44995c13b8bc Mon Sep 17 00:00:00 2001
From: Helder Eijs <helderijs@gmail.com>
Date: Wed, 20 Dec 2023 20:46:08 +0100
Subject: [PATCH] Fix side-channel leakage in RSA decryption

---
 lib/Crypto/Cipher/PKCS1_OAEP.py                  |    6 -
 lib/Crypto/Cipher/PKCS1_v1_5.py                  |    6 -
 lib/Crypto/Math/_IntegerBase.py                  |   20 +++
 lib/Crypto/Math/_IntegerBase.pyi                 |    4 
 lib/Crypto/Math/_IntegerCustom.py                |   56 +++++++++-
 lib/Crypto/Math/_IntegerGMP.py                   |   20 +++
 lib/Crypto/Math/_IntegerNative.py                |   12 ++
 lib/Crypto/PublicKey/RSA.py                      |   10 +
 lib/Crypto/SelfTest/Math/__init__.py             |    2 
 lib/Crypto/SelfTest/Math/test_Numbers.py         |   28 +++++
 lib/Crypto/SelfTest/Math/test_modmult.py         |  120 +++++++++++++++++++++++
 lib/Crypto/SelfTest/PublicKey/test_RSA.py        |    4 
 lib/Crypto/SelfTest/PublicKey/test_import_RSA.py |    6 -
 lib/Crypto/Signature/pkcs1_15.py                 |    9 -
 lib/Crypto/Signature/pss.py                      |   11 +-
 src/modexp.c                                     |   68 ++++++++++++-
 16 files changed, 349 insertions(+), 33 deletions(-)
 create mode 100644 lib/Crypto/SelfTest/Math/test_modmult.py

--- a/lib/Crypto/Cipher/PKCS1_OAEP.py
+++ b/lib/Crypto/Cipher/PKCS1_OAEP.py
@@ -167,10 +167,8 @@ class PKCS1OAEP_Cipher:
             raise ValueError("Ciphertext with incorrect length.")
         # Step 2a (O2SIP)
         ct_int = bytes_to_long(ciphertext)
-        # Step 2b (RSADP)
-        m_int = self._key._decrypt(ct_int)
-        # Complete step 2c (I2OSP)
-        em = long_to_bytes(m_int, k)
+        # Step 2b (RSADP) and step 2c (I2OSP)
+        em = self._key._decrypt(ct_int)
         # Step 3a
         lHash = self._hashObj.new(self._label).digest()
         # Step 3b
--- a/lib/Crypto/Cipher/PKCS1_v1_5.py
+++ b/lib/Crypto/Cipher/PKCS1_v1_5.py
@@ -165,10 +165,8 @@ class PKCS115_Cipher:
             raise ValueError("Ciphertext with incorrect length.")
         # Step 2a (O2SIP)
         ct_int = bytes_to_long(ciphertext)
-        # Step 2b (RSADP)
-        m_int = self._key._decrypt(ct_int)
-        # Complete step 2c (I2OSP)
-        em = long_to_bytes(m_int, k)
+        # Step 2b (RSADP) and Step 2c (I2OSP)
+        em = self._key._decrypt(ct_int)
         # Step 3
         sep = em.find(b'\x00', 2)
         if  not em.startswith(b'\x00\x02') or sep < 10:
--- a/lib/Crypto/Math/_IntegerBase.py
+++ b/lib/Crypto/Math/_IntegerBase.py
@@ -390,3 +390,23 @@ class IntegerBase(ABC):
                                     )
         return norm_candidate + min_inclusive
 
+    @staticmethod
+    @abc.abstractmethod
+    def _mult_modulo_bytes(term1, term2, modulus):
+        """Multiply two integers, take the modulo, and encode as big endian.
+        This specialized method is used for RSA decryption.
+
+        Args:
+          term1 : integer
+            The first term of the multiplication, non-negative.
+          term2 : integer
+            The second term of the multiplication, non-negative.
+          modulus: integer
+            The modulus, a positive odd number.
+        :Returns:
+            A byte string, with the result of the modular multiplication
+            encoded in big endian mode.
+            It is as long as the modulus would be, with zero padding
+            on the left if needed.
+        """
+        pass
--- a/lib/Crypto/Math/_IntegerBase.pyi
+++ b/lib/Crypto/Math/_IntegerBase.pyi
@@ -58,4 +58,8 @@ class IntegerBase:
     def random(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
     @classmethod
     def random_range(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
+    @staticmethod
+    def _mult_modulo_bytes(term1: Union[IntegerBase, int],
+                           term2:  Union[IntegerBase, int],
+                           modulus: Union[IntegerBase, int]) -> bytes: ...
 
--- a/lib/Crypto/Math/_IntegerCustom.py
+++ b/lib/Crypto/Math/_IntegerCustom.py
@@ -41,12 +41,18 @@ from Crypto.Util._raw_api import (load_p
 from Crypto.Random.random import getrandbits
 
 c_defs = """
-int monty_pow(const uint8_t *base,
-               const uint8_t *exp,
-               const uint8_t *modulus,
-               uint8_t       *out,
-               size_t len,
-               uint64_t seed);
+int monty_pow(uint8_t       *out,
+              const uint8_t *base,
+              const uint8_t *exp,
+              const uint8_t *modulus,
+              size_t        len,
+              uint64_t      seed);
+
+int monty_multiply(uint8_t       *out,
+                   const uint8_t *term1,
+                   const uint8_t *term2,
+                   const uint8_t *modulus,
+                   size_t        len);
 """
 
 
@@ -109,3 +115,41 @@ class IntegerCustom(IntegerNative):
         result = bytes_to_long(get_raw_buffer(out))
         self._value = result
         return self
+
+    @staticmethod
+    def _mult_modulo_bytes(term1, term2, modulus):
+
+        # With modular reduction
+        mod_value = int(modulus)
+        if mod_value < 0:
+            raise ValueError("Modulus must be positive")
+        if mod_value == 0:
+            raise ZeroDivisionError("Modulus cannot be zero")
+
+        # C extension only works with odd moduli
+        if (mod_value & 1) == 0:
+            raise ValueError("Odd modulus is required")
+
+        # C extension only works with non-negative terms smaller than modulus
+        if term1 >= mod_value or term1 < 0:
+            term1 %= mod_value
+        if term2 >= mod_value or term2 < 0:
+            term2 %= mod_value
+
+        modulus_b = long_to_bytes(mod_value)
+        numbers_len = len(modulus_b)
+        term1_b = long_to_bytes(term1, numbers_len)
+        term2_b = long_to_bytes(term2, numbers_len)
+        out = create_string_buffer(numbers_len)
+
+        error = _raw_montgomery.monty_multiply(
+                    out,
+                    term1_b,
+                    term2_b,
+                    modulus_b,
+                    c_size_t(numbers_len)
+                    )
+        if error:
+            raise ValueError("monty_multiply failed with error: %d" % error)
+
+        return get_raw_buffer(out)
--- a/lib/Crypto/Math/_IntegerGMP.py
+++ b/lib/Crypto/Math/_IntegerGMP.py
@@ -695,6 +695,26 @@ class IntegerGMP(IntegerBase):
             raise ValueError("n must be positive even for the Jacobi symbol")
         return _gmp.mpz_jacobi(a._mpz_p, n._mpz_p)
 
+    @staticmethod
+    def _mult_modulo_bytes(term1, term2, modulus):
+        if not isinstance(term1, IntegerGMP):
+            term1 = IntegerGMP(term1)
+        if not isinstance(term2, IntegerGMP):
+            term2 = IntegerGMP(term2)
+        if not isinstance(modulus, IntegerGMP):
+            modulus = IntegerGMP(modulus)
+
+        if modulus < 0:
+            raise ValueError("Modulus must be positive")
+        if modulus == 0:
+            raise ZeroDivisionError("Modulus cannot be zero")
+        if (modulus & 1) == 0:
+            raise ValueError("Odd modulus is required")
+
+        numbers_len = len(modulus.to_bytes())
+        result = ((term1 * term2) % modulus).to_bytes(numbers_len)
+        return result
+
     # Clean-up
     def __del__(self):
 
--- a/lib/Crypto/Math/_IntegerNative.py
+++ b/lib/Crypto/Math/_IntegerNative.py
@@ -378,3 +378,15 @@ class IntegerNative(IntegerBase):
         n1 = n % a1
         # Step 8
         return s * IntegerNative.jacobi_symbol(n1, a1)
+
+    @staticmethod
+    def _mult_modulo_bytes(term1, term2, modulus):
+        if modulus < 0:
+            raise ValueError("Modulus must be positive")
+        if modulus == 0:
+            raise ZeroDivisionError("Modulus cannot be zero")
+        if (modulus & 1) == 0:
+            raise ValueError("Odd modulus is required")
+
+        number_len = len(long_to_bytes(modulus))
+        return long_to_bytes((term1 * term2) % modulus, number_len)
--- a/lib/Crypto/PublicKey/RSA.py
+++ b/lib/Crypto/PublicKey/RSA.py
@@ -38,6 +38,7 @@ from Crypto import Random
 from Crypto.IO import PKCS8, PEM
 from Crypto.Util.py3compat import tobytes, bord, tostr
 from Crypto.Util.asn1 import DerSequence
+from Crypto.Util.number import bytes_to_long
 
 from Crypto.Math.Numbers import Integer
 from Crypto.Math.Primality import (test_probable_prime,
@@ -165,10 +166,11 @@ class RsaKey(object):
         h = (h * self._u) % self._q
         mp = h * self._p + m1
         # Step 4: Compute m = m**(r-1) mod n
-        result = (r.inverse(self._n) * mp) % self._n
-        # Verify no faults occured
-        if ciphertext != pow(result, self._e, self._n):
-            raise ValueError("Fault detected in RSA decryption")
+        # then encode into a big endian byte string
+        result = Integer._mult_modulo_bytes(
+                    r.inverse(self._n),
+                    mp,
+                    self._n)
         return result
 
     def has_private(self):
--- a/lib/Crypto/SelfTest/Math/__init__.py
+++ b/lib/Crypto/SelfTest/Math/__init__.py
@@ -38,9 +38,11 @@ def get_tests(config={}):
     from Crypto.SelfTest.Math import test_Numbers
     from Crypto.SelfTest.Math import test_Primality
     from Crypto.SelfTest.Math import test_modexp
+    from Crypto.SelfTest.Math import test_modmult
     tests += test_Numbers.get_tests(config=config)
     tests += test_Primality.get_tests(config=config)
     tests += test_modexp.get_tests(config=config)
+    tests += test_modmult.get_tests(config=config)
     return tests
 
 if __name__ == '__main__':
--- a/lib/Crypto/SelfTest/Math/test_Numbers.py
+++ b/lib/Crypto/SelfTest/Math/test_Numbers.py
@@ -672,6 +672,34 @@ class TestIntegerBase(unittest.TestCase)
         v1, = self.Integers(0x10)
         self.assertEqual(hex(v1), "0x10")
 
+    def test_mult_modulo_bytes(self):
+        modmult = self.Integer._mult_modulo_bytes
+
+        res = modmult(4, 5, 19)
+        self.assertEqual(res, b'\x01')
+
+        res = modmult(4 - 19, 5, 19)
+        self.assertEqual(res, b'\x01')
+
+        res = modmult(4, 5 - 19, 19)
+        self.assertEqual(res, b'\x01')
+
+        res = modmult(4 + 19, 5, 19)
+        self.assertEqual(res, b'\x01')
+
+        res = modmult(4, 5 + 19, 19)
+        self.assertEqual(res, b'\x01')
+
+        modulus = 2**512 - 1    # 64 bytes
+        t1 = 13**100
+        t2 = 17**100
+        expect = b"\xfa\xb2\x11\x87\xc3(y\x07\xf8\xf1n\xdepq\x0b\xca\xf3\xd3B,\xef\xf2\xfbf\xcc)\x8dZ*\x95\x98r\x96\xa8\xd5\xc3}\xe2q:\xa2'z\xf48\xde%\xef\t\x07\xbc\xc4[C\x8bUE2\x90\xef\x81\xaa:\x08"
+        self.assertEqual(expect, modmult(t1, t2, modulus))
+
+        self.assertRaises(ZeroDivisionError, modmult, 4, 5, 0)
+        self.assertRaises(ValueError, modmult, 4, 5, -1)
+        self.assertRaises(ValueError, modmult, 4, 5, 4)
+
 
 class TestIntegerInt(TestIntegerBase):
 
--- /dev/null
+++ b/lib/Crypto/SelfTest/Math/test_modmult.py
@@ -0,0 +1,120 @@
+#
+#  SelfTest/Math/test_modmult.py: Self-test for custom modular multiplication
+#
+# ===================================================================
+#
+# Copyright (c) 2023, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in
+#    the documentation and/or other materials provided with the
+#    distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for the custom modular multiplication"""
+
+import unittest
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+from Crypto.Util._raw_api import (create_string_buffer,
+                                  get_raw_buffer,
+                                  c_size_t)
+
+from Crypto.Math._IntegerCustom import _raw_montgomery
+
+
+class ExceptionModulus(ValueError):
+    pass
+
+
+def monty_mult(term1, term2, modulus):
+
+    if term1 >= modulus:
+        term1 %= modulus
+    if term2 >= modulus:
+        term2 %= modulus
+
+    modulus_b = long_to_bytes(modulus)
+    numbers_len = len(modulus_b)
+    term1_b = long_to_bytes(term1, numbers_len)
+    term2_b = long_to_bytes(term2, numbers_len)
+
+    out = create_string_buffer(numbers_len)
+    error = _raw_montgomery.monty_multiply(
+                out,
+                term1_b,
+                term2_b,
+                modulus_b,
+                c_size_t(numbers_len)
+                )
+
+    if error == 17:
+        raise ExceptionModulus()
+    if error:
+        raise ValueError("monty_multiply() failed with error: %d" % error)
+
+    return get_raw_buffer(out)
+
+
+modulus1 = 0xd66691b20071be4d66d4b71032b37fa007cfabf579fcb91e50bfc2753b3f0ce7be74e216aef7e26d4ae180bc20d7bd3ea88a6cbf6f87380e613c8979b5b043b200a8ff8856a3b12875e36e98a7569f3852d028e967551000b02c19e9fa52e83115b89309aabb1e1cf1e2cb6369d637d46775ce4523ea31f64ad2794cbc365dd8a35e007ed3b57695877fbf102dbeb8b3212491398e494314e93726926e1383f8abb5889bea954eb8c0ca1c62c8e9d83f41888095c5e645ed6d32515fe0c58c1368cad84694e18da43668c6f43e61d7c9bca633ddcda7aef5b79bc396d4a9f48e2a9abe0836cc455e435305357228e93d25aaed46b952defae0f57339bf26f5a9
+
+
+class TestModMultiply(unittest.TestCase):
+
+    def test_small(self):
+        self.assertEqual(b"\x01", monty_mult(5, 6, 29))
+
+    def test_large(self):
+        numbers_len = (modulus1.bit_length() + 7) // 8
+
+        t1 = modulus1 // 2
+        t2 = modulus1 - 90
+        expect = b'\x00' * (numbers_len - 1) + b'\x2d'
+        self.assertEqual(expect, monty_mult(t1, t2, modulus1))
+
+    def test_zero_term(self):
+        numbers_len = (modulus1.bit_length() + 7) // 8
+        expect = b'\x00' * numbers_len
+        self.assertEqual(expect, monty_mult(0x100, 0, modulus1))
+        self.assertEqual(expect, monty_mult(0, 0x100, modulus1))
+
+    def test_larger_term(self):
+        t1 = 2**2047
+        expect_int = 0x8edf4071f78e3d7ba622cdbbbef74612e301d69186776ae6bf87ff38c320d9aebaa64889c2f67de2324e6bccd2b10ad89e91fd21ba4bb523904d033eff5e70e62f01a84f41fa90a4f248ef249b82e1d2729253fdfc2a3b5b740198123df8bfbf7057d03e15244ad5f26eb9a099763b5c5972121ec076b0bf899f59bd95f7cc129abddccf24217bce52ca0f3a44c9ccc504765dbb89734205f3ae6a8cc560494a60ea84b27d8e00fa24bdd5b4f1d4232edb61e47d3d984c1fa50a3820a2e580fbc3fc8bc11e99df53b9efadf5a40ac75d384e400905aa6f1d88950cd53b1c54dc2222115ad84a27260fa4d978155c1434c551de1ee7361a17a2f79d4388f78a5d
+        res = bytes_to_long(monty_mult(t1, t1, modulus1))
+        self.assertEqual(res, expect_int)
+
+
+def get_tests(config={}):
+    tests = []
+    tests += list_test_cases(TestModMultiply)
+    return tests
+
+
+if __name__ == '__main__':
+    def suite():
+        return unittest.TestSuite(get_tests())
+    unittest.main(defaultTest='suite')
--- a/lib/Crypto/SelfTest/PublicKey/test_RSA.py
+++ b/lib/Crypto/SelfTest/PublicKey/test_RSA.py
@@ -274,7 +274,7 @@ class RSATest(unittest.TestCase):
         ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
 
         # Test decryption
-        plaintext = rsaObj._decrypt(ciphertext)
+        plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))
 
         # Test encryption (2 arguments)
         new_ciphertext2 = rsaObj._encrypt(plaintext)
@@ -299,7 +299,7 @@ class RSATest(unittest.TestCase):
         ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
 
         # Test plain decryption
-        new_plaintext = rsaObj._decrypt(ciphertext)
+        new_plaintext = bytes_to_long(rsaObj._decrypt(ciphertext))
         self.assertEqual(plaintext, new_plaintext)
 
 
--- a/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py
+++ b/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py
@@ -26,7 +26,7 @@ import re
 from Crypto.PublicKey import RSA
 from Crypto.SelfTest.st_common import *
 from Crypto.Util.py3compat import *
-from Crypto.Util.number import inverse
+from Crypto.Util.number import inverse, bytes_to_long
 from Crypto.Util import asn1
 
 from Crypto.Util._file_system import pycryptodome_filename
@@ -214,13 +214,13 @@ Lr7UkvEtFrRhDDKMtuIIq19FrL4pUIMymPMSLBn3
     def testImportKey5(self):
         """Verifies that the imported key is still a valid RSA pair"""
         key = RSA.importKey(self.rsaKeyPEM)
-        idem = key._encrypt(key._decrypt(89))
+        idem = key._encrypt(bytes_to_long(key._decrypt(89)))
         self.assertEqual(idem, 89)
 
     def testImportKey6(self):
         """Verifies that the imported key is still a valid RSA pair"""
         key = RSA.importKey(self.rsaKeyDER)
-        idem = key._encrypt(key._decrypt(65))
+        idem = key._encrypt(bytes_to_long(key._decrypt(65)))
         self.assertEqual(idem, 65)
 
     def testImportKey7(self):
--- a/lib/Crypto/Signature/pkcs1_15.py
+++ b/lib/Crypto/Signature/pkcs1_15.py
@@ -77,10 +77,11 @@ class PKCS115_SigScheme:
         em = _EMSA_PKCS1_V1_5_ENCODE(msg_hash, k)
         # Step 2a (OS2IP)
         em_int = bytes_to_long(em)
-        # Step 2b (RSASP1)
-        m_int = self._key._decrypt(em_int)
-        # Step 2c (I2OSP)
-        signature = long_to_bytes(m_int, k)
+        # Step 2b (RSASP1) and Step 2c (I2OSP)
+        signature = self._key._decrypt(em_int)
+        # Verify no faults occurred
+        if em_int != pow(bytes_to_long(signature), self._key.e, self._key.n):
+            raise ValueError("Fault detected in RSA private key operation")
         return signature
 
     def verify(self, msg_hash, signature):
--- a/lib/Crypto/Signature/pss.py
+++ b/lib/Crypto/Signature/pss.py
@@ -107,10 +107,11 @@ class PSS_SigScheme:
         em = _EMSA_PSS_ENCODE(msg_hash, modBits-1, self._randfunc, mgf, sLen)
         # Step 2a (OS2IP)
         em_int = bytes_to_long(em)
-        # Step 2b (RSASP1)
-        m_int = self._key._decrypt(em_int)
-        # Step 2c (I2OSP)
-        signature = long_to_bytes(m_int, k)
+        # Step 2b (RSASP1) and Step 2c (I2OSP)
+        signature = self._key._decrypt(em_int)
+        # Verify no faults occurred
+        if em_int != pow(bytes_to_long(signature), self._key.e, self._key.n):
+            raise ValueError("Fault detected in RSA private key operation")
         return signature
 
     def verify(self, msg_hash, signature):
@@ -178,7 +179,7 @@ def MGF1(mgfSeed, maskLen, hash_gen):
 
     :return: the mask, as a *byte string*
     """
-    
+
     T = b""
     for counter in iter_range(ceil_div(maskLen, hash_gen.digest_size)):
         c = long_to_bytes(counter, 4)
--- a/src/modexp.c
+++ b/src/modexp.c
@@ -179,6 +179,71 @@ cleanup:
     return res;
 }
 
+/*
+ * Modular multiplication. All numbers are
+ * encoded in big endian form, possibly with
+ * zero padding on the left.
+ *
+ * @param out     The memory area where to store the result
+ * @param term1   First term of the multiplication, strictly smaller than the modulus
+ * @param term2   Second term of the multiplication, strictly smaller than the modulus
+ * @param modulus Modulus, it must be odd
+ * @param len     Size in bytes of out, term1, term2, and modulus
+ * @return        0 in case of success, the appropriate error code otherwise
+ */
+EXPORT_SYM int monty_multiply(
+               uint8_t       *out,
+               const uint8_t *term1,
+               const uint8_t *term2,
+               const uint8_t *modulus,
+               size_t        len)
+{
+    MontContext *ctx = NULL;
+    uint64_t *mont_term1 = NULL;
+    uint64_t *mont_term2 = NULL;
+    uint64_t *mont_output = NULL;
+    uint64_t *scratchpad = NULL;
+    int res;
+
+    if (!term1 || !term2 || !modulus || !out)
+        return ERR_NULL;
+
+    if (len == 0)
+        return ERR_NOT_ENOUGH_DATA;
+
+    /* Allocations **/
+    res = mont_context_init(&ctx, modulus, len);
+    if (res)
+        return res;
+
+    res = mont_from_bytes(&mont_term1, term1, len, ctx);
+    if (res) goto cleanup;
+
+    res = mont_from_bytes(&mont_term2, term2, len, ctx);
+    if (res) goto cleanup;
+
+    res = mont_number(&mont_output, 1, ctx);
+    if (res) goto cleanup;
+
+    res = mont_number(&scratchpad, SCRATCHPAD_NR, ctx);
+    if (res) goto cleanup;
+
+    /* Multiply, then transform result back into big-endian, byte form **/
+    res = mont_mult(mont_output, mont_term1, mont_term2, scratchpad, ctx);
+    if (res) goto cleanup;
+
+    res = mont_to_bytes(out, len, mont_output, ctx);
+
+cleanup:
+    mont_context_free(ctx);
+    free(mont_term1);
+    free(mont_term2);
+    free(mont_output);
+    free(scratchpad);
+
+    return res;
+}
+
 #ifdef MAIN
 int main(void)
 {
@@ -205,7 +270,7 @@ int main(void)
     res = fread(out, 1, length, stdin);
 
     result = monty_pow(out, base, exponent, modulus, length, 12);
-    
+
     free(base);
     free(modulus);
     free(exponent);
@@ -232,5 +297,6 @@ int main(void)
     monty_pow(out, base, exponent, modulus, length, 12);
     }
 
+    monty_multiply(out, base, out, modulus, length);
 }
 #endif
openSUSE Build Service is sponsored by