ntl/ntl-loadtime-cpu.patch

2372 lines
79 KiB
Diff

--- doc/config.txt.orig 2021-06-20 15:05:49.000000000 -0600
+++ doc/config.txt 2021-06-23 19:59:29.902142132 -0600
@@ -420,6 +420,7 @@ NTL_AVOID_BRANCHING=off
NTL_GF2X_NOINLINE=off
NTL_GF2X_ALTCODE=off
NTL_GF2X_ALTCODE1=off
+NTL_LOADTIME_CPU=off
GMP_INCDIR=$(GMP_PREFIX)/include
GMP_LIBDIR=$(GMP_PREFIX)/lib
@@ -734,6 +735,10 @@ NTL_GF2X_ALTCODE1=off
# Yet another alternative implementation for GF2X multiplication.
+NTL_LOADTIME_CPU=off
+
+# switch to check CPU characteristics at load time and use routines
+# optimized for the executing CPU.
########## More GMP Options:
--- include/NTL/config.h.orig 2021-06-20 15:05:49.000000000 -0600
+++ include/NTL/config.h 2021-06-23 19:59:29.903142133 -0600
@@ -549,6 +549,19 @@ to be defined. Of course, to unset a f
#error "NTL_SAFE_VECTORS defined but not NTL_STD_CXX11 or NTL_STD_CXX14"
#endif
+#if 0
+#define NTL_LOADTIME_CPU
+
+/*
+ * With this flag enabled, detect advanced CPU features at load time instead
+ * of at compile time. This flag is intended for distributions, so that they
+ * can compile for the lowest common denominator CPU, but still support newer
+ * CPUs.
+ *
+ * This flag is useful only on x86_64 platforms with gcc 4.8 or later.
+ */
+
+#endif
--- include/NTL/ctools.h.orig 2021-06-20 15:05:49.000000000 -0600
+++ include/NTL/ctools.h 2021-06-23 19:59:29.904142134 -0600
@@ -518,6 +518,155 @@ char *_ntl_make_aligned(char *p, long al
// this should be big enough to satisfy any SIMD instructions,
// and it should also be as big as a cache line
+/* Determine CPU characteristics at runtime */
+#ifdef NTL_LOADTIME_CPU
+#if !defined(__x86_64__)
+#error Runtime CPU support is only available on x86_64.
+#endif
+#ifndef __GNUC__
+#error Runtime CPU support is only available with GCC.
+#endif
+#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 6)
+#error Runtime CPU support is only available with GCC 4.6 or later.
+#endif
+
+#include <cpuid.h>
+#ifndef bit_SSSE3
+#define bit_SSSE3 (1 << 9)
+#endif
+#ifndef bit_PCLMUL
+#define bit_PCLMUL (1 << 1)
+#endif
+#ifndef bit_AVX
+#define bit_AVX (1 << 28)
+#endif
+#ifndef bit_FMA
+#define bit_FMA (1 << 12)
+#endif
+#ifndef bit_AVX2
+#define bit_AVX2 (1 << 5)
+#endif
+
+#define BASE_FUNC(type,name) static type name##_base
+#define TARGET_FUNC(arch,suffix,type,name) \
+ static type __attribute__((target (arch))) name##_##suffix
+#define SSSE3_FUNC(type,name) TARGET_FUNC("ssse3",ssse3,type,name)
+#define PCLMUL_FUNC(type,name) TARGET_FUNC("pclmul,ssse3",pclmul,type,name)
+#define AVX_FUNC(type,name) TARGET_FUNC("avx,pclmul,ssse3",avx,type,name)
+#define FMA_FUNC(type,name) TARGET_FUNC("fma,avx,pclmul,ssse3",fma,type,name)
+#define AVX2_FUNC(type,name) TARGET_FUNC("avx2,fma,avx,pclmul,ssse3",avx2,type,name)
+#define SSSE3_RESOLVER(st,type,name,params) \
+ extern "C" { \
+ static type (*resolve_##name(void)) params { \
+ if (__builtin_expect(have_avx2, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) { \
+ have_avx2 = ((ebx & bit_AVX2) != 0); \
+ } else { \
+ have_avx2 = 0; \
+ } \
+ } \
+ if (__builtin_expect(have_ssse3, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \
+ have_ssse3 = ((ecx & bit_SSSE3) != 0); \
+ } else { \
+ have_ssse3 = 0; \
+ } \
+ } \
+ if (have_avx2) return &name##_avx2; \
+ if (have_ssse3) return &name##_ssse3; \
+ return &name##_base; \
+ } \
+ } \
+ st type __attribute__((ifunc ("resolve_" #name))) name params
+#define PCLMUL_RESOLVER(st,type,name,params) \
+ extern "C" { \
+ static type (*resolve_##name(void)) params { \
+ if (__builtin_expect(have_pclmul, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \
+ have_pclmul = ((ecx & bit_PCLMUL) != 0); \
+ have_avx = ((ecx & bit_AVX) != 0); \
+ have_fma = ((ecx & bit_FMA) != 0); \
+ } else { \
+ have_pclmul = 0; \
+ have_avx = 0; \
+ have_fma = 0; \
+ } \
+ } \
+ if (have_avx) return &name##_avx; \
+ if (have_pclmul) return &name##_pclmul; \
+ return &name##_base; \
+ } \
+ } \
+ st type __attribute__((ifunc ("resolve_" #name))) name params
+#define AVX_RESOLVER(st,type,name,params) \
+ extern "C" { \
+ static type (*resolve_##name(void)) params { \
+ if (__builtin_expect(have_pclmul, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \
+ have_pclmul = ((ecx & bit_PCLMUL) != 0); \
+ have_avx = ((ecx & bit_AVX) != 0); \
+ have_fma = ((ecx & bit_FMA) != 0); \
+ } else { \
+ have_pclmul = 0; \
+ have_avx = 0; \
+ have_fma = 0; \
+ } \
+ } \
+ return have_avx ? &name##_avx : &name##_base; \
+ } \
+ } \
+ st type __attribute__((ifunc ("resolve_" #name))) name params
+#define FMA_RESOLVER(st,type,name,params) \
+ extern "C" { \
+ static type (*resolve_##name(void)) params { \
+ if (__builtin_expect(have_pclmul, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \
+ have_pclmul = ((ecx & bit_PCLMUL) != 0); \
+ have_avx = ((ecx & bit_AVX) != 0); \
+ have_fma = ((ecx & bit_FMA) != 0); \
+ } else { \
+ have_pclmul = 0; \
+ have_avx = 0; \
+ have_fma = 0; \
+ } \
+ } \
+ return have_fma ? &name##_fma : &name##_avx; \
+ } \
+ } \
+ st type __attribute__((ifunc ("resolve_" #name))) name params
+#define AVX2_RESOLVER(st,type,name,params) \
+ extern "C" { \
+ static type (*resolve_##name(void)) params { \
+ if (__builtin_expect(have_avx2, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) { \
+ have_avx2 = ((ebx & bit_AVX2) != 0); \
+ } else { \
+ have_avx2 = 0; \
+ } \
+ } \
+ if (__builtin_expect(have_pclmul, 0) < 0) { \
+ unsigned int eax, ebx, ecx, edx; \
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \
+ have_pclmul = ((ecx & bit_PCLMUL) != 0); \
+ have_avx = ((ecx & bit_AVX) != 0); \
+ have_fma = ((ecx & bit_FMA) != 0); \
+ } else { \
+ have_pclmul = 0; \
+ have_avx = 0; \
+ have_fma = 0; \
+ } \
+ } \
+ return have_avx2 ? &name##_avx2 : &name##_fma; \
+ } \
+ } \
+ st type __attribute__((ifunc ("resolve_" #name))) name params
+#endif
#ifdef NTL_HAVE_BUILTIN_CLZL
--- include/NTL/MatPrime.h.orig 2021-06-20 15:05:49.000000000 -0600
+++ include/NTL/MatPrime.h 2021-06-23 19:59:29.904142134 -0600
@@ -20,7 +20,7 @@ NTL_OPEN_NNS
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
#define NTL_MatPrime_NBITS (23)
#else
#define NTL_MatPrime_NBITS NTL_SP_NBITS
--- include/NTL/REPORT_ALL_FEATURES.h.orig 2021-06-20 15:05:49.000000000 -0600
+++ include/NTL/REPORT_ALL_FEATURES.h 2021-06-23 19:59:29.905142135 -0600
@@ -63,3 +63,6 @@
std::cerr << "NTL_HAVE_KMA\n";
#endif
+#ifdef NTL_LOADTIME_CPU
+ std::cerr << "NTL_LOADTIME_CPU\n";
+#endif
--- src/cfile.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/cfile 2021-06-23 19:59:29.906142136 -0600
@@ -449,6 +449,19 @@ to be defined. Of course, to unset a f
#endif
+#if @{NTL_LOADTIME_CPU}
+#define NTL_LOADTIME_CPU
+
+/*
+ * With this flag enabled, detect advanced CPU features at load time instead
+ * of at compile time. This flag is intended for distributions, so that they
+ * can compile for the lowest common denominator CPU, but still support newer
+ * CPUs.
+ *
+ * This flag is useful only on x86_64 platforms with gcc 4.8 or later.
+ */
+
+#endif
#if @{NTL_CRT_ALTCODE}
--- src/DispSettings.cpp.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/DispSettings.cpp 2021-06-23 19:59:29.906142136 -0600
@@ -192,6 +192,9 @@ cout << "Performance Options:\n";
cout << "NTL_RANDOM_AES256CTR\n";
#endif
+#ifdef NTL_LOADTIME_CPU
+ cout << "NTL_LOADTIME_CPU\n";
+#endif
cout << "***************************/\n";
cout << "\n\n";
--- src/DoConfig.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/DoConfig 2021-06-23 19:59:29.907142137 -0600
@@ -1,6 +1,7 @@
# This is a perl script, invoked from a shell
use warnings; # this doesn't work on older versions of perl
+use Config;
system("echo '*** CompilerOutput.log ***' > CompilerOutput.log");
@@ -92,6 +93,7 @@ system("echo '*** CompilerOutput.log ***
'NTL_GF2X_NOINLINE' => 'off',
'NTL_GF2X_ALTCODE' => 'off',
'NTL_GF2X_ALTCODE1' => 'off',
+'NTL_LOADTIME_CPU' => 'off',
'NTL_RANDOM_AES256CTR' => 'off',
@@ -176,6 +178,14 @@ if ($MakeVal{'CXXFLAGS'} =~ '-march=') {
$MakeFlag{'NATIVE'} = 'off';
}
+# special processing: NTL_LOADTIME_CPU on x86/x86_64 only and => NTL_GF2X_NOINLINE
+
+if ($ConfigFlag{'NTL_LOADTIME_CPU'} eq 'on') {
+ if (!$Config{archname} =~ /x86_64/) {
+ die "Error: NTL_LOADTIME_CPU currently only available with x86_64...sorry\n";
+ }
+ $ConfigFlag{'NTL_GF2X_NOINLINE'} = 'on';
+}
# some special MakeVal values that are determined by SHARED
--- src/GF2EX.cpp.orig 2021-06-20 15:05:48.000000000 -0600
+++ src/GF2EX.cpp 2021-06-23 19:59:29.908142138 -0600
@@ -801,7 +801,7 @@ void mul(GF2EX& c, const GF2EX& a, const
if (GF2E::WordLength() <= 1) use_kron_mul = true;
-#if (defined(NTL_GF2X_LIB) && defined(NTL_HAVE_PCLMUL))
+#if (defined(NTL_GF2X_LIB) && (defined(NTL_HAVE_PCLMUL) || defined(NTL_LOADTIME_CPU)))
// With gf2x library and pclmul, KronMul is better in a larger range, but
// it is very hard to characterize that range. The following is very
// conservative.
--- src/GF2X1.cpp.orig 2021-06-20 15:05:48.000000000 -0600
+++ src/GF2X1.cpp 2021-06-23 19:59:29.910142141 -0600
@@ -18,7 +18,7 @@
// simple scaling factor for some crossover points:
// we use a lower crossover of the underlying multiplication
// is faster
-#if (defined(NTL_GF2X_LIB) || defined(NTL_HAVE_PCLMUL))
+#if (defined(NTL_GF2X_LIB) || defined(NTL_HAVE_PCLMUL) || defined(NTL_LOADTIME_CPU))
#define XOVER_SCALE (1L)
#else
#define XOVER_SCALE (2L)
--- src/GF2X.cpp.orig 2021-06-20 15:05:48.000000000 -0600
+++ src/GF2X.cpp 2021-06-23 19:59:29.911142142 -0600
@@ -27,6 +27,22 @@ pclmul_mul1 (unsigned long *c, unsigned
_mm_storeu_si128((__m128i*)c, _mm_clmulepi64_si128(aa, bb, 0));
}
+#elif defined(NTL_LOADTIME_CPU)
+
+#include <wmmintrin.h>
+
+static int have_pclmul = -1;
+static int have_avx = -1;
+static int have_fma = -1;
+
+#define NTL_INLINE inline
+
+#define pclmul_mul1(c,a,b) do { \
+ __m128i aa = _mm_setr_epi64( _mm_cvtsi64_m64(a), _mm_cvtsi64_m64(0)); \
+ __m128i bb = _mm_setr_epi64( _mm_cvtsi64_m64(b), _mm_cvtsi64_m64(0)); \
+ _mm_storeu_si128((__m128i*)(c), _mm_clmulepi64_si128(aa, bb, 0)); \
+} while (0)
+
#else
@@ -556,6 +572,27 @@ void add(GF2X& x, const GF2X& a, const G
+#ifdef NTL_LOADTIME_CPU
+
+BASE_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ NTL_EFF_BB_MUL_CODE0
+}
+
+PCLMUL_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ pclmul_mul1(c, a, b);
+}
+
+AVX_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ pclmul_mul1(c, a, b);
+}
+
+PCLMUL_RESOLVER(static,void,mul1,(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b));
+
+#else
+
static NTL_INLINE
void mul1(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
{
@@ -568,6 +605,7 @@ NTL_EFF_BB_MUL_CODE0
}
+#endif
#ifdef NTL_GF2X_NOINLINE
@@ -592,6 +630,51 @@ NTL_EFF_BB_MUL_CODE0
#endif
+#ifdef NTL_LOADTIME_CPU
+
+BASE_FUNC(void,Mul1)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ NTL_EFF_BB_MUL_CODE1
+}
+
+PCLMUL_FUNC(void,Mul1)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] = carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] = carry;
+}
+
+AVX_FUNC(void,Mul1)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] = carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] = carry;
+}
+
+PCLMUL_RESOLVER(static,void,Mul1,
+ (_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a));
+
+#else
+
static
void Mul1(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
{
@@ -620,6 +703,53 @@ NTL_EFF_BB_MUL_CODE1
// warning #13200: No EMMS instruction before return
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+
+BASE_FUNC(void,AddMul1)
+(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a)
+{
+ NTL_EFF_BB_MUL_CODE2
+}
+
+PCLMUL_FUNC(void,AddMul1)
+(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] ^= carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] ^= carry;
+}
+
+AVX_FUNC(void,AddMul1)
+(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] ^= carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] ^= carry;
+}
+
+PCLMUL_RESOLVER(static,void,AddMul1,
+ (_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a));
+
+#else
+
static
void AddMul1(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a)
{
@@ -648,6 +778,52 @@ NTL_EFF_BB_MUL_CODE2
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+
+BASE_FUNC(void,Mul1_short)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ NTL_EFF_SHORT_BB_MUL_CODE1
+}
+
+PCLMUL_FUNC(void,Mul1_short)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] = carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] = carry;
+}
+
+AVX_FUNC(void,Mul1_short)
+(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
+{
+ long i;
+ unsigned long carry, prod[2];
+
+ carry = 0;
+ for (i = 0; i < sb; i++) {
+ pclmul_mul1(prod, bp[i], a);
+ cp[i] = carry ^ prod[0];
+ carry = prod[1];
+ }
+
+ cp[sb] = carry;
+}
+
+PCLMUL_RESOLVER(static,void,Mul1_short,
+ (_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a));
+
+#else
static
void Mul1_short(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)
@@ -677,9 +853,29 @@ NTL_EFF_SHORT_BB_MUL_CODE1
// warning #13200: No EMMS instruction before return
}
+#endif
+#ifdef NTL_LOADTIME_CPUE
+BASE_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ NTL_EFF_HALF_BB_MUL_CODE0
+}
+
+PCLMUL_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ pclmul_mul1(c, a, b);
+}
+
+AVX_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
+{
+ pclmul_mul1(c, a, b);
+}
+
+PCLMUL_RESOLVER(static,void,mul_half,(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b));
+
+#else
static
void mul_half(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)
@@ -694,6 +890,7 @@ NTL_EFF_HALF_BB_MUL_CODE0
}
+#endif
// mul2...mul8 hard-code 2x2...8x8 word multiplies.
// I adapted these routines from LiDIA (except mul3, see below).
@@ -1611,6 +1808,77 @@ static const _ntl_ulong sqrtab[256] = {
+#ifdef NTL_LOADTIME_CPU
+
+BASE_FUNC(void,sqr)(GF2X& c, const GF2X& a)
+{
+ long sa = a.xrep.length();
+ if (sa <= 0) {
+ clear(c);
+ return;
+ }
+
+ c.xrep.SetLength(sa << 1);
+ _ntl_ulong *cp = c.xrep.elts();
+ const _ntl_ulong *ap = a.xrep.elts();
+
+ for (long i = sa-1; i >= 0; i--) {
+ _ntl_ulong *c = cp + (i << 1);
+ _ntl_ulong a = ap[i];
+ _ntl_ulong hi, lo;
+
+ NTL_BB_SQR_CODE
+
+ c[0] = lo;
+ c[1] = hi;
+ }
+
+ c.normalize();
+ return;
+}
+
+PCLMUL_FUNC(void,sqr)(GF2X& c, const GF2X& a)
+{
+ long sa = a.xrep.length();
+ if (sa <= 0) {
+ clear(c);
+ return;
+ }
+
+ c.xrep.SetLength(sa << 1);
+ _ntl_ulong *cp = c.xrep.elts();
+ const _ntl_ulong *ap = a.xrep.elts();
+
+ for (long i = sa-1; i >= 0; i--)
+ pclmul_mul1 (cp + (i << 1), ap[i], ap[i]);
+
+ c.normalize();
+ return;
+}
+
+AVX_FUNC(void,sqr)(GF2X& c, const GF2X& a)
+{
+ long sa = a.xrep.length();
+ if (sa <= 0) {
+ clear(c);
+ return;
+ }
+
+ c.xrep.SetLength(sa << 1);
+ _ntl_ulong *cp = c.xrep.elts();
+ const _ntl_ulong *ap = a.xrep.elts();
+
+ for (long i = sa-1; i >= 0; i--)
+ pclmul_mul1 (cp + (i << 1), ap[i], ap[i]);
+
+ c.normalize();
+ return;
+}
+
+PCLMUL_RESOLVER(,void,sqr,(GF2X& c, const GF2X& a));
+
+#else
+
static inline
void sqr1(_ntl_ulong *c, _ntl_ulong a)
{
@@ -1651,6 +1919,7 @@ void sqr(GF2X& c, const GF2X& a)
return;
}
+#endif
void LeftShift(GF2X& c, const GF2X& a, long n)
--- src/InitSettings.cpp.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/InitSettings.cpp 2021-06-23 19:59:29.912142143 -0600
@@ -190,6 +190,11 @@ int main()
cout << "NTL_RANGE_CHECK=0\n";
#endif
+#ifdef NTL_LOADTIME_CPU
+ cout << "NTL_LOADTIME_CPU=1\n";
+#else
+ cout << "NTL_LOADTIME_CPU=0\n";
+#endif
// the following are not actual config flags, but help
--- src/mat_lzz_p.cpp.orig 2021-06-20 15:05:48.000000000 -0600
+++ src/mat_lzz_p.cpp 2021-06-23 19:59:29.915142146 -0600
@@ -9,6 +9,15 @@
#ifdef NTL_HAVE_AVX
#include <immintrin.h>
+#define AVX_ACTIVE 1
+#elif defined(NTL_LOADTIME_CPU)
+#include <immintrin.h>
+#define AVX_ACTIVE have_avx
+
+static int have_pclmul = -1;
+static int have_avx = -1;
+static int have_fma = -1;
+static int have_avx2 = -1;
#endif
NTL_START_IMPL
@@ -634,7 +643,7 @@ void mul(mat_zz_p& X, const mat_zz_p& A,
#ifdef NTL_HAVE_LL_TYPE
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
#define MAX_DBL_INT ((1L << NTL_DOUBLE_PRECISION)-1)
// max int representable exactly as a double
@@ -648,10 +657,12 @@ void mul(mat_zz_p& X, const mat_zz_p& A,
// MUL_ADD(a, b, c): a += b*c
+#define FMA_MUL_ADD(a, b, c) a = _mm256_fmadd_pd(b, c, a)
+#define AVX_MUL_ADD(a, b, c) a = _mm256_add_pd(a, _mm256_mul_pd(b, c))
#ifdef NTL_HAVE_FMA
-#define MUL_ADD(a, b, c) a = _mm256_fmadd_pd(b, c, a)
+#define MUL_ADD(a, b, c) FMA_MUL_ADD(a, b, c)
#else
-#define MUL_ADD(a, b, c) a = _mm256_add_pd(a, _mm256_mul_pd(b, c))
+#define MUL_ADD(a, b, c) AVX_MUL_ADD(a, b, c)
#endif
@@ -931,6 +942,94 @@ void muladd3_by_16(double *x, const doub
#else
+#if defined(NTL_LOADTIME_CPU)
+
+AVX_FUNC(void,muladd1_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec, bvec;
+
+
+ __m256d acc0=_mm256_load_pd(x + 0*4);
+ __m256d acc1=_mm256_load_pd(x + 1*4);
+ __m256d acc2=_mm256_load_pd(x + 2*4);
+ __m256d acc3=_mm256_load_pd(x + 3*4);
+ __m256d acc4=_mm256_load_pd(x + 4*4);
+ __m256d acc5=_mm256_load_pd(x + 5*4);
+ __m256d acc6=_mm256_load_pd(x + 6*4);
+ __m256d acc7=_mm256_load_pd(x + 7*4);
+
+
+ for (long i = 0; i < n; i++) {
+ avec = _mm256_broadcast_sd(a); a++;
+
+
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc0, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc1, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc2, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc3, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc4, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc5, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc6, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc7, avec, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4, acc0);
+ _mm256_store_pd(x + 1*4, acc1);
+ _mm256_store_pd(x + 2*4, acc2);
+ _mm256_store_pd(x + 3*4, acc3);
+ _mm256_store_pd(x + 4*4, acc4);
+ _mm256_store_pd(x + 5*4, acc5);
+ _mm256_store_pd(x + 6*4, acc6);
+ _mm256_store_pd(x + 7*4, acc7);
+}
+
+FMA_FUNC(void,muladd1_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec, bvec;
+
+
+ __m256d acc0=_mm256_load_pd(x + 0*4);
+ __m256d acc1=_mm256_load_pd(x + 1*4);
+ __m256d acc2=_mm256_load_pd(x + 2*4);
+ __m256d acc3=_mm256_load_pd(x + 3*4);
+ __m256d acc4=_mm256_load_pd(x + 4*4);
+ __m256d acc5=_mm256_load_pd(x + 5*4);
+ __m256d acc6=_mm256_load_pd(x + 6*4);
+ __m256d acc7=_mm256_load_pd(x + 7*4);
+
+
+ for (long i = 0; i < n; i++) {
+ avec = _mm256_broadcast_sd(a); a++;
+
+
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc0, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc1, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc2, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc3, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc4, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc5, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc6, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc7, avec, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4, acc0);
+ _mm256_store_pd(x + 1*4, acc1);
+ _mm256_store_pd(x + 2*4, acc2);
+ _mm256_store_pd(x + 3*4, acc3);
+ _mm256_store_pd(x + 4*4, acc4);
+ _mm256_store_pd(x + 5*4, acc5);
+ _mm256_store_pd(x + 6*4, acc6);
+ _mm256_store_pd(x + 7*4, acc7);
+}
+
+FMA_RESOLVER(static,void,muladd1_by_32,
+ (double *x, const double *a, const double *b, long n));
+
+#else
static
void muladd1_by_32(double *x, const double *a, const double *b, long n)
@@ -973,6 +1072,167 @@ void muladd1_by_32(double *x, const doub
_mm256_store_pd(x + 7*4, acc7);
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+
+AVX_FUNC(void,muladd2_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec0, avec1, bvec;
+ __m256d acc00, acc01, acc02, acc03;
+ __m256d acc10, acc11, acc12, acc13;
+
+
+ // round 0
+
+ acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13);
+
+ // round 1
+
+ acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec);
+ }
+
+
+ _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13);
+
+}
+
+FMA_FUNC(void,muladd2_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec0, avec1, bvec;
+ __m256d acc00, acc01, acc02, acc03;
+ __m256d acc10, acc11, acc12, acc13;
+
+
+ // round 0
+
+ acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13);
+
+ // round 1
+
+ acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec);
+ }
+
+
+ _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13);
+
+}
+
+FMA_RESOLVER(static,void,muladd2_by_32,
+ (double *x, const double *a, const double *b, long n));
+
+#else
+
static
void muladd2_by_32(double *x, const double *a, const double *b, long n)
{
@@ -1049,6 +1309,212 @@ void muladd2_by_32(double *x, const doub
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+FMA_FUNC(void,muladd3_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec0, avec1, avec2, bvec;
+ __m256d acc00, acc01, acc02, acc03;
+ __m256d acc10, acc11, acc12, acc13;
+ __m256d acc20, acc21, acc22, acc23;
+
+
+ // round 0
+
+ acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ);
+
+ acc20=_mm256_load_pd(x + 0*4 + 2*MAT_BLK_SZ);
+ acc21=_mm256_load_pd(x + 1*4 + 2*MAT_BLK_SZ);
+ acc22=_mm256_load_pd(x + 2*4 + 2*MAT_BLK_SZ);
+ acc23=_mm256_load_pd(x + 3*4 + 2*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+ avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13);
+
+ _mm256_store_pd(x + 0*4 + 2*MAT_BLK_SZ, acc20);
+ _mm256_store_pd(x + 1*4 + 2*MAT_BLK_SZ, acc21);
+ _mm256_store_pd(x + 2*4 + 2*MAT_BLK_SZ, acc22);
+ _mm256_store_pd(x + 3*4 + 2*MAT_BLK_SZ, acc23);
+
+ // round 1
+
+ acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ);
+
+ acc20=_mm256_load_pd(x + 4*4 + 2*MAT_BLK_SZ);
+ acc21=_mm256_load_pd(x + 5*4 + 2*MAT_BLK_SZ);
+ acc22=_mm256_load_pd(x + 6*4 + 2*MAT_BLK_SZ);
+ acc23=_mm256_load_pd(x + 7*4 + 2*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+ avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec);
+ }
+
+
+ _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13);
+
+ _mm256_store_pd(x + 4*4 + 2*MAT_BLK_SZ, acc20);
+ _mm256_store_pd(x + 5*4 + 2*MAT_BLK_SZ, acc21);
+ _mm256_store_pd(x + 6*4 + 2*MAT_BLK_SZ, acc22);
+ _mm256_store_pd(x + 7*4 + 2*MAT_BLK_SZ, acc23);
+
+}
+
+AVX2_FUNC(void,muladd3_by_32)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec0, avec1, avec2, bvec;
+ __m256d acc00, acc01, acc02, acc03;
+ __m256d acc10, acc11, acc12, acc13;
+ __m256d acc20, acc21, acc22, acc23;
+
+
+ // round 0
+
+ acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ);
+
+ acc20=_mm256_load_pd(x + 0*4 + 2*MAT_BLK_SZ);
+ acc21=_mm256_load_pd(x + 1*4 + 2*MAT_BLK_SZ);
+ acc22=_mm256_load_pd(x + 2*4 + 2*MAT_BLK_SZ);
+ acc23=_mm256_load_pd(x + 3*4 + 2*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+ avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec);
+ }
+
+
+ _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13);
+
+ _mm256_store_pd(x + 0*4 + 2*MAT_BLK_SZ, acc20);
+ _mm256_store_pd(x + 1*4 + 2*MAT_BLK_SZ, acc21);
+ _mm256_store_pd(x + 2*4 + 2*MAT_BLK_SZ, acc22);
+ _mm256_store_pd(x + 3*4 + 2*MAT_BLK_SZ, acc23);
+
+ // round 1
+
+ acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ);
+ acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ);
+ acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ);
+ acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ);
+
+ acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ);
+ acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ);
+ acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ);
+ acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ);
+
+ acc20=_mm256_load_pd(x + 4*4 + 2*MAT_BLK_SZ);
+ acc21=_mm256_load_pd(x + 5*4 + 2*MAT_BLK_SZ);
+ acc22=_mm256_load_pd(x + 6*4 + 2*MAT_BLK_SZ);
+ acc23=_mm256_load_pd(x + 7*4 + 2*MAT_BLK_SZ);
+
+ for (long i = 0; i < n; i++) {
+ avec0 = _mm256_broadcast_sd(&a[i]);
+ avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
+ avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]);
+
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec);
+ }
+
+
+ _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00);
+ _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01);
+ _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02);
+ _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03);
+
+ _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10);
+ _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11);
+ _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12);
+ _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13);
+
+ _mm256_store_pd(x + 4*4 + 2*MAT_BLK_SZ, acc20);
+ _mm256_store_pd(x + 5*4 + 2*MAT_BLK_SZ, acc21);
+ _mm256_store_pd(x + 6*4 + 2*MAT_BLK_SZ, acc22);
+ _mm256_store_pd(x + 7*4 + 2*MAT_BLK_SZ, acc23);
+
+}
+
+AVX2_RESOLVER(static,void,muladd3_by_32,
+ (double *x, const double *a, const double *b, long n));
+
+#else
+
// NOTE: this makes things slower on an AVX1 platform --- not enough registers
// it could be faster on AVX2/FMA, where there should be enough registers
static
@@ -1150,6 +1616,75 @@ void muladd3_by_32(double *x, const doub
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+
+AVX_FUNC(void,muladd1_by_16)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec, bvec;
+
+
+ __m256d acc0=_mm256_load_pd(x + 0*4);
+ __m256d acc1=_mm256_load_pd(x + 1*4);
+ __m256d acc2=_mm256_load_pd(x + 2*4);
+ __m256d acc3=_mm256_load_pd(x + 3*4);
+
+
+ for (long i = 0; i < n; i++) {
+ avec = _mm256_broadcast_sd(a); a++;
+
+
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc0, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc1, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc2, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc3, avec, bvec);
+ b += 16;
+ }
+
+
+ _mm256_store_pd(x + 0*4, acc0);
+ _mm256_store_pd(x + 1*4, acc1);
+ _mm256_store_pd(x + 2*4, acc2);
+ _mm256_store_pd(x + 3*4, acc3);
+}
+
+FMA_FUNC(void,muladd1_by_16)
+(double *x, const double *a, const double *b, long n)
+{
+ __m256d avec, bvec;
+
+
+ __m256d acc0=_mm256_load_pd(x + 0*4);
+ __m256d acc1=_mm256_load_pd(x + 1*4);
+ __m256d acc2=_mm256_load_pd(x + 2*4);
+ __m256d acc3=_mm256_load_pd(x + 3*4);
+
+
+ for (long i = 0; i < n; i++) {
+ avec = _mm256_broadcast_sd(a); a++;
+
+
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc0, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc1, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc2, avec, bvec);
+ bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc3, avec, bvec);
+ b += 16;
+ }
+
+
+ _mm256_store_pd(x + 0*4, acc0);
+ _mm256_store_pd(x + 1*4, acc1);
+ _mm256_store_pd(x + 2*4, acc2);
+ _mm256_store_pd(x + 3*4, acc3);
+}
+
+FMA_RESOLVER(static,void,muladd1_by_16,
+ (double *x, const double *a, const double *b, long n));
+
+#else
+
static
void muladd1_by_16(double *x, const double *a, const double *b, long n)
{
@@ -1180,10 +1715,11 @@ void muladd1_by_16(double *x, const doub
_mm256_store_pd(x + 3*4, acc3);
}
+#endif
-static
-void muladd2_by_16(double *x, const double *a, const double *b, long n)
+static void __attribute__((target ("avx,pclmul")))
+muladd2_by_16(double *x, const double *a, const double *b, long n)
{
__m256d avec0, avec1, bvec;
__m256d acc00, acc01, acc02, acc03;
@@ -1206,10 +1742,10 @@ void muladd2_by_16(double *x, const doub
avec0 = _mm256_broadcast_sd(&a[i]);
avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); MUL_ADD(acc00, avec0, bvec); MUL_ADD(acc10, avec1, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); MUL_ADD(acc01, avec0, bvec); MUL_ADD(acc11, avec1, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); MUL_ADD(acc02, avec0, bvec); MUL_ADD(acc12, avec1, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); MUL_ADD(acc03, avec0, bvec); MUL_ADD(acc13, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec);
}
@@ -1226,8 +1762,8 @@ void muladd2_by_16(double *x, const doub
}
-static
-void muladd3_by_16(double *x, const double *a, const double *b, long n)
+static void __attribute__((target("fma,pclmul")))
+muladd3_by_16(double *x, const double *a, const double *b, long n)
{
__m256d avec0, avec1, avec2, bvec;
__m256d acc00, acc01, acc02, acc03;
@@ -1257,10 +1793,10 @@ void muladd3_by_16(double *x, const doub
avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]);
avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); MUL_ADD(acc00, avec0, bvec); MUL_ADD(acc10, avec1, bvec); MUL_ADD(acc20, avec2, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); MUL_ADD(acc01, avec0, bvec); MUL_ADD(acc11, avec1, bvec); MUL_ADD(acc21, avec2, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); MUL_ADD(acc02, avec0, bvec); MUL_ADD(acc12, avec1, bvec); MUL_ADD(acc22, avec2, bvec);
- bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); MUL_ADD(acc03, avec0, bvec); MUL_ADD(acc13, avec1, bvec); MUL_ADD(acc23, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec);
+ bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec);
}
@@ -1289,6 +1825,29 @@ void muladd3_by_16(double *x, const doub
+#ifdef NTL_LOADTIME_CPU
+static inline
+void muladd_all_by_32(long first, long last, double *x, const double *a, const double *b, long n)
+{
+ long i = first;
+
+ if (have_fma) {
+ // process three rows at a time
+ for (; i <= last-3; i+=3)
+ muladd3_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ for (; i < last; i++)
+ muladd1_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ } else {
+ // process only two rows at a time: not enough registers :-(
+ for (; i <= last-2; i+=2)
+ muladd2_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ for (; i < last; i++)
+ muladd1_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ }
+}
+
+#else
+
static inline
void muladd_all_by_32(long first, long last, double *x, const double *a, const double *b, long n)
{
@@ -1308,6 +1867,30 @@ void muladd_all_by_32(long first, long l
#endif
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+
+static inline
+void muladd_all_by_16(long first, long last, double *x, const double *a, const double *b, long n)
+{
+ long i = first;
+ if (have_fma) {
+ // processing three rows at a time is faster
+ for (; i <= last-3; i+=3)
+ muladd3_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ for (; i < last; i++)
+ muladd1_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ } else {
+ // process only two rows at a time: not enough registers :-(
+ for (; i <= last-2; i+=2)
+ muladd2_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ for (; i < last; i++)
+ muladd1_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n);
+ }
+}
+
+#else
static inline
void muladd_all_by_16(long first, long last, double *x, const double *a, const double *b, long n)
@@ -1328,6 +1911,8 @@ void muladd_all_by_16(long first, long l
#endif
}
+#endif
+
static inline
void muladd_all_by_32_width(long first, long last, double *x, const double *a, const double *b, long n, long width)
{
@@ -1343,6 +1928,74 @@ void muladd_all_by_32_width(long first,
// this assumes n is a multiple of 16
+#ifdef NTL_LOADTIME_CPU
+AVX_FUNC(void,muladd_interval)
+(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)
+{
+ __m256d xvec0, xvec1, xvec2, xvec3;
+ __m256d yvec0, yvec1, yvec2, yvec3;
+
+ __m256d cvec = _mm256_broadcast_sd(&c);
+
+ for (long i = 0; i < n; i += 16, x += 16, y += 16) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ xvec1 = _mm256_load_pd(x+1*4);
+ xvec2 = _mm256_load_pd(x+2*4);
+ xvec3 = _mm256_load_pd(x+3*4);
+
+ yvec0 = _mm256_load_pd(y+0*4);
+ yvec1 = _mm256_load_pd(y+1*4);
+ yvec2 = _mm256_load_pd(y+2*4);
+ yvec3 = _mm256_load_pd(y+3*4);
+
+ AVX_MUL_ADD(xvec0, yvec0, cvec);
+ AVX_MUL_ADD(xvec1, yvec1, cvec);
+ AVX_MUL_ADD(xvec2, yvec2, cvec);
+ AVX_MUL_ADD(xvec3, yvec3, cvec);
+
+ _mm256_store_pd(x + 0*4, xvec0);
+ _mm256_store_pd(x + 1*4, xvec1);
+ _mm256_store_pd(x + 2*4, xvec2);
+ _mm256_store_pd(x + 3*4, xvec3);
+ }
+}
+
+FMA_FUNC(void,muladd_interval)
+(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)
+{
+ __m256d xvec0, xvec1, xvec2, xvec3;
+ __m256d yvec0, yvec1, yvec2, yvec3;
+
+ __m256d cvec = _mm256_broadcast_sd(&c);
+
+ for (long i = 0; i < n; i += 16, x += 16, y += 16) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ xvec1 = _mm256_load_pd(x+1*4);
+ xvec2 = _mm256_load_pd(x+2*4);
+ xvec3 = _mm256_load_pd(x+3*4);
+
+ yvec0 = _mm256_load_pd(y+0*4);
+ yvec1 = _mm256_load_pd(y+1*4);
+ yvec2 = _mm256_load_pd(y+2*4);
+ yvec3 = _mm256_load_pd(y+3*4);
+
+ FMA_MUL_ADD(xvec0, yvec0, cvec);
+ FMA_MUL_ADD(xvec1, yvec1, cvec);
+ FMA_MUL_ADD(xvec2, yvec2, cvec);
+ FMA_MUL_ADD(xvec3, yvec3, cvec);
+
+ _mm256_store_pd(x + 0*4, xvec0);
+ _mm256_store_pd(x + 1*4, xvec1);
+ _mm256_store_pd(x + 2*4, xvec2);
+ _mm256_store_pd(x + 3*4, xvec3);
+ }
+}
+
+FMA_RESOLVER(static,void,muladd_interval,
+ (double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n));
+
+#else
+
static inline
void muladd_interval(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)
{
@@ -1374,6 +2027,106 @@ void muladd_interval(double * NTL_RESTRI
}
}
+#endif
+
+#ifdef NTL_LOADTIME_CPU
+AVX_FUNC(void,muladd_interval1)
+(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)
+{
+
+ __m256d xvec0, xvec1, xvec2, xvec3;
+ __m256d yvec0, yvec1, yvec2, yvec3;
+ __m256d cvec;
+
+ if (n >= 4)
+ cvec = _mm256_broadcast_sd(&c);
+
+ long i=0;
+ for (; i <= n-16; i += 16, x += 16, y += 16) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ xvec1 = _mm256_load_pd(x+1*4);
+ xvec2 = _mm256_load_pd(x+2*4);
+ xvec3 = _mm256_load_pd(x+3*4);
+
+ yvec0 = _mm256_load_pd(y+0*4);
+ yvec1 = _mm256_load_pd(y+1*4);
+ yvec2 = _mm256_load_pd(y+2*4);
+ yvec3 = _mm256_load_pd(y+3*4);
+
+ AVX_MUL_ADD(xvec0, yvec0, cvec);
+ AVX_MUL_ADD(xvec1, yvec1, cvec);
+ AVX_MUL_ADD(xvec2, yvec2, cvec);
+ AVX_MUL_ADD(xvec3, yvec3, cvec);
+
+ _mm256_store_pd(x + 0*4, xvec0);
+ _mm256_store_pd(x + 1*4, xvec1);
+ _mm256_store_pd(x + 2*4, xvec2);
+ _mm256_store_pd(x + 3*4, xvec3);
+ }
+
+ for (; i <= n-4; i += 4, x += 4, y += 4) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ yvec0 = _mm256_load_pd(y+0*4);
+ AVX_MUL_ADD(xvec0, yvec0, cvec);
+ _mm256_store_pd(x + 0*4, xvec0);
+ }
+
+ for (; i < n; i++, x++, y++) {
+ *x += (*y)*c;
+ }
+}
+
+FMA_FUNC(void,muladd_interval1)
+(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)
+{
+
+ __m256d xvec0, xvec1, xvec2, xvec3;
+ __m256d yvec0, yvec1, yvec2, yvec3;
+ __m256d cvec;
+
+ if (n >= 4)
+ cvec = _mm256_broadcast_sd(&c);
+
+ long i=0;
+ for (; i <= n-16; i += 16, x += 16, y += 16) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ xvec1 = _mm256_load_pd(x+1*4);
+ xvec2 = _mm256_load_pd(x+2*4);
+ xvec3 = _mm256_load_pd(x+3*4);
+
+ yvec0 = _mm256_load_pd(y+0*4);
+ yvec1 = _mm256_load_pd(y+1*4);
+ yvec2 = _mm256_load_pd(y+2*4);
+ yvec3 = _mm256_load_pd(y+3*4);
+
+ FMA_MUL_ADD(xvec0, yvec0, cvec);
+ FMA_MUL_ADD(xvec1, yvec1, cvec);
+ FMA_MUL_ADD(xvec2, yvec2, cvec);
+ FMA_MUL_ADD(xvec3, yvec3, cvec);
+
+ _mm256_store_pd(x + 0*4, xvec0);
+ _mm256_store_pd(x + 1*4, xvec1);
+ _mm256_store_pd(x + 2*4, xvec2);
+ _mm256_store_pd(x + 3*4, xvec3);
+ }
+
+ for (; i <= n-4; i += 4, x += 4, y += 4) {
+ xvec0 = _mm256_load_pd(x+0*4);
+ yvec0 = _mm256_load_pd(y+0*4);
+ FMA_MUL_ADD(xvec0, yvec0, cvec);
+ _mm256_store_pd(x + 0*4, xvec0);
+ }
+
+ for (; i < n; i++, x++, y++) {
+ *x += (*y)*c;
+ }
+}
+
+FMA_RESOLVER(static,void,muladd_interval1,
+ (double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n));
+
+#else
+
// this one is more general: does not assume that n is a
// multiple of 16
static inline
@@ -1422,6 +2175,7 @@ void muladd_interval1(double * NTL_RESTR
}
}
+#endif
#endif
@@ -3009,10 +3763,10 @@ void alt_mul_LL(const mat_window_zz_p& X
}
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
-static
-void blk_mul_DD(const mat_window_zz_p& X,
+static void __attribute__((target("avx,pclmul")))
+blk_mul_DD(const mat_window_zz_p& X,
const const_mat_window_zz_p& A, const const_mat_window_zz_p& B)
{
long n = A.NumRows();
@@ -3351,12 +4105,13 @@ void mul_base (const mat_window_zz_p& X,
long p = zz_p::modulus();
long V = MAT_BLK_SZ*4;
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
// experimentally, blk_mul_DD beats all the alternatives
// if each dimension is at least 16
- if (n >= 16 && l >= 16 && m >= 16 &&
+ if (AVX_ACTIVE &&
+ n >= 16 && l >= 16 && m >= 16 &&
p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1))
@@ -3451,7 +4206,8 @@ void mul_strassen(const mat_window_zz_p&
// this code determines if mul_base triggers blk_mul_DD,
// in which case a higher crossover is used
-#if (defined(NTL_HAVE_LL_TYPE) && defined(NTL_HAVE_AVX))
+#if (defined(NTL_HAVE_LL_TYPE) && (defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)))
+ if (AVX_ACTIVE)
{
long V = MAT_BLK_SZ*4;
long p = zz_p::modulus();
@@ -3950,10 +4706,10 @@ void alt_inv_L(zz_p& d, mat_zz_p& X, con
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
-static
-void alt_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax)
+static void __attribute__((target("avx,pclmul")))
+alt_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax)
{
long n = A.NumRows();
@@ -4118,10 +4874,10 @@ void alt_inv_DD(zz_p& d, mat_zz_p& X, co
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
-static
-void blk_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax)
+static void __attribute__((target("avx,pclmul")))
+blk_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax)
{
long n = A.NumRows();
@@ -4879,8 +5635,9 @@ void relaxed_inv(zz_p& d, mat_zz_p& X, c
else if (n/MAT_BLK_SZ < 4) {
long V = 64;
-#ifdef NTL_HAVE_AVX
- if (p-1 <= MAX_DBL_INT &&
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
+ if (AVX_ACTIVE &&
+ p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) {
@@ -4905,8 +5662,9 @@ void relaxed_inv(zz_p& d, mat_zz_p& X, c
else {
long V = 4*MAT_BLK_SZ;
-#ifdef NTL_HAVE_AVX
- if (p-1 <= MAX_DBL_INT &&
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
+ if (AVX_ACTIVE &&
+ p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) {
@@ -5312,10 +6070,10 @@ void alt_tri_L(zz_p& d, const mat_zz_p&
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
-static
-void alt_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp,
+static void __attribute__((target("avx,pclmul")))
+alt_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp,
vec_zz_p *xp, bool trans, bool relax)
{
long n = A.NumRows();
@@ -5502,10 +6260,10 @@ void alt_tri_DD(zz_p& d, const mat_zz_p&
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
-static
-void blk_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp,
+static void __attribute__((target("avx,pclmul")))
+blk_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp,
vec_zz_p *xp, bool trans, bool relax)
{
long n = A.NumRows();
@@ -6316,8 +7074,9 @@ void tri(zz_p& d, const mat_zz_p& A, con
else if (n/MAT_BLK_SZ < 4) {
long V = 64;
-#ifdef NTL_HAVE_AVX
- if (p-1 <= MAX_DBL_INT &&
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
+ if (AVX_ACTIVE &&
+ p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) {
@@ -6342,8 +7101,9 @@ void tri(zz_p& d, const mat_zz_p& A, con
else {
long V = 4*MAT_BLK_SZ;
-#ifdef NTL_HAVE_AVX
- if (p-1 <= MAX_DBL_INT &&
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
+ if (AVX_ACTIVE &&
+ p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) {
@@ -6589,7 +7349,7 @@ long elim_basic(const mat_zz_p& A, mat_z
#ifdef NTL_HAVE_LL_TYPE
-#ifdef NTL_HAVE_AVX
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
static inline
@@ -8057,8 +8817,9 @@ long elim(const mat_zz_p& A, mat_zz_p *i
else {
long V = 4*MAT_BLK_SZ;
-#ifdef NTL_HAVE_AVX
- if (p-1 <= MAX_DBL_INT &&
+#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU)
+ if (AVX_ACTIVE &&
+ p-1 <= MAX_DBL_INT &&
V <= (MAX_DBL_INT-(p-1))/(p-1) &&
V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) {
--- src/QuickTest.cpp.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/QuickTest.cpp 2021-06-23 19:59:29.916142147 -0600
@@ -326,6 +326,9 @@ cerr << "Performance Options:\n";
cerr << "NTL_GF2X_NOINLINE\n";
#endif
+#ifdef NTL_LOADTIME_CPU
+ cerr << "NTL_LOADTIME_CPU\n";
+#endif
cerr << "\n\n";
--- src/WizardAux.orig 2021-06-20 15:05:49.000000000 -0600
+++ src/WizardAux 2021-06-23 19:59:29.916142147 -0600
@@ -89,6 +89,7 @@ system("$ARGV[0] InitSettings");
'NTL_GF2X_NOINLINE' => 0,
'NTL_FFT_BIGTAB' => 0,
'NTL_FFT_LAZYMUL' => 0,
+'NTL_LOADTIME_CPU' => 0,
'WIZARD_HACK' => '#define NTL_WIZARD_HACK',
--- src/ZZ.cpp.orig 2021-06-20 15:05:48.000000000 -0600
+++ src/ZZ.cpp 2021-06-23 19:59:29.918142149 -0600
@@ -14,6 +14,13 @@
#elif defined(NTL_HAVE_SSSE3)
#include <emmintrin.h>
#include <tmmintrin.h>
+#elif defined(NTL_LOADTIME_CPU)
+#include <immintrin.h>
+#include <emmintrin.h>
+#include <tmmintrin.h>
+
+static int have_avx2 = -1;
+static int have_ssse3 = -1;
#endif
#if defined(NTL_HAVE_KMA)
@@ -3268,6 +3275,590 @@ struct RandomStream_impl {
};
+#elif defined(NTL_LOADTIME_CPU)
+
+// round selector, specified values:
+// 8: low security - high speed
+// 12: mid security - mid speed
+// 20: high security - low speed
+#ifndef CHACHA_RNDS
+#define CHACHA_RNDS 20
+#endif
+
+typedef __m128i ssse3_ivec_t;
+typedef __m256i avx2_ivec_t;
+
+#define SSSE3_DELTA _mm_set_epi32(0,0,0,1)
+#define AVX2_DELTA _mm256_set_epi64x(0,2,0,2)
+
+#define SSSE3_START _mm_setzero_si128()
+#define AVX2_START _mm256_set_epi64x(0,1,0,0)
+
+#define SSSE3_NONCE(nonce) _mm_set_epi64x(nonce,0)
+#define AVX2_NONCE(nonce) _mm256_set_epi64x(nonce, 1, nonce, 0)
+
+#define SSSE3_STOREU_VEC(m,r) _mm_storeu_si128((__m128i*)(m), r)
+#define AVX2_STOREU_VEC(m,r) _mm256_storeu_si256((__m256i*)(m), r)
+
+#define SSSE3_STORE_VEC(m,r) _mm_store_si128((__m128i*)(m), r)
+#define AVX2_STORE_VEC(m,r) _mm256_store_si256((__m256i*)(m), r)
+
+#define SSSE3_LOAD_VEC(r,m) r = _mm_load_si128((const __m128i *)(m))
+#define AVX2_LOAD_VEC(r,m) r = _mm256_load_si256((const __m256i *)(m))
+
+#define SSSE3_LOADU_VEC_128(r, m) r = _mm_loadu_si128((const __m128i*)(m))
+#define AVX2_LOADU_VEC_128(r, m) r = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(m)))
+
+#define SSSE3_ADD_VEC_32(a,b) _mm_add_epi32(a, b)
+#define AVX2_ADD_VEC_32(a,b) _mm256_add_epi32(a, b)
+
+#define SSSE3_ADD_VEC_64(a,b) _mm_add_epi64(a, b)
+#define AVX2_ADD_VEC_64(a,b) _mm256_add_epi64(a, b)
+
+#define SSSE3_XOR_VEC(a,b) _mm_xor_si128(a, b)
+#define AVX2_XOR_VEC(a,b) _mm256_xor_si256(a, b)
+
+#define SSSE3_ROR_VEC_V1(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(0,3,2,1))
+#define AVX2_ROR_VEC_V1(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(0,3,2,1))
+
+#define SSSE3_ROR_VEC_V2(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(1,0,3,2))
+#define AVX2_ROR_VEC_V2(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(1,0,3,2))
+
+#define SSSE3_ROR_VEC_V3(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(2,1,0,3))
+#define AVX2_ROR_VEC_V3(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(2,1,0,3))
+
+#define SSSE3_ROL_VEC_7(x) SSSE3_XOR_VEC(_mm_slli_epi32(x, 7), _mm_srli_epi32(x,25))
+#define AVX2_ROL_VEC_7(x) AVX2_XOR_VEC(_mm256_slli_epi32(x, 7), _mm256_srli_epi32(x,25))
+
+#define SSSE3_ROL_VEC_12(x) SSSE3_XOR_VEC(_mm_slli_epi32(x,12), _mm_srli_epi32(x,20))
+#define AVX2_ROL_VEC_12(x) AVX2_XOR_VEC(_mm256_slli_epi32(x,12), _mm256_srli_epi32(x,20))
+
+#define SSSE3_ROL_VEC_8(x) _mm_shuffle_epi8(x,_mm_set_epi8(14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3))
+#define AVX2_ROL_VEC_8(x) _mm256_shuffle_epi8(x,_mm256_set_epi8(14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3,14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3))
+
+#define SSSE3_ROL_VEC_16(x) _mm_shuffle_epi8(x,_mm_set_epi8(13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2))
+#define AVX2_ROL_VEC_16(x) _mm256_shuffle_epi8(x,_mm256_set_epi8(13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2))
+
+#define SSSE3_WRITEU_VEC(op, d, v0, v1, v2, v3) \
+ SSSE3_STOREU_VEC(op + (d + 0*4), v0); \
+ SSSE3_STOREU_VEC(op + (d + 4*4), v1); \
+ SSSE3_STOREU_VEC(op + (d + 8*4), v2); \
+ SSSE3_STOREU_VEC(op + (d +12*4), v3);
+#define AVX2_WRITEU_VEC(op, d, v0, v1, v2, v3) \
+ AVX2_STOREU_VEC(op + (d + 0*4), _mm256_permute2x128_si256(v0, v1, 0x20)); \
+ AVX2_STOREU_VEC(op + (d + 8*4), _mm256_permute2x128_si256(v2, v3, 0x20)); \
+ AVX2_STOREU_VEC(op + (d +16*4), _mm256_permute2x128_si256(v0, v1, 0x31)); \
+ AVX2_STOREU_VEC(op + (d +24*4), _mm256_permute2x128_si256(v2, v3, 0x31));
+
+#define SSSE3_WRITE_VEC(op, d, v0, v1, v2, v3) \
+ SSSE3_STORE_VEC(op + (d + 0*4), v0); \
+ SSSE3_STORE_VEC(op + (d + 4*4), v1); \
+ SSSE3_STORE_VEC(op + (d + 8*4), v2); \
+ SSSE3_STORE_VEC(op + (d +12*4), v3);
+#define AVX2_WRITE_VEC(op, d, v0, v1, v2, v3) \
+ AVX2_STORE_VEC(op + (d + 0*4), _mm256_permute2x128_si256(v0, v1, 0x20)); \
+ AVX2_STORE_VEC(op + (d + 8*4), _mm256_permute2x128_si256(v2, v3, 0x20)); \
+ AVX2_STORE_VEC(op + (d +16*4), _mm256_permute2x128_si256(v0, v1, 0x31)); \
+ AVX2_STORE_VEC(op + (d +24*4), _mm256_permute2x128_si256(v2, v3, 0x31));
+
+#define SSSE3_SZ_VEC (16)
+#define AVX2_SZ_VEC (32)
+
+#define SSSE3_RANSTREAM_NCHUNKS (4)
+// leads to a BUFSZ of 512
+
+#define AVX2_RANSTREAM_NCHUNKS (2)
+// leads to a BUFSZ of 512
+
+#define SSSE3_DQROUND_VECTORS_VEC(a,b,c,d) \
+ a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_16(d); \
+ c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_12(b); \
+ a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_8(d); \
+ c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_7(b); \
+ b = SSSE3_ROR_VEC_V1(b); c = SSSE3_ROR_VEC_V2(c); d = SSSE3_ROR_VEC_V3(d); \
+ a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_16(d); \
+ c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_12(b); \
+ a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_8(d); \
+ c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_7(b); \
+ b = SSSE3_ROR_VEC_V3(b); c = SSSE3_ROR_VEC_V2(c); d = SSSE3_ROR_VEC_V1(d);
+
+#define AVX2_DQROUND_VECTORS_VEC(a,b,c,d) \
+ a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_16(d); \
+ c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_12(b); \
+ a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_8(d); \
+ c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_7(b); \
+ b = AVX2_ROR_VEC_V1(b); c = AVX2_ROR_VEC_V2(c); d = AVX2_ROR_VEC_V3(d); \
+ a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_16(d); \
+ c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_12(b); \
+ a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_8(d); \
+ c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_7(b); \
+ b = AVX2_ROR_VEC_V3(b); c = AVX2_ROR_VEC_V2(c); d = AVX2_ROR_VEC_V1(d);
+
+#define SSSE3_RANSTREAM_STATESZ (4*SSSE3_SZ_VEC)
+#define AVX2_RANSTREAM_STATESZ (4*AVX2_SZ_VEC)
+
+#define SSSE3_RANSTREAM_CHUNKSZ (2*SSSE3_RANSTREAM_STATESZ)
+#define AVX2_RANSTREAM_CHUNKSZ (2*AVX2_RANSTREAM_STATESZ)
+
+#define SSSE3_RANSTREAM_BUFSZ (SSSE3_RANSTREAM_NCHUNKS*SSSE3_RANSTREAM_CHUNKSZ)
+#define AVX2_RANSTREAM_BUFSZ (AVX2_RANSTREAM_NCHUNKS*AVX2_RANSTREAM_CHUNKSZ)
+
+static void allocate_space(AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store)
+{
+ if (have_avx2) {
+ state_store.SetLength(AVX2_RANSTREAM_STATESZ);
+ buf_store.SetLength(AVX2_RANSTREAM_BUFSZ);
+ } else {
+ state_store.SetLength(SSSE3_RANSTREAM_STATESZ);
+ buf_store.SetLength(SSSE3_RANSTREAM_BUFSZ);
+ }
+};
+
+BASE_FUNC(void, randomstream_impl_init)
+(_ntl_uint32 *state,
+ AlignedArray<unsigned char> &state_store __attribute__((unused)),
+ AlignedArray<unsigned char> &buf_store __attribute__((unused)),
+ const unsigned char *key)
+{
+ salsa20_init(state, key);
+}
+
+SSSE3_FUNC(void, randomstream_impl_init)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store,
+ const unsigned char *key)
+{
+ allocate_space(state_store, buf_store);
+
+ unsigned char *state = state_store.elts();
+
+ unsigned int chacha_const[] = {
+ 0x61707865,0x3320646E,0x79622D32,0x6B206574
+ };
+
+ ssse3_ivec_t d0, d1, d2, d3;
+ SSSE3_LOADU_VEC_128(d0, chacha_const);
+ SSSE3_LOADU_VEC_128(d1, key);
+ SSSE3_LOADU_VEC_128(d2, key+16);
+
+ d3 = SSSE3_START;
+
+ SSSE3_STORE_VEC(state + 0*SSSE3_SZ_VEC, d0);
+ SSSE3_STORE_VEC(state + 1*SSSE3_SZ_VEC, d1);
+ SSSE3_STORE_VEC(state + 2*SSSE3_SZ_VEC, d2);
+ SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3);
+}
+
+AVX2_FUNC(void, randomstream_impl_init)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store,
+ const unsigned char *key)
+{
+ allocate_space(state_store, buf_store);
+
+ unsigned char *state = state_store.elts();
+
+ unsigned int chacha_const[] = {
+ 0x61707865,0x3320646E,0x79622D32,0x6B206574
+ };
+
+ avx2_ivec_t d0, d1, d2, d3;
+ AVX2_LOADU_VEC_128(d0, chacha_const);
+ AVX2_LOADU_VEC_128(d1, key);
+ AVX2_LOADU_VEC_128(d2, key+16);
+
+ d3 = AVX2_START;
+
+ AVX2_STORE_VEC(state + 0*AVX2_SZ_VEC, d0);
+ AVX2_STORE_VEC(state + 1*AVX2_SZ_VEC, d1);
+ AVX2_STORE_VEC(state + 2*AVX2_SZ_VEC, d2);
+ AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3);
+}
+
+SSSE3_RESOLVER(static, void, randomstream_impl_init,
+ (_ntl_uint32 *state, AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store, const unsigned char *key));
+
+BASE_FUNC(long, randomstream_get_bytes)
+(_ntl_uint32 *state,
+ unsigned char *buf,
+ AlignedArray<unsigned char> &state_store __attribute__((unused)),
+ AlignedArray<unsigned char> &buf_store __attribute__((unused)),
+ long &chunk_count __attribute__((unused)),
+ unsigned char *NTL_RESTRICT res,
+ long n,
+ long pos)
+{
+ if (n < 0) LogicError("RandomStream::get: bad args");
+
+ long i, j;
+
+ if (n <= 64-pos) {
+ for (i = 0; i < n; i++) res[i] = buf[pos+i];
+ pos += n;
+ return pos;
+ }
+
+ // read remainder of buffer
+ for (i = 0; i < 64-pos; i++) res[i] = buf[pos+i];
+ n -= 64-pos;
+ res += 64-pos;
+ pos = 64;
+
+ _ntl_uint32 wdata[16];
+
+ // read 64-byte chunks
+ for (i = 0; i <= n-64; i += 64) {
+ salsa20_apply(state, wdata);
+ for (j = 0; j < 16; j++)
+ FROMLE(res + i + 4*j, wdata[j]);
+ }
+
+ if (i < n) {
+ salsa20_apply(state, wdata);
+
+ for (j = 0; j < 16; j++)
+ FROMLE(buf + 4*j, wdata[j]);
+
+ pos = n-i;
+ for (j = 0; j < pos; j++)
+ res[i+j] = buf[j];
+ }
+
+ return pos;
+}
+
+SSSE3_FUNC(long, randomstream_get_bytes)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ unsigned char *buf_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store,
+ long &chunk_count,
+ unsigned char *NTL_RESTRICT res,
+ long n,
+ long pos)
+{
+ if (n < 0) LogicError("RandomStream::get: bad args");
+ if (n == 0) return pos;
+
+ unsigned char *NTL_RESTRICT buf = buf_store.elts();
+
+ if (n <= SSSE3_RANSTREAM_BUFSZ-pos) {
+ std::memcpy(&res[0], &buf[pos], n);
+ pos += n;
+ return pos;
+ }
+
+ unsigned char *NTL_RESTRICT state = state_store.elts();
+
+ ssse3_ivec_t d0, d1, d2, d3;
+ SSSE3_LOAD_VEC(d0, state + 0*SSSE3_SZ_VEC);
+ SSSE3_LOAD_VEC(d1, state + 1*SSSE3_SZ_VEC);
+ SSSE3_LOAD_VEC(d2, state + 2*SSSE3_SZ_VEC);
+ SSSE3_LOAD_VEC(d3, state + 3*SSSE3_SZ_VEC);
+
+ // read remainder of buffer
+ std::memcpy(&res[0], &buf[pos], SSSE3_RANSTREAM_BUFSZ-pos);
+ n -= SSSE3_RANSTREAM_BUFSZ-pos;
+ res += SSSE3_RANSTREAM_BUFSZ-pos;
+ pos = SSSE3_RANSTREAM_BUFSZ;
+
+ long i = 0;
+ for (; i <= n-SSSE3_RANSTREAM_BUFSZ; i += SSSE3_RANSTREAM_BUFSZ) {
+ chunk_count |= SSSE3_RANSTREAM_NCHUNKS; // disable small buffer strategy
+
+ for (long j = 0; j < SSSE3_RANSTREAM_NCHUNKS; j++) {
+ ssse3_ivec_t v0=d0, v1=d1, v2=d2, v3=d3;
+ ssse3_ivec_t v4=d0, v5=d1, v6=d2, v7=SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+
+ for (long k = 0; k < CHACHA_RNDS/2; k++) {
+ SSSE3_DQROUND_VECTORS_VEC(v0,v1,v2,v3)
+ SSSE3_DQROUND_VECTORS_VEC(v4,v5,v6,v7)
+ }
+
+ SSSE3_WRITEU_VEC(res+i+j*(8*SSSE3_SZ_VEC), 0, SSSE3_ADD_VEC_32(v0,d0), SSSE3_ADD_VEC_32(v1,d1), SSSE3_ADD_VEC_32(v2,d2), SSSE3_ADD_VEC_32(v3,d3))
+ d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+ SSSE3_WRITEU_VEC(res+i+j*(8*SSSE3_SZ_VEC), 4*SSSE3_SZ_VEC, SSSE3_ADD_VEC_32(v4,d0), SSSE3_ADD_VEC_32(v5,d1), SSSE3_ADD_VEC_32(v6,d2), SSSE3_ADD_VEC_32(v7,d3))
+ d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+ }
+
+ }
+
+ if (i < n) {
+
+ long nchunks;
+
+ if (chunk_count < SSSE3_RANSTREAM_NCHUNKS) {
+ nchunks = long(cast_unsigned((n-i)+SSSE3_RANSTREAM_CHUNKSZ-1)/SSSE3_RANSTREAM_CHUNKSZ);
+ chunk_count += nchunks;
+ }
+ else
+ nchunks = SSSE3_RANSTREAM_NCHUNKS;
+
+ long pos_offset = SSSE3_RANSTREAM_BUFSZ - nchunks*SSSE3_RANSTREAM_CHUNKSZ;
+ buf += pos_offset;
+
+ for (long j = 0; j < nchunks; j++) {
+ ssse3_ivec_t v0=d0, v1=d1, v2=d2, v3=d3;
+ ssse3_ivec_t v4=d0, v5=d1, v6=d2, v7=SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+
+ for (long k = 0; k < CHACHA_RNDS/2; k++) {
+ SSSE3_DQROUND_VECTORS_VEC(v0,v1,v2,v3)
+ SSSE3_DQROUND_VECTORS_VEC(v4,v5,v6,v7)
+ }
+
+ SSSE3_WRITE_VEC(buf+j*(8*SSSE3_SZ_VEC), 0, SSSE3_ADD_VEC_32(v0,d0), SSSE3_ADD_VEC_32(v1,d1), SSSE3_ADD_VEC_32(v2,d2), SSSE3_ADD_VEC_32(v3,d3))
+ d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+ SSSE3_WRITE_VEC(buf+j*(8*SSSE3_SZ_VEC), 4*SSSE3_SZ_VEC, SSSE3_ADD_VEC_32(v4,d0), SSSE3_ADD_VEC_32(v5,d1), SSSE3_ADD_VEC_32(v6,d2), SSSE3_ADD_VEC_32(v7,d3))
+ d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA);
+ }
+
+ pos = n-i+pos_offset;
+ std::memcpy(&res[i], &buf[0], n-i);
+ }
+
+ SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3);
+
+ return pos;
+}
+
+AVX2_FUNC(long, randomstream_get_bytes)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ unsigned char *buf_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store,
+ long &chunk_count,
+ unsigned char *NTL_RESTRICT res,
+ long n,
+ long pos)
+{
+ if (n < 0) LogicError("RandomStream::get: bad args");
+ if (n == 0) return pos;
+
+ unsigned char *NTL_RESTRICT buf = buf_store.elts();
+
+ if (n <= AVX2_RANSTREAM_BUFSZ-pos) {
+ std::memcpy(&res[0], &buf[pos], n);
+ pos += n;
+ return pos;
+ }
+
+ unsigned char *NTL_RESTRICT state = state_store.elts();
+
+ avx2_ivec_t d0, d1, d2, d3;
+ AVX2_LOAD_VEC(d0, state + 0*AVX2_SZ_VEC);
+ AVX2_LOAD_VEC(d1, state + 1*AVX2_SZ_VEC);
+ AVX2_LOAD_VEC(d2, state + 2*AVX2_SZ_VEC);
+ AVX2_LOAD_VEC(d3, state + 3*AVX2_SZ_VEC);
+
+ // read remainder of buffer
+ std::memcpy(&res[0], &buf[pos], AVX2_RANSTREAM_BUFSZ-pos);
+ n -= AVX2_RANSTREAM_BUFSZ-pos;
+ res += AVX2_RANSTREAM_BUFSZ-pos;
+ pos = AVX2_RANSTREAM_BUFSZ;
+
+ long i = 0;
+ for (; i <= n-AVX2_RANSTREAM_BUFSZ; i += AVX2_RANSTREAM_BUFSZ) {
+ chunk_count |= AVX2_RANSTREAM_NCHUNKS; // disable small buffer strategy
+
+ for (long j = 0; j < AVX2_RANSTREAM_NCHUNKS; j++) {
+ avx2_ivec_t v0=d0, v1=d1, v2=d2, v3=d3;
+ avx2_ivec_t v4=d0, v5=d1, v6=d2, v7=AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+
+ for (long k = 0; k < CHACHA_RNDS/2; k++) {
+ AVX2_DQROUND_VECTORS_VEC(v0,v1,v2,v3)
+ AVX2_DQROUND_VECTORS_VEC(v4,v5,v6,v7)
+ }
+
+ AVX2_WRITEU_VEC(res+i+j*(8*AVX2_SZ_VEC), 0, AVX2_ADD_VEC_32(v0,d0), AVX2_ADD_VEC_32(v1,d1), AVX2_ADD_VEC_32(v2,d2), AVX2_ADD_VEC_32(v3,d3))
+ d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+ AVX2_WRITEU_VEC(res+i+j*(8*AVX2_SZ_VEC), 4*AVX2_SZ_VEC, AVX2_ADD_VEC_32(v4,d0), AVX2_ADD_VEC_32(v5,d1), AVX2_ADD_VEC_32(v6,d2), AVX2_ADD_VEC_32(v7,d3))
+ d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+ }
+
+ }
+
+ if (i < n) {
+
+ long nchunks;
+
+ if (chunk_count < AVX2_RANSTREAM_NCHUNKS) {
+ nchunks = long(cast_unsigned((n-i)+AVX2_RANSTREAM_CHUNKSZ-1)/AVX2_RANSTREAM_CHUNKSZ);
+ chunk_count += nchunks;
+ }
+ else
+ nchunks = AVX2_RANSTREAM_NCHUNKS;
+
+ long pos_offset = AVX2_RANSTREAM_BUFSZ - nchunks*AVX2_RANSTREAM_CHUNKSZ;
+ buf += pos_offset;
+
+ for (long j = 0; j < nchunks; j++) {
+ avx2_ivec_t v0=d0, v1=d1, v2=d2, v3=d3;
+ avx2_ivec_t v4=d0, v5=d1, v6=d2, v7=AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+
+ for (long k = 0; k < CHACHA_RNDS/2; k++) {
+ AVX2_DQROUND_VECTORS_VEC(v0,v1,v2,v3)
+ AVX2_DQROUND_VECTORS_VEC(v4,v5,v6,v7)
+ }
+
+ AVX2_WRITE_VEC(buf+j*(8*AVX2_SZ_VEC), 0, AVX2_ADD_VEC_32(v0,d0), AVX2_ADD_VEC_32(v1,d1), AVX2_ADD_VEC_32(v2,d2), AVX2_ADD_VEC_32(v3,d3))
+ d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+ AVX2_WRITE_VEC(buf+j*(8*AVX2_SZ_VEC), 4*AVX2_SZ_VEC, AVX2_ADD_VEC_32(v4,d0), AVX2_ADD_VEC_32(v5,d1), AVX2_ADD_VEC_32(v6,d2), AVX2_ADD_VEC_32(v7,d3))
+ d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA);
+ }
+
+ pos = n-i+pos_offset;
+ std::memcpy(&res[i], &buf[0], n-i);
+ }
+
+ AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3);
+
+ return pos;
+}
+
+SSSE3_RESOLVER(static, long, randomstream_get_bytes,
+ (_ntl_uint32 *state, unsigned char *buf,
+ AlignedArray<unsigned char> &state_store,
+ AlignedArray<unsigned char> &buf_store,
+ long &chunk_count,
+ unsigned char *NTL_RESTRICT res,
+ long n,
+ long pos));
+
+BASE_FUNC(void, randomstream_set_nonce)
+(_ntl_uint32 *state,
+ AlignedArray<unsigned char> &state_store __attribute__((unused)),
+ long &chunk_count __attribute__((unused)),
+ unsigned long nonce)
+{
+ _ntl_uint32 nonce0, nonce1;
+
+ nonce0 = nonce;
+ nonce0 = INT32MASK(nonce0);
+
+ nonce1 = 0;
+
+#if (NTL_BITS_PER_LONG > 32)
+ nonce1 = nonce >> 32;
+ nonce1 = INT32MASK(nonce1);
+#endif
+
+ state[12] = 0;
+ state[13] = 0;
+ state[14] = nonce0;
+ state[15] = nonce1;
+}
+
+SSSE3_FUNC(void, randomstream_set_nonce)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ long &chunk_count,
+ unsigned long nonce)
+{
+ unsigned char *state = state_store.elts();
+ ssse3_ivec_t d3;
+ d3 = SSSE3_NONCE(nonce);
+ SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3);
+ chunk_count = 0;
+}
+
+AVX2_FUNC(void, randomstream_set_nonce)
+(_ntl_uint32 *state_ignored __attribute__((unused)),
+ AlignedArray<unsigned char> &state_store,
+ long &chunk_count,
+ unsigned long nonce)
+{
+ unsigned char *state = state_store.elts();
+ avx2_ivec_t d3;
+ d3 = AVX2_NONCE(nonce);
+ AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3);
+ chunk_count = 0;
+}
+
+SSSE3_RESOLVER(, void, randomstream_set_nonce,
+ (_ntl_uint32 *state,
+ AlignedArray<unsigned char> &state_store,
+ long &chunk_count,
+ unsigned long nonce));
+
+struct RandomStream_impl {
+ AlignedArray<unsigned char> state_store;
+ AlignedArray<unsigned char> buf_store;
+ long chunk_count;
+ _ntl_uint32 state[16];
+ unsigned char buf[64];
+
+ explicit
+ RandomStream_impl(const unsigned char *key)
+ {
+ randomstream_impl_init(state, state_store, buf_store, key);
+ chunk_count = 0;
+ }
+
+ RandomStream_impl(const RandomStream_impl& other)
+ {
+ if (have_avx2 || have_ssse3) {
+ allocate_space(state_store, buf_store);
+ }
+ *this = other;
+ }
+
+ RandomStream_impl& operator=(const RandomStream_impl& other)
+ {
+ if (have_avx2) {
+ std::memcpy(state_store.elts(), other.state_store.elts(), AVX2_RANSTREAM_STATESZ);
+ std::memcpy(buf_store.elts(), other.buf_store.elts(), AVX2_RANSTREAM_BUFSZ);
+ } else if (have_ssse3) {
+ std::memcpy(state_store.elts(), other.state_store.elts(), SSSE3_RANSTREAM_STATESZ);
+ std::memcpy(buf_store.elts(), other.buf_store.elts(), SSSE3_RANSTREAM_BUFSZ);
+ }
+ chunk_count = other.chunk_count;
+ return *this;
+ }
+
+ const unsigned char *
+ get_buf() const
+ {
+ if (have_avx2 || have_ssse3) {
+ return buf_store.elts();
+ } else {
+ return &buf[0];
+ }
+ }
+
+ long
+ get_buf_len() const
+ {
+ if (have_avx2) {
+ return AVX2_RANSTREAM_BUFSZ;
+ } else if (have_ssse3) {
+ return SSSE3_RANSTREAM_BUFSZ;
+ } else {
+ return 64;
+ }
+ }
+
+ // bytes are generated in chunks of RANSTREAM_BUFSZ bytes, except that
+ // initially, we may generate a few chunks of RANSTREAM_CHUNKSZ
+ // bytes. This optimizes a bit for short bursts following a reset.
+
+ long
+ get_bytes(unsigned char *NTL_RESTRICT res,
+ long n, long pos)
+ {
+ return randomstream_get_bytes(state, buf, state_store, buf_store,
+ chunk_count, res, n, pos);
+ }
+
+ void
+ set_nonce(unsigned long nonce)
+ {
+ randomstream_set_nonce(state, state_store, chunk_count, nonce);
+ }
+};
#else