aboutsummaryrefslogtreecommitdiffstats
path: root/src/ressol.cpp
blob: ab4ad78760bc7a08a1281b112198fd334ed16799 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/*************************************************
* Shanks-Tonnelli (RESSOL) Source File           *
* (C) 2007-2008 Falko Strenzke, FlexSecure GmbH  *
* (C) 2008 Jack Lloyd                            *
*************************************************/

#include <botan/numthry.h>
#include <botan/reducer.h>

#include <iostream>

namespace Botan {

/*************************************************
* Shanks-Tonnelli algorithm                      *
*************************************************/
BigInt ressol(const BigInt& a, const BigInt& p)
   {
   if(a < 0)
      throw Invalid_Argument("ressol(): a to solve for must be positive");
   if(p <= 1)
      throw Invalid_Argument("ressol(): prime must be > 1");

   if(a == 0)
      return 0;
   if(p == 2)
      return a;

   if(jacobi(a, p) != 1) // not a quadratic residue
      return BigInt("-1");

   if(p % 4 == 3)
      return power_mod(a, ((p+1) >> 2), p);

   u32bit s = low_zero_bits(p - 1);
   BigInt q = p >> s;

   q -= 1;
   q >>= 1;

   Modular_Reducer mod_p(p);

   BigInt r = power_mod(a, q, p);
   BigInt n = mod_p.multiply(a, mod_p.square(r));
   r = mod_p.multiply(r, a);

   if(n == 1)
      return r;

   // find random non quadratic residue z
   BigInt z = 2;
   while(jacobi(z, p) == 1) // while z quadratic residue
      ++z;

   BigInt c = power_mod(z, (q << 1) + 1, p);

   while(n > 1)
      {
      q = n;

      u32bit i = 0;
      while(q != 1)
         {
         q = mod_p.square(q);
         ++i;
         }
      u32bit t = s;

      if(t <= i)
         return BigInt("-1");

      c = power_mod(c, BigInt(BigInt::Power2, t-i-1), p);
      r = mod_p.multiply(r, c);
      c = mod_p.square(c);
      n = mod_p.multiply(n, c);
      s = i;
      }

   return r;
   }

}