NTT 的 C/C++ 实现
文章目录
- NTT (C ref)
-
- ntt_ref.h
- ntt_ref.c
- NTT (AVX2)
-
- ntt_avx2.h
- ntt_avx2.c
- Test
-
- cputimer.h
- Result
NTT (C ref)
ntt_ref.h
#ifndef NTT_H
#define NTT_Htypedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;//## 参数设置 ###define NTT_NEG 1 //0:循环NTT。1:反循环NTT。#define NTT_Q 12289
#define NTT_N 1024#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)#define NTT_ZETA 7//## 快速模约减 ###define MONT_L 16
#define MONT_R (1LL<<MONT_L)
#define MONT 4091 // MONT_R mod q
#define NEGQINV 12287 // -q^-1 mod MONT_R#define BARR_R (1LL<<32)
#define BARR 349497 // round(2^32/q)// 蒙特马利模约简,计算 a*R^{-1} mod q
//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + ((int16)((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L) // 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR*(int64)(a))>>32)*NTT_Q)//## 函数定义 ##void get_ntt_param(int32 q, int32 n, int32 r);void ntt(int16* f);
void intt(int16* f);
void nttmul(int16* r, const int16* a, const int16* b);int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);#endif
ntt_ref.c
#include <stdio.h>
#include <stdlib.h>
#include "ntt_ref.h"//## 参数设置 ##int16 zetas[NTT_ORDER + 1] = { };int16 zetas_mont[NTT_ORDER + 1] = { };int16 bitrev_list[NTT_ORDER] = { };int32 factor = 12277, factor_mont = 64;//## 通用函数 ##int32 brv(int32 b, int32 l)
{int32 bb = 0;for (int32 i = 0; i < l; i++){bb <<= 1;bb |= (b & 1);b >>= 1;}return bb;
}int64 fast_pow(int64 a, int64 b, int64 q)
{int64 result = 1;while (b != 0){if (b % 2 == 1)result = (result * a) % q;a = (a * a) % q;b >>= 1;}return result;
}int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{if (b == 0){*x = 1;*y = 0;return a;}int64 ret = exgcd(x, y, b, a % b);int64 tmp = *x;*x = *y;*y = tmp - (a / b) * (*y);return ret;
}int32 print_bytes(int8* arr, int32 len)
{printf("[ %d", arr[0]);for (int64 i = 1; i < len; i++)printf(", %d", arr[i]);printf(" ]");return 0;
}int32 print_coeffs(int16* arr, int32 len)
{printf("[ %d", arr[0]);for (int64 i = 1; i < len; i++)printf(", %d", arr[i]);printf(" ]");return 0;
}// 预计算参数 int64 find_root(int64 q, int64 ord)
{int64 w = 2;while (w < q){if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1){printf("%lld-th root = %lld\\n\\n", ord, w);return w;}w++;}return 0;
}void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{int64 wi = 1;int64 w = zeta;zetas[0] = 1;printf("zetas = { %d", zetas[0]);for (int64 i = 1; i <= ord; i++){wi = (wi * w) % q;zetas[i] = wi;printf(", %lld", wi);}printf(" };\\n\\n");
}void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{int64 wi_pre = mont * zetas[0] % q;zetas_mont[0] = wi_pre;printf("zetas_mont = { %lld", wi_pre);for (int64 i = 1; i <= ord; i++){wi_pre = mont * zetas[i] % q;zetas_mont[i] = wi_pre;printf(", %lld", wi_pre);}printf(" };\\n\\n");
}void get_brv_table(int32 bits)
{printf("bitrev_list = { 0");int32 len = (1LL << bits);for (int i = 1; i < len; i++)printf(", %d", brv(i, bits));printf(" };\\n\\n");
}void get_intt_factor(int64 q, int64 r, int64 mont)
{int64 factor, pinv;int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);factor = factor < 0 ? factor + q : factor;int64 factor_mont = (factor * mont) % q;printf("factor = %lld, factor_mont = %lld\\n\\n", factor, factor_mont);
}void get_ntt_param(int32 q, int32 n, int32 r)
{printf("/* get ntt params */\\n\\n");printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d, MONT_R = %lld, BARR_R = %lld\\n\\n", q, n, r, MONT_R, BARR_R);int64 d, x, y;int64 mont = MONT_R % q;d = exgcd(&x, &y, q, MONT_R);if (d != 1){printf("gcd(NTT_Q, MONT_R) != 1\\n");return;}printf("mont = %lld mod q\\n\\nqinv = %lld mod R\\n\\n", mont, x);printf("barret = 2^32/q = %lld\\n\\n", (BARR_R + (q >> 1)) / q);int32 order = 1 << (r+1);int64 Zeta;Zeta = find_root(q, order);int32* Zetas = (int32*)malloc(sizeof(int32)*(order + 1));int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));get_zetas(Zetas, Zeta, q, order);get_zetas_mont(Zetas_mont, Zetas, q, order, mont);get_brv_table(r + 1);get_intt_factor(q, r, mont);printf("//* get ntt params *//\\n\\n");free(Zetas);free(Zetas_mont);
}//## NTT变换 ##void ntt(int16* f) {int32 Blocknum = 1;int32 Blocksize = NTT_N;int32 Round = 0;/*Radix-2X = X + WYY = X - WY*/if ((NTT_ROUND & 1) == 1) { int32 offset = Blocksize >> 1;int32 X, Y, WY;int32 zeta_mont = zetas_mont[Blocknum * NTT_NEG];int16* pf = f;for (int32 k = 0; k < offset; k++) {X = pf[k];WY = pf[k + offset] * zeta_mont;WY = montgomery_reduce(WY);pf[k] = X + WY;pf[k + offset] = X + NTT_Q - WY;}Blocknum <<= 1;Blocksize >>= 1;Round++;}/*Radix-4Harvey,输入输出范围[0,2q)X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减*/for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {int32 offset = Blocksize >> 2;int32 X1, X2, Y1, Y2, WY;int32 zeta_mont, zeta1_mont, zeta2_mont;for (int32 i = 0; i < Blocknum; i++) {int16* pf = f + i * Blocksize;/*j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}brv_{j}(2i) = brv_{r}/(r-j+1)因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}*/zeta_mont = zetas_mont[bitrev_list[Blocknum * NTT_NEG + i] >> 1]; //Round层第i块zeta1_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1]; //Round+1层第2i块zeta2_mont = zetas_mont[bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1]; //Round+1层第2i+1块for (int k = 0; k < offset; k++) {X1 = pf[k];X2 = pf[k + offset];Y1 = pf[k + offset * 2];Y2 = pf[k + offset * 3];X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;WY = montgomery_reduce(Y1 * zeta_mont);Y1 = X1 + NTT_Q - WY;X1 += WY;X1 -= ((NTT_Q - X1 - 1) >> 31) & NTT_Q;Y1 -= ((NTT_Q - Y1 - 1) >> 31) & NTT_Q;WY = montgomery_reduce(Y2 * zeta_mont);Y2 = X2 + NTT_Q - WY;X2 += WY;WY = montgomery_reduce(X2 * zeta1_mont);X2 = X1 + NTT_Q - WY;X1 += WY;WY = montgomery_reduce(Y2 * zeta2_mont);Y2 = Y1 + NTT_Q - WY;Y1 += WY;pf[k] = X1;pf[k + offset] = X2;pf[k + offset * 2] = Y1;pf[k + offset * 3] = Y2;}}}//for (int32 k = 0; k < NTT_N; k++) {// f[k] -= ((NTT_Q - f[k] - 1) >> 31) & NTT_Q; //模约减,从[0,2q)约减到[0,q)//}
}void intt(int16* f) {int32 Blocknum = 1 << NTT_ROUND;int32 Blocksize = NTT_N >> NTT_ROUND;int32 Round = NTT_ROUND;int32 Qtimes2 = NTT_Q * 2;Blocksize <<= 2;Blocknum >>= 2;/*Radix-4Harvey,输入输出范围[0,2q)X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减*/for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {int32 offset = Blocksize >> 2;int32 X1, X2, Y1, Y2, T;int32 zeta_mont, zeta1_mont, zeta2_mont;for (int32 i = 0; i < Blocknum; i++) {int16* pf = f + i * Blocksize;/*j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}brv_{j}(2i) = brv_{r}/(r-j+1)因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}*/zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG + i] >> 1)]; //Round层第i块zeta1_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2] >> 1)]; //Round+1层第2i块zeta2_mont = zetas_mont[NTT_ORDER - (bitrev_list[2 * Blocknum * NTT_NEG + i * 2 + 1] >> 1)]; //Round+1层第2i+1块for (int k = 0; k < offset; k++) {X1 = pf[k];X2 = pf[k + offset];Y1 = pf[k + offset * 2];Y2 = pf[k + offset * 3];T = (X1 - X2) * zeta1_mont;X1 += X2;X2 = montgomery_reduce(T);X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减T = (Y1 - Y2) * zeta2_mont;Y1 += Y2;Y2 = montgomery_reduce(T);Y1 -= ((Qtimes2 - Y1 - 1) >> 31) & Qtimes2; //模约减T = (X1 - Y1) * zeta_mont;X1 += Y1;Y1 = montgomery_reduce(T);X1 -= ((Qtimes2 - X1 - 1) >> 31) & Qtimes2; //模约减T = (X2 - Y2) * zeta_mont;X2 += Y2;Y2 = montgomery_reduce(T);pf[k] = X1;pf[k + offset] = X2;pf[k + offset * 2] = Y1;pf[k + offset * 3] = Y2;}}}/*Radix-2X = X + YY = IW*(X - Y)*/if ((NTT_ROUND & 1) == 1) {int32 offset = Blocksize >> 1;int32 X, Y, T;int32 zeta_mont = zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)];int16* pf = f;for (int32 k = 0; k < offset; k++) {X = pf[k];Y = pf[k + offset];T = (X - Y) * zeta_mont;pf[k] = X + Y;pf[k + offset] = montgomery_reduce(T);}}//逆变换因子for (int32 k = 0; k < NTT_N; k++) {int32 X = f[k] * factor_mont;X = montgomery_reduce(X);f[k] = X - (((NTT_Q - X - 1) >> 31) & NTT_Q);}
}inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{int32 res; // 用更长的累加器,延迟取模运算int32 s;for (int32 i = 0; i < NTT_BASELEN; i++){res = 0;s = NTT_BASELEN + i;for (int32 j = 0; j <= i; j++)res += b[j] * a[i - j];for (int32 j = i + 1; j < NTT_BASELEN; j++)res += zeta * barrett_reduce(b[j] * a[s - j]);r[i] = barrett_reduce(res);}
}void nttmul(int16* r, const int16* a, const int16* b)
{// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1int32 num = 1 << NTT_ROUND;for (int32 i = 0; i < num; i++){
#if (NTT_BASELEN == 1)int32 tmp = *a * *b;*r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)// 第r层第2^{r-1}+i个多项式使用的单位根,// w_{2^r}^{brv_r(2^{r-1}+i)}int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);int32 tmp1 = a[0] * b[1] + a[1] * b[0];r[0] = barrett_reduce(tmp0);r[1] = barrett_reduce(tmp1);
#else// 第r层第2^{r-1}+i个多项式使用的单位根,// w_{2^r}^{brv_r(2^{r-1}+i)}int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];basemul(r, a, b, zeta);
#endifr += NTT_BASELEN;a += NTT_BASELEN;b += NTT_BASELEN;}
}
NTT (AVX2)
ntt_avx2.h
#ifndef NTT_H
#define NTT_Htypedef char int8;
typedef short int16;
typedef int int32;
typedef long long int64;typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned long long uint64;//## 参数设置 ###define NTT_NEG 1 //0:循环NTT。1:反循环NTT。#define NTT_Q 12289
#define NTT_N 1024#define NTT_ROUND 10
#define NTT_ORDER (1<<(NTT_ROUND+1))
#define NTT_BASELEN (NTT_N>>NTT_ROUND)#define NTT_ZETA 7//## 快速模约减 ###define MONT_L 16
#define MONT_R (1LL<<MONT_L)
#define MONT 4091 // MONT_R mod q
#define QINV -12287 // q^-1 mod MONT_R
#define NEGQINV 12287 // -q^-1 mod MONT_R#define BARR_epi16 5 // round(2^16/q)
#define BARR_epi32 349497 // round(2^32/q)// 蒙特马利模约简,计算 a*R^{-1} mod q
//#define montgomery_reduce(a) (((a) - (int32)((int16)((int64)(a)*QINV))*NTT_Q)>>MONT_L)//当qinv负数时结果[-q, q],与NTT中逻辑冲突(要求正数),需要变一下号
//Newhope中q=12289,它的MONT_R选为18位(16位时数据溢出?没有吧!)
#define montgomery_reduce(a) (((a) + (((int64)(a)*NEGQINV)&(MONT_R-1))*NTT_Q)>>MONT_L) // 巴雷特模约简,计算 a mod q
#define barrett_reduce(a) ((a)-((BARR_epi32*(int64)(a))>>32)*NTT_Q)//## 函数定义 ##void get_ntt_param(int32 q, int32 n, int32 r);void ntt(int16* f);
void intt(int16* f, int8 mont);
void nttmul(int16* r, const int16* a, const int16* b, int8 mont);int32 print_bytes(int8* arr, int32 len);
int32 print_coeffs(int16* arr, int32 len);#endif
ntt_avx2.c
#include <stdio.h>
#include <stdlib.h>
#include <xmmintrin.h> // __m128
#include <immintrin.h> // __m256
//#include <zmmintrin.h> // __m512
#include "ntt_avx2.h"//## 参数设置 ##const int16 zetas[NTT_ORDER + 1] = { };const int16 zetas_mont[NTT_ORDER + 1] = { };const int16 bitrev_list[NTT_ORDER] = { };const int32 factor = 12277, factor_mont = 64, factor_mont2 = 3755;//## 通用函数 ##int32 brv(int32 b, int32 l)
{int32 bb = 0;for (int32 i = 0; i < l; i++){bb <<= 1;bb |= (b & 1);b >>= 1;}return bb;
}int64 fast_pow(int64 a, int64 b, int64 q)
{int64 result = 1;while (b != 0){if (b % 2 == 1)result = (result * a) % q;a = (a * a) % q;b >>= 1;}return result;
}int64 exgcd(int64* x, int64* y, int64 a, int64 b)
{if (b == 0){*x = 1;*y = 0;return a;}int64 ret = exgcd(x, y, b, a % b);int64 tmp = *x;*x = *y;*y = tmp - (a / b) * (*y);return ret;
}int32 print_bytes(int8* arr, int32 len)
{printf("[ %d", arr[0]);for (int64 i = 1; i < len; i++)printf(", %d", arr[i]);printf(" ]");return 0;
}int32 print_coeffs(int16* arr, int32 len)
{printf("[ %d", arr[0]);for (int64 i = 1; i < len; i++)printf(", %d", arr[i]);printf(" ]");return 0;
}// 预计算参数 int64 find_root(int64 q, int64 ord)
{int64 w = 2;while (w < q){if (fast_pow(w, ord, q) == 1 && fast_pow(w, ord >> 1, q) != 1){printf("%lld-th root = %lld\\n\\n", ord, w);return w;}w++;}return 0;
}void get_zetas(int32* zetas, int32 zeta, int32 q, int32 ord)
{int64 wi = 1;int64 w = zeta;zetas[0] = 1;printf("zetas = { %d", zetas[0]);for (int64 i = 1; i <= ord; i++){wi = (wi * w) % q;zetas[i] = wi;printf(", %lld", wi);}printf(" };\\n\\n");
}void get_zetas_mont(int32* zetas_mont, int32* zetas, int64 q, int64 ord, int64 mont)
{int64 wi_pre = mont * zetas[0] % q;zetas_mont[0] = wi_pre;printf("zetas_mont = { %lld", wi_pre);for (int64 i = 1; i <= ord; i++){wi_pre = mont * zetas[i] % q;zetas_mont[i] = wi_pre;printf(", %lld", wi_pre);}printf(" };\\n\\n");
}void get_brv_table(int32 bits)
{printf("bitrev_list = { 0");int32 len = (1LL << bits);for (int i = 1; i < len; i++)printf(", %d", brv(i, bits));printf(" };\\n\\n");
}void get_intt_factor(int64 q, int64 r, int64 mont)
{int64 factor, pinv;int64 gcd = exgcd(&factor, &pinv, 1LL << r, q);factor = factor < 0 ? factor + q : factor;int64 factor_mont = (factor * mont) % q;int64 factor_mont2 = (factor_mont * mont) % q;printf("factor = %lld, factor_mont = %lld, factor_mont2 = %lld\\n\\n", factor, factor_mont, factor_mont2); //分别为:1/2^r,R/2^r,R^2/2^r
}void get_ntt_param(int32 q, int32 n, int32 r)
{printf("/* get ntt params */\\n\\n");printf("NTT_Q = %d, NTT_N = %d, NTT_ROUND = %d\\n\\n", q, n, r);int64 d, x, y;int64 mont = MONT_R % q;d = exgcd(&x, &y, q, MONT_R);if (d != 1){printf("gcd(NTT_Q, MONT_R) != 1\\n\\n");return;}printf("MONT_R = %lld\\nMONT = %lld mod q\\nQINV = %lld mod R\\n\\n", MONT_R, mont, x);printf("BARR_R = 2^16, BARR_epi16 = 2^16/q = %d\\nBARR_R = 2^32, BARR_epi32 = 2^32/q = %d\\n\\n", (16 + (q >> 1)) / q, (32 + (q >> 1)) / q);int32 order = 1 << (r + 1);int64 Zeta;Zeta = find_root(q, order);int32* Zetas = (int32*)malloc(sizeof(int32) * (order + 1));int32* Zetas_mont = (int32*)malloc(sizeof(int32) * (order + 1));get_zetas(Zetas, Zeta, q, order);get_zetas_mont(Zetas_mont, Zetas, q, order, mont);get_brv_table(r + 1);get_intt_factor(q, r, mont);printf("//* get ntt params *//\\n\\n");free(Zetas);free(Zetas_mont);
}//## Load/Store辅助函数 ##__m256i NTT_TMP;#define Half(X,Y)\\NTT_TMP = _mm256_permute2x128_si256(X, Y, 0x31);\\X = _mm256_permute2x128_si256(X, Y, 0x20);\\Y = NTT_TMP;#define Perm(X,Y)\\X = _mm256_permute4x64_epi64(X, 0b11011000);\\Y = _mm256_permute4x64_epi64(Y, 0b11011000);#define Coll_32(X,Y)\\X = _mm256_shuffle_epi32(X, 0b11011000);\\Y = _mm256_shuffle_epi32(Y, 0b11011000);\\
const int8 CollIndex[32] = {0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15,0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15
};#define Coll_16(X,Y)\\X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex);\\Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex);const int8 CollIndex_inv[32] = {0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15,0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15
};#define Coll_16_inv(X,Y)\\X = _mm256_shuffle_epi8(X,*(__m256i*)CollIndex_inv);\\Y = _mm256_shuffle_epi8(Y,*(__m256i*)CollIndex_inv);#define offset8(X,Y) Half(X, Y);
#define offset8_inv(X,Y) Half(X, Y);#define offset4(X,Y) Perm(X, Y); Half(X, Y);
#define offset4_inv(X,Y) Half(X, Y); Perm(X, Y); #define offset2(X,Y) Coll_32(X,Y); Perm(X,Y); Half(X, Y);
#define offset2_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_32(X,Y); #define offset1(X,Y) Coll_16(X,Y); Perm(X,Y); Half(X, Y);
#define offset1_inv(X,Y) Half(X, Y); Perm(X, Y); Coll_16_inv(X,Y); //## 快速模约减 ##/*
* montgomery_mul(a, zeta) = (a * zeta_mont - (R-1)&(a * zeta_mont * qinv) * q) >> t
* zeta_mont = zeta*R mod q,t=16,R=2^16
*/
__m256i montgomery_reduce_epi16(__m256i a, __m256i w) {__m256i q = _mm256_set1_epi16(NTT_Q);__m256i qinv = _mm256_set1_epi16(QINV);//正确性约束:(R - 1)*q + a*w < 2^32__m256i hi = _mm256_mulhi_epi16(a, w); //有符号的高位,epi与epu的乘法结果模2^32同余,比特表示相同__m256i lo = _mm256_mullo_epi16(a, w); //有符号的低位,是个无符号数,hi*65536 + lo/*只要不越界,_mm256_mullo_epi16 = _mm256_mullo_epu16但是,_mm256_mulhi_epi16 != _mm256_mulhi_epu16,注意高位补0还是补1*/__m256i tmp = _mm256_mullo_epi16(lo, qinv); //无符号模乘,R=2^16,无论lo和qinv是int16下的负数或正数tmp = _mm256_mulhi_epu16(tmp, q); //无符号乘法,要使用epu,将tmp识别为无符号数,抑制乘法的高位补1/*a*w和tmp*q的低16位相同,没有进位借位hi是负数,减去无符号tmp后还是负数hi是正数,减去无符号tmp后可能是正数也可能是负数*/hi = _mm256_sub_epi16(hi, tmp);//要让约减结果范围[0,q],不能出现负数(与Harvey蝴蝶冲突)tmp = _mm256_srai_epi16(hi, 15); //算数右移tmp = _mm256_and_si256(tmp, q);hi = _mm256_add_epi16(hi, tmp);return hi;
}/*
* a - ((m*a)>>t) * q
* m = R/q,t=16,R=2^16
*/
__m256i barrett_reduce_epi16(__m256i a) {__m256i q = _mm256_set1_epi16(NTT_Q);__m256i m = _mm256_set1_epi16(BARR_epi16);__m256i tmp = _mm256_mulhi_epi16(a, m); //有符号tmp = _mm256_mullo_epi16(tmp, q);a = _mm256_sub_epi16(a, tmp);return a; //范围乱变[-q, 2q)
}/*
* a - ((m*a)>>t) * q
* m = R/q,t=32,R=2^32
*/
__m256i barrett_reduce_epi32(__m256i a) {__m256i q = _mm256_set1_epi32(NTT_Q);__m256i m = _mm256_set1_epi32(BARR_epi32);__m256i tmp1 = _mm256_mul_epi32(a, m);tmp1 = _mm256_srli_epi64(tmp1, 32); //本应算数右移,为了重构方便采用逻辑右移,截断结果仍有符号__m256i tmp2 = _mm256_shuffle_epi32(a, 0b10110001);tmp2 = _mm256_mul_epi32(tmp2, m);tmp2 = _mm256_srli_epi64(tmp2, 32);tmp2 = _mm256_shuffle_epi32(tmp2, 0b10110001);tmp1 = _mm256_or_si256(tmp1, tmp2); //重构为epi32tmp1 = _mm256_mullo_epi32(tmp1, q);tmp1 = _mm256_sub_epi32(a, tmp1);return tmp1; //范围乱变[-q, 2q)
}__m256i iflt0_addq(__m256i a) {__m256i q = _mm256_set1_epi16(NTT_Q);__m256i tmp = _mm256_srai_epi16(a, 15);tmp = _mm256_and_si256(tmp, q);return _mm256_add_epi16(a, tmp);
}__m256i ifgeq_subq(__m256i a) {__m256i q = _mm256_set1_epi16(NTT_Q);__m256i tmp = _mm256_set1_epi16(NTT_Q - 1);tmp = _mm256_sub_epi16(tmp, a);tmp = _mm256_srai_epi16(tmp, 15);tmp = _mm256_and_si256(tmp, q);return _mm256_sub_epi16(a, tmp);
}__m256i ifge2q_sub2q(__m256i a) {__m256i q = _mm256_set1_epi16(2 * NTT_Q);__m256i tmp = _mm256_set1_epi16(2 * NTT_Q - 1);tmp = _mm256_sub_epi16(tmp, a);tmp = _mm256_srai_epi16(tmp, 15);tmp = _mm256_and_si256(tmp, q);return _mm256_sub_epi16(a, tmp);
}//## NTT变换 ##void ntt(int16* f) {int32 Blocknum = 1;int32 Blocksize = NTT_N;int32 Round = 0;__m256i T, Q = _mm256_set1_epi16(NTT_Q);/*Radix-2X = X + WYY = X - WY*/if ((NTT_ROUND & 1) == 1) {int32 offset = Blocksize >> 1;__m256i W = _mm256_set1_epi16(bitrev_list[Blocknum * NTT_NEG] >> 1);for (int32 k = 0; k < offset; k += 16) {__m256i X = _mm256_loadu_si256((__m256i*)(f + k));__m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);_mm256_storeu_si256((__m256i*)(f + k), X);_mm256_storeu_si256((__m256i*)(f + k + offset), Y);}Blocknum <<= 1;Blocksize >>= 1;Round++;}/*Radix-4Harvey,输入输出范围[0,2q)X1 = (X1 + W*Y1) + W0*(X2 + W*Y2),范围[0,4q)X2 = (X1 + W*Y1) - W0*(X2 + W*Y2),范围[0,4q)Y1 = (X1 - W*Y1) + W1*(X2 - W*Y2),范围[0,4q)X2 = (X1 - W*Y1) - W1*(X2 - W*Y2),范围[0,4q)先约束X1范围[0,q),接着约束(X1 + W*Y1)和(X1 - W*Y1)范围[0,q),共三次模约减*/for (; Round < NTT_ROUND; Round += 2, Blocksize >>= 2, Blocknum <<= 2) {if (Blocksize >= 64)goto Block64;elseswitch (Blocksize){case 32: goto Block32;case 16: goto Block16;case 8: goto Block8;case 4: goto Block4;default:goto Error; //本代码仅处理:NTT_N 是2的幂次}Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块for (int32 i = 0; i < Blocknum; i++) {int32 offset = Blocksize >> 2;int32 num = offset >> 4; //16个系数1个YMMint16* pf = f + i * Blocksize;/*j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}brv_{j}(2i) = brv_{r}/(r-j+1)因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}*/int32 ind = Blocknum * NTT_NEG + i;__m256i W = _mm256_set1_epi16(zetas_mont[bitrev_list[ind] >> 1]); //Round层第i块__m256i W0 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2] >> 1]); //Round+1层第2i块__m256i W1 = _mm256_set1_epi16(zetas_mont[bitrev_list[ind * 2 + 1] >> 1]); //Round+1层第2i+1块for (int32 k = 0; k < num; k++) {__m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));__m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));__m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));__m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));X1 = ifgeq_subq(X1);T = montgomery_reduce_epi16(Y1, W);Y1 = _mm256_add_epi16(X1, Q);Y1 = _mm256_sub_epi16(Y1, T);X1 = _mm256_add_epi16(X1, T);T = montgomery_reduce_epi16(Y2, W);Y2 = _mm256_add_epi16(X2, Q);Y2 = _mm256_sub_epi16(Y2, T);X2 = _mm256_add_epi16(X2, T);X1 = ifgeq_subq(X1);T = montgomery_reduce_epi16(X2, W0);X2 = _mm256_add_epi16(X1, Q);X2 = _mm256_sub_epi16(X2, T);X1 = _mm256_add_epi16(X1, T);Y1 = ifgeq_subq(Y1);T = montgomery_reduce_epi16(Y2, W1);Y2 = _mm256_add_epi16(Y1, Q);Y2 = _mm256_sub_epi16(Y2, T);Y1 = _mm256_add_epi16(Y1, T);_mm256_storeu_si256((__m256i*)(pf + k * 16), X1);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);}}continue;Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块for (int32 i = 0; i < Blocknum; i++) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = Blocknum * NTT_NEG + i;int16 w = zetas_mont[bitrev_list[ind] >> 1];__m256i W = _mm256_set1_epi16(w);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);ind <<= 1;int16 w0 = zetas_mont[bitrev_list[ind] >> 1];int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);offset8(X, Y); X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset8_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块for (int32 i = 0; i < Blocknum; i+=2) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = Blocknum * NTT_NEG + i;int16 w0 = zetas_mont[bitrev_list[ind] >> 1];int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);offset8(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset8_inv(X, Y);ind <<= 1;int16 w00 = zetas_mont[bitrev_list[ind] >> 1];int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);offset4(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset4_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块for (int32 i = 0; i < Blocknum; i += 4) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = Blocknum * NTT_NEG + i;int16 w0 = zetas_mont[bitrev_list[ind] >> 1];int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);offset4(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset4_inv(X, Y);ind <<= 1;int16 w00 = zetas_mont[bitrev_list[ind] >> 1];int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);offset2(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset2_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块for (int32 i = 0; i < Blocknum; i += 8) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = Blocknum * NTT_NEG + i;int16 w0 = zetas_mont[bitrev_list[ind] >> 1];int16 w1 = zetas_mont[bitrev_list[ind + 1] >> 1];int16 w2 = zetas_mont[bitrev_list[ind + 2] >> 1];int16 w3 = zetas_mont[bitrev_list[ind + 3] >> 1];int16 w4 = zetas_mont[bitrev_list[ind + 4] >> 1];int16 w5 = zetas_mont[bitrev_list[ind + 5] >> 1];int16 w6 = zetas_mont[bitrev_list[ind + 6] >> 1];int16 w7 = zetas_mont[bitrev_list[ind + 7] >> 1];__m256i W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);offset2(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset2_inv(X, Y);ind <<= 1;int16 w00 = zetas_mont[bitrev_list[ind] >> 1];int16 w01 = zetas_mont[bitrev_list[ind + 1] >> 1];int16 w10 = zetas_mont[bitrev_list[ind + 2] >> 1];int16 w11 = zetas_mont[bitrev_list[ind + 3] >> 1];int16 w20 = zetas_mont[bitrev_list[ind + 4] >> 1];int16 w21 = zetas_mont[bitrev_list[ind + 5] >> 1];int16 w30 = zetas_mont[bitrev_list[ind + 6] >> 1];int16 w31 = zetas_mont[bitrev_list[ind + 7] >> 1];int16 w40 = zetas_mont[bitrev_list[ind + 8] >> 1];int16 w41 = zetas_mont[bitrev_list[ind + 9] >> 1];int16 w50 = zetas_mont[bitrev_list[ind + 10] >> 1];int16 w51 = zetas_mont[bitrev_list[ind + 11] >> 1];int16 w60 = zetas_mont[bitrev_list[ind + 12] >> 1];int16 w61 = zetas_mont[bitrev_list[ind + 13] >> 1];int16 w70 = zetas_mont[bitrev_list[ind + 14] >> 1];int16 w71 = zetas_mont[bitrev_list[ind + 15] >> 1];W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);offset1(X, Y);X = ifgeq_subq(X);T = montgomery_reduce_epi16(Y, W);Y = _mm256_add_epi16(X, Q);Y = _mm256_sub_epi16(Y, T);X = _mm256_add_epi16(X, T);offset1_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Error: //捕获块大小错误printf("Blocksize isn't power of 2.\\n");}for (int32 k = 0; k < NTT_N; k += 16) {__m256i X = _mm256_loadu_si256((__m256i*)(f + k)); //模约减,从[0,2q)约减到[0,q)X = ifgeq_subq(X);_mm256_storeu_si256((__m256i*)(f + k), X);}
}void intt(int16* f, int8 mont) {int32 Blocknum = 1 << NTT_ROUND;int32 Blocksize = NTT_N >> NTT_ROUND;int32 Round = NTT_ROUND;int32 Qtimes2 = NTT_Q * 2;Blocksize <<= 2;Blocknum >>= 2;__m256i T, Q = _mm256_set1_epi16(NTT_Q);/*Radix-4Harvey,输入输出范围[0,2q)X1 = (X1 + X2) + (Y1 + Y2),范围[0,8q)X2 = IW0*(X1 - X2) + IW1*(Y1 - Y2),范围[0,2q)Y1 = IW*((X1 + X2) + (Y1 + Y2)),范围[0,q)Y2 = IW*(IW0*(X1 - X2) + IW1*(Y1 - Y2)),范围[0,q)先约束(X1 + X2)和(Y1 + Y2)范围[0,2q),接着约束(X1 + X2) + (Y1 + Y2)范围[0,2q),共三次模约减*/for (; Round > 1; Round -= 2, Blocksize <<= 2, Blocknum >>= 2) {if (Blocksize >= 64)goto Block64;elseswitch (Blocksize){case 32: goto Block32;case 16: goto Block16;case 8: goto Block8;case 4: goto Block4;default:goto Error; //本代码仅处理:NTT_N 是2的幂次}Block64: //处理分块大小整除64的情况,使用4个YMM,处理1个块for (int32 i = 0; i < Blocknum; i++) {int32 offset = Blocksize >> 2;int32 num = offset >> 4; //16个系数1个YMMint16* pf = f + i * Blocksize;/*j=0是原始数组,第j次迭代中,j-1层第i个分块使用的单位根,w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{2^{r-j}*brv_{j}(2i)}brv_{j}(2i) = brv_{r}/(r-j+1)因此 w_{2^{j}}^{brv_{j}(2i)} = w_{2^{r}}^{brv_{r}(i)/2}*/int32 ind = Blocknum * NTT_NEG + i;__m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)]); //Round层第i块__m256i W0 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2] >> 1)]); //Round+1层第2i块__m256i W1 = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[ind * 2 + 1] >> 1)]); //Round+1层第2i+1块for (int32 k = 0; k < num; k++) {__m256i X1 = _mm256_loadu_si256((__m256i*)(pf + k * 16));__m256i X2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset));__m256i Y1 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 2));__m256i Y2 = _mm256_loadu_si256((__m256i*)(pf + k * 16 + offset * 3));T = _mm256_sub_epi16(X1, X2);X1 = _mm256_add_epi16(X1, X2);X2 = montgomery_reduce_epi16(T, W0);X1 = ifgeq_subq(X1);T = _mm256_sub_epi16(Y1, Y2);Y1 = _mm256_add_epi16(Y1, Y2);Y2 = montgomery_reduce_epi16(T, W1);Y1 = ifgeq_subq(Y1);T = _mm256_sub_epi16(X1, Y1);X1 = _mm256_add_epi16(X1, Y1);Y1 = montgomery_reduce_epi16(T, W);X1 = ifgeq_subq(X1);T = _mm256_sub_epi16(X2, Y2);X2 = _mm256_add_epi16(X2, Y2);Y2 = montgomery_reduce_epi16(T, W);X2 = ifgeq_subq(X2);_mm256_storeu_si256((__m256i*)(pf + k * 16), X1);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset), X2);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 2), Y1);_mm256_storeu_si256((__m256i*)(pf + k * 16 + offset * 3), Y2);}}continue;Block32: //处理分块大小为32的情况,使用2个YMM,处理1个块for (int32 i = 0; i < Blocknum; i++) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = (Blocknum * NTT_NEG + i) * 2;int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];__m256i W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);offset8(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset8_inv(X, Y);ind >>= 1;int16 w = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];W = _mm256_set1_epi16(w);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block16: //处理分块大小为16的情况,使用2个YMM,处理2个块for (int32 i = 0; i < Blocknum; i += 2) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = (Blocknum * NTT_NEG + i) * 2;int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];__m256i W = _mm256_setr_epi16(w00, w00, w00, w00, w01, w01, w01, w01, w10, w10, w10, w10, w11, w11, w11, w11);offset4(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset4_inv(X, Y);ind >>= 1;int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];W = _mm256_setr_epi16(w0, w0, w0, w0, w0, w0, w0, w0, w1, w1, w1, w1, w1, w1, w1, w1);offset8(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset8_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block8: //处理分块大小为8的情况,使用2个YMM,处理4个块for (int32 i = 0; i < Blocknum; i += 4) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = (Blocknum * NTT_NEG + i) * 2;int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];__m256i W = _mm256_setr_epi16(w00, w00, w01, w01, w10, w10, w11, w11, w20, w20, w21, w21, w30, w30, w31, w31);offset2(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset2_inv(X, Y);ind >>= 1;int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];W = _mm256_setr_epi16(w0, w0, w0, w0, w1, w1, w1, w1, w2, w2, w2, w2, w3, w3, w3, w3);offset4(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset4_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Block4: //处理分块大小为4的情况,使用2个YMM,处理8个块for (int32 i = 0; i < Blocknum; i += 8) {int16* pf = f + i * Blocksize;__m256i X = _mm256_loadu_si256((__m256i*)(pf));__m256i Y = _mm256_loadu_si256((__m256i*)(pf + 16));int32 ind = (Blocknum * NTT_NEG + i) * 2;int16 w00 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w01 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];int16 w10 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];int16 w11 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];int16 w20 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];int16 w21 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];int16 w30 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];int16 w31 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];int16 w40 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 8] >> 1)];int16 w41 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 9] >> 1)];int16 w50 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 10] >> 1)];int16 w51 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 11] >> 1)];int16 w60 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 12] >> 1)];int16 w61 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 13] >> 1)];int16 w70 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 14] >> 1)];int16 w71 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 15] >> 1)];__m256i W = _mm256_setr_epi16(w00, w01, w10, w11, w20, w21, w30, w31, w40, w41, w50, w51, w60, w61, w70, w71);offset1(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset1_inv(X, Y);ind >>= 1;int16 w0 = zetas_mont[NTT_ORDER - (bitrev_list[ind] >> 1)];int16 w1 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 1] >> 1)];int16 w2 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 2] >> 1)];int16 w3 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 3] >> 1)];int16 w4 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 4] >> 1)];int16 w5 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 5] >> 1)];int16 w6 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 6] >> 1)];int16 w7 = zetas_mont[NTT_ORDER - (bitrev_list[ind + 7] >> 1)];W = _mm256_setr_epi16(w0, w0, w1, w1, w2, w2, w3, w3, w4, w4, w5, w5, w6, w6, w7, w7);offset2(X, Y);T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);X = ifgeq_subq(X);offset2_inv(X, Y);_mm256_storeu_si256((__m256i*)(pf), X);_mm256_storeu_si256((__m256i*)(pf + 16), Y);}continue;Error: //捕获块大小错误printf("Blocksize isn't power of 2.\\n");}/*Radix-2X = X + YY = IW*(X - Y)*/if ((NTT_ROUND & 1) == 1) {int32 offset = Blocksize >> 1;int32 num = offset >> 4; //16个系数1个YMM__m256i W = _mm256_set1_epi16(zetas_mont[NTT_ORDER - (bitrev_list[Blocknum * NTT_NEG] >> 1)]);for (int32 k = 0; k < offset; k += 16) {__m256i X = _mm256_loadu_si256((__m256i*)(f + k));__m256i Y = _mm256_loadu_si256((__m256i*)(f + k + offset));T = _mm256_sub_epi16(X, Y);X = _mm256_add_epi16(X, Y);Y = montgomery_reduce_epi16(T, W);_mm256_storeu_si256((__m256i*)(f + k), X);_mm256_storeu_si256((__m256i*)(f + k + offset), Y);}}//逆变换因子__m256i F = _mm256_set1_epi16(factor_mont);if(mont != 0)F = _mm256_set1_epi16(factor_mont2); //执行了 montgomery 版本的 nttmul,需额外再乘一个 mont = R mod qfor (int32 k = 0; k < NTT_N; k += 16) {__m256i X = _mm256_loadu_si256((__m256i*)(f + k));X = montgomery_reduce_epi16(X, F);_mm256_storeu_si256((__m256i*)(f + k), X);}
}inline void basemul_mont(int16* r, const int16* a, const int16* b, int16 zeta_mont)
{int32 res; // 用更长的累加器,延迟取模运算int32 s;for (int32 i = 0; i < NTT_BASELEN; i++){res = 0;s = NTT_BASELEN + i;for (int32 j = 0; j <= i; j++)res += b[j] * a[i - j];for (int32 j = i + 1; j < NTT_BASELEN; j++) {res += montgomery_reduce(b[j] * zeta_mont) * a[s - j];}r[i] = montgomery_reduce(res); //结果是 r = a*b/R}
}void nttmul_mont(int16* r, const int16* a, const int16* b)
{// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1int32 num = 1 << NTT_ROUND;#if (NTT_BASELEN == 1) //AVX2实现for (int32 i = 0; i < num; i += 16) {__m256i X = _mm256_loadu_si256(a);__m256i Y = _mm256_loadu_si256(b);X = montgomery_reduce_epi16(X, Y);_mm256_storeu_si256(r, X);r += 16;a += 16;b += 16;}#elif (NTT_BASELEN == 2) //常规实现for (int32 i = 0; i < num; i++) {// 第r层第2^{r-1}+i个多项式使用的单位根,// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];int32 tmp0 = a[0] * b[0] + montgomery_reduce(zeta * a[1]) * b[1];int32 tmp1 = a[0] * b[1] + a[1] * b[0];r[0] = montgomery_reduce(tmp0);r[1] = montgomery_reduce(tmp1);r += NTT_BASELEN;a += NTT_BASELEN;b += NTT_BASELEN;}#elsefor (int32 i = 0; i < num; i++) {int32 zeta = zetas_mont[bitrev_list[num * NTT_NEG + i]];basemul_mont(r, a, b, zeta); //常规实现r += NTT_BASELEN;a += NTT_BASELEN;b += NTT_BASELEN;}#endif
}inline void basemul(int16* r, const int16* a, const int16* b, int16 zeta)
{int32 res; // 用更长的累加器,延迟取模运算int32 s;for (int32 i = 0; i < NTT_BASELEN; i++){res = 0;s = NTT_BASELEN + i;for (int32 j = 0; j <= i; j++)res += b[j] * a[i - j];for (int32 j = i + 1; j < NTT_BASELEN; j++)res += zeta * barrett_reduce(b[j] * a[s - j]);r[i] = barrett_reduce(res);}
}void nttmul(int16* r, const int16* a, const int16* b, int8 mont)
{if (mont == 0) {// 2^{r-1} 个 n/2^{r-1} 长小多项式,NTT_ROUND = r-1int32 num = 1 << NTT_ROUND;for (int32 i = 0; i < num; i++){
#if (NTT_BASELEN == 1)int32 tmp = *a * *b;*r = barrett_reduce(tmp);
#elif (NTT_BASELEN == 2)// 第r层第2^{r-1}+i个多项式使用的单位根,// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];int32 tmp0 = a[0] * b[0] + zeta * barrett_reduce(a[1] * b[1]);int32 tmp1 = a[0] * b[1] + a[1] * b[0];r[0] = barrett_reduce(tmp0);r[1] = barrett_reduce(tmp1);
#else// 第r层第2^{r-1}+i个多项式使用的单位根,// w_{2^r}^{brv_r(2^{r-1}+i)},NTT_ROUND = r-1int32 zeta = zetas[bitrev_list[num * NTT_NEG + i]];basemul(r, a, b, zeta);
#endifr += NTT_BASELEN;a += NTT_BASELEN;b += NTT_BASELEN;}}elsenttmul_mont(r, a, b);
}
Test
cputimer.h
#ifndef CPUTIMER
#define CPUTIMER#if defined(__linux__)
// Linux系统
#include <unistd.h>
#elif defined(_WIN32)
// Windows系统
#include <intrin.h>
#include <windows.h>
#endif/*单位:毫秒*/
void sleepms(int time) {
#if defined(__linux__)// Linux系统usleep(time * 1000);
#elif defined(_WIN32)// Windows系统Sleep(time);
#endif
}/* Needs echo 2 > /sys/devices/cpu/rdpmc */
unsigned long long cputimer() {// 以下三种方法,是等价的(只在 x86 上运行,而 x64 不支持内联汇编)// 1./*__asm {rdtsc;shl edx, 32;or eax, edx;}*/// 2.//__asm RDTSC;// 3./*__asm _emit 0x0F__asm _emit 0x31*/#if _WIN32return __rdtsc();
#elseunsigned int lo, hi;__asm__ volatile ("rdtsc" : "=a" (lo), "=d" (hi));return ((unsigned long long)hi << 32) | lo;
#endif
}
//unsigned long long cputimer(); // 独立汇编代码
/*align 16_cputimer:rdtscshl rdx, 32or rax, rdxret
*/unsigned long long CPUFrequency;// 测量 CPU 主频
unsigned long long GetFrequency() {unsigned long long t1 = cputimer();sleepms(1000);unsigned long long t2 = cputimer();CPUFrequency = t2 - t1;return CPUFrequency;
}#define pn printf("\\n\\n")unsigned long long TM_start, TM_end;
#define Timer(code) TM_start = cputimer(); code; TM_end = cputimer(); \\printf("time = %lld cycles (%f s)\\n", TM_end - TM_start, (double)(TM_end - TM_start)/CPUFrequency); //对code部分计时unsigned long long TM_mem[10000];
#define Loop(loop, code) for(int i=0; i<loop; i++) {\\TM_start = cputimer(); code; TM_end = cputimer(); TM_mem[i] = TM_end - TM_start;} Analyis_TM(loop); void __quick_sort(unsigned long long* arr, int begin, int end) //快速排序,简化版
{if (begin >= end)return;unsigned long long temp1 = arr[begin], temp2;int k = begin;for (int i = begin + 1; i <= end; i++){if (temp1 > arr[i]){temp2 = arr[i];int j;for (j = i - 1; j >= k; j--)arr[j + 1] = arr[j];arr[j + 1] = temp2;k++;}}__quick_sort(arr, begin, k - 1);__quick_sort(arr, k + 1, end);
}void quick_sort(unsigned long long* arr, int size)
{__quick_sort(arr, 0, size - 1);
}void Analyis_TM(int loop) //分析代码性能
{unsigned long long min, max, med, aver = 0;quick_sort(TM_mem, loop);min = TM_mem[0];max = TM_mem[loop-1];med = TM_mem[loop >> 1];for (int i = 0; i < loop; i++) {aver += TM_mem[i];}aver /= loop;printf("Time:\\n\\tMinimum\\t%10lld cycles,%10.6f ms\\n\\tMaximum\\t%10lld cycles,%10.6f ms\\n\\tMedian\\t%10lld cycles,%10.6f ms\\n\\tAverage\\t%10lld cycles,%10.6f ms\\n", min, (double)min / CPUFrequency * 1000, max, (double)max / CPUFrequency * 1000, med, (double)med / CPUFrequency * 1000, aver, (double)aver / CPUFrequency * 1000);
}#endif
Result
WSL
下用 gcc
编译:
gcc ntt_avx2.c test_ntt_avx2.c -o test_ntt_avx2 -O3 -fopt-info-vec-optimized -mavx2
执行 ./test_ntt_avx2
的结果为:
CPU Frequency = 2918844449Time: ntt Minimum 2588 cycles, 0.000887 msMaximum 39908 cycles, 0.013672 msMedian 2636 cycles, 0.000903 msAverage 2881 cycles, 0.000987 msTime: inttMinimum 2546 cycles, 0.000872 msMaximum 29334 cycles, 0.010050 msMedian 2586 cycles, 0.000886 msAverage 2650 cycles, 0.000908 msTime: nttmulMinimum 1828 cycles, 0.000626 msMaximum 91908 cycles, 0.031488 msMedian 1838 cycles, 0.000630 msAverage 1865 cycles, 0.000639 msTime: nttmul_montMinimum 164 cycles, 0.000056 msMaximum 7684 cycles, 0.002632 msMedian 186 cycles, 0.000064 msAverage 199 cycles, 0.000068 ms