aboutsummaryrefslogtreecommitdiffstats
path: root/src/scripts/comba.py
blob: 4d3befa2601a21e7142a47f7c14ac62b387ec2f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/python

import sys

# Used to generate src/lib/math/mp/mp_comba.cpp

def comba_indexes(N):

    indexes = []

    for i in xrange(0, 2*N):
        x = []

        for j in xrange(max(0, i-N+1), min(N, i+1)):
            x += [(j,i-j)]
        indexes += [sorted(x)]

    return indexes

def comba_sqr_indexes(N):

    indexes = []

    for i in xrange(0, 2*N):
        x = []

        for j in xrange(max(0, i-N+1), min(N, i+1)):
            if j < i-j:
                x += [(j,i-j)]
            else:
                x += [(i-j,j)]
        indexes += [sorted(x)]

    return indexes

def comba_multiply_code(N):
    indexes = comba_indexes(N)

    w2 = 'w2'
    w1 = 'w1'
    w0 = 'w0'

    for (i,idx) in zip(range(0, len(indexes)), indexes):
        for pair in idx:
            print "   word3_muladd(&%s, &%s, &%s, x[%2d], y[%2d]);" % (w2, w1, w0, pair[0], pair[1])

        if i < 2*N-2:
            print "   z[%2d] = %s; %s = 0;\n" % (i, w0, w0)
        else:
            print "   z[%2d] = %s;" % (i, w0)
        (w0,w1,w2) = (w1,w2,w0)
        #print "z[%2d] = w0; w0 = w1; w1 = w2; w2 = 0;" % (i)

def comba_square_code(N):
    indexes = comba_sqr_indexes(N)

    w2 = 'w2'
    w1 = 'w1'
    w0 = 'w0'

    for (rnd,idx) in zip(range(0, len(indexes)), indexes):
        for (i,pair) in zip(range(0, len(idx)), idx):
            if pair[0] == pair[1]:
                print "   word3_muladd(&%s, &%s, &%s, x[%2d], x[%2d]);" % (w2, w1, w0, pair[0], pair[1])
            elif i % 2 == 0:
                print "   word3_muladd_2(&%s, &%s, &%s, x[%2d], x[%2d]);" % (w2, w1, w0, pair[0], pair[1])

        if rnd < 2*N-2:
            print "   z[%2d] = %s; %s = 0;\n" % (rnd, w0, w0)
        else:
            print "   z[%2d] = %s;" % (rnd, w0)

        (w0,w1,w2) = (w1,w2,w0)

def main(args = None):
    if args is None:
        args = sys.argv

    print """/*
* Comba Multiplication and Squaring
* (C) 1999-2007,2011,2014 Jack Lloyd
*
* Distributed under the terms of the Botan license
*/

#include <botan/internal/mp_core.h>
#include <botan/internal/mp_asmi.h>

namespace Botan {

extern "C" {
"""

    for n in [4,6,8,9,16]:
        print "/*\n* Comba %dx%d Squaring\n*/" % (n, n)
        print "void bigint_comba_sqr%d(word z[%d], const word x[%d])" % (n, 2*n, n)
        print "   {"
        print "   word w2 = 0, w1 = 0, w0 = 0;\n"

        comba_square_code(n)

        print "   }\n"

        print "/*\n* Comba %dx%d Multiplication\n*/" % (n, n)
        print "void bigint_comba_mul%d(word z[%d], const word x[%d], const word y[%d])" % (n, 2*n, n, n)
        print "   {"
        print "   word w2 = 0, w1 = 0, w0 = 0;\n"

        comba_multiply_code(n)

        print "   }\n"

    print "}\n\n}"

if __name__ == '__main__':
    sys.exit(main())