> 文章列表 > Intel SIMD: AVX2

Intel SIMD: AVX2

Intel SIMD: AVX2

AVX2

资料:

  • Intel 内部指令 — AVX和AVX2学习笔记
  • Intel Intrinsics — AVX & AVX2 Learning Notes
  • Module x86

AVX 向量寄存器有三种:

  1. 128-bit (XMM forms),AVX2 支持,符号 __m128, __m128d, __m128i
  2. 256-bit (YMM forms),AVX2 支持,符号 __m256, __m256d, __m256i
  3. 512-bit 的向量寄存器,AVX2 不支持,这需要 AVX-512 架构

YMM 实际上是两个 XMM,在运算时会分成 2 个 128 bits 的区域。YMM 支持 16x16, 8x32, 4x64 的 SIMD,实测 add_epi16a+b 快 20 倍(包括了 for 的花销)。

AVX 的数据类型包括:

  1. ps – packed single precision
  2. pd – packed double precision
  3. epi32 – packed 32-bit integers
  4. epu32 – packed 32-bit unsigned integers
  5. epi64 – packed 64-bit integers

一些常用指令(有些指令 __m256 不支持但 __m128 支持):

  1. 全局:zeroall(所有的 YMM 置零), zeroupper(所有的 YMM 高位置零)
  2. 加载:load, loadu, i32gather(根据索引加载), i64gather
  3. 存储:store(可以直接用 char* 作为 __m256*
  4. 设置:broadcast(广播), setzero, set1(复制), set(反序), setr(正序)
  5. 转换:cast(在 __m128, __m256 之间转换), cvt(在 ps, epi32 之间转换)
  6. 加减:add, sub, hadd(水平), hsub
  7. 乘除:mul, mullo(低位结果), mulhi(高位结果), div
  8. 混合运算(只有 ps, pd,不支持 epi):fmadd, fmsub, fnmadd, fnmsub
  9. 逻辑运算:cmp, cmpeq, cmpneq, cmpge(大于等于), cmpgt(严格大于), cmple, cmplt,
  10. 位运算:and, andnot, or, xor, sll(左移), srl(右移), slli, srli, sllv, srlv, bslli, bsrli
  11. 统计学:max, min, avg, ceil, floor, round, lzcnt
  12. 数学:abs, getexp, sqrt, rsqrt, sin, cos
  13. 置换:shuffle, permute(根据控制位写入), insert(根据控制位插入)
  14. 另外还有 mask 的版本,但是 AVX2 似乎都不支持?

代码样例

#include <stdio.h>
#include <time.h>#include <xmmintrin.h>  // __m128
#include <immintrin.h>  // __m256
#include <zmmintrin.h>  // __m512time_t TM_start, TM_end;#define Timer(code) TM_start = clock(); code; TM_end = clock(); printf("cpu cycles = %lld\\n", TM_end - TM_start); //对code部分计时
#define Loop(loop, code) Timer(for(int ind=0; ind<loop; ind++) {code;})#define pn printf("\\n\\n")/*全局:zeroall, zeroupper加载:load, loadu, i32gather(根据索引加载), i64gather存储:store设置:broadcast, setzero, set1(复制), set(反序), setr(正序)转换:cast, cvt加减:add, sub, hadd(水平), hsub乘除:mul, mullo(低位结果), mulhi(高位结果), div混合运算:fmadd, fmsub, fnmadd, fnmsub逻辑运算:cmp, cmpeq, cmpneq, cmpge(大于等于), cmpgt(严格大于), cmple, cmplt,位运算:and, andnot, or, xor, sll(左移), srl(右移), slli, srli, sllv, srlv, bslli, bsrli统计学:max, min, avg, ceil, floor, round, lzcnt数学:abs, getexp, sqrt, rsqrt, sin, cos置换:shuffle, permute(根据控制位写入), insert(根据控制位插入)
*/void print_m256i_i16(__m256i* arr) {printf("[ %d", arr->m256i_i16[0]);for (int i = 1; i < 16; i++)printf(", %d", arr->m256i_i16[i]);printf(" ]\\n\\n");
}void print_m256i_i32(__m256i* arr) {printf("[ %d", arr->m256i_i32[0]);for (int i = 1; i < 8; i++)printf(", %d", arr->m256i_i32[i]);printf(" ]\\n\\n");
}int main()
{// AVX2 不支持 m512;这需要 AVX-512 指令集!/*__m512i a3;a3 = _mm512_set1_epi32(123);print_m256i_i32(&a3);*/int arr1[64], arr2[64];for (int i = 0; i < 64; i++)arr1[i] = i+1;__m256i a, b, c;printf("Test set/load/store: \\n\\n");a = _mm256_setzero_si256();     // 全零print_m256i_i32(&a);b = _mm256_set1_epi32(123);     // Copyprint_m256i_i32(&b);/*在 m256 中,包含两个 m128。每个 m128 里,i8[15] 在最左边,i8[0] 在最右边。*/a = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);   // 反序:L0 赋给 i32[7],L7 赋给 i32[0]print_m256i_i32(&a);b = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8);  // 正序(reverse order)print_m256i_i32(&b);c = _mm256_load_si256(arr1);    // 不必强制类型转换,直接写 32 字节数组即可print_m256i_i32(&c);_mm256_store_si256(arr2, c);    // 同理,直接写 32 字节数组print_m256i_i32(arr2);printf("Test broadcast/cvt/cast: \\n\\n");__m256 d;float e = 123.5;float f[4] = { 1.e1,1.e2,1.e3,1.e4 };d = _mm256_broadcast_ss(&e);    // m32 的广播a = _mm256_cvtps_epi32(d);      // 类型转换,园整print_m256i_i32(&a);d = _mm256_broadcast_ps(&f);    // m128 的广播a = _mm256_cvtps_epi32(d);      // 类型转换,园整print_m256i_i32(&a);__m128i g;a.m256i_i32[0] = 123;a.m256i_i32[7] = 456;g = _mm256_castsi256_si128(a);  // m256 的前一半写到 m128 上print_m256i_i32(&a);print_m256i_i32(&g);g = *(__m128i*)arr1;    // 强行赋值a = _mm256_castsi128_si256(g);  // m128 写到 m256 的前一半,后一半置零print_m256i_i32(&a);printf("Test gather: \\n\\n");__m256i index = _mm256_setr_epi32(1, 3, 5, 7, 2, 4, 6, 8);a = _mm256_i32gather_epi32(arr1, index, 4); // index 是按字节寻址的,第三个参数是每个数据项的字节长度(epi32 是 4)print_m256i_i32(&index);print_m256i_i32(arr1);print_m256i_i32(&a);printf("Test shuffle: \\n\\n");/*按 m128 分区,每个区的4个数置换控制位 IMM8,共4*2比特,最低的2比特控制m128[0]0b10110001:reverse(01,00,11,10)*/b = _mm256_shuffle_epi32(*(__m256i*)arr1, 0b10110001);      // m128 视为 4 个 int 进行置换print_m256i_i32(arr1);print_m256i_i32(&b);b = _mm256_shufflehi_epi16(*(__m256i*)arr1, 0b10110001);    // m128 的高64比特,视为 4 个 short 进行置换print_m256i_i32(arr1);print_m256i_i32(&b);b = _mm256_shufflelo_epi16(*(__m256i*)arr1, 0b10110001);    // m128 的低64比特,视为 4 个 short 进行置换print_m256i_i32(arr1);print_m256i_i32(&b);printf("Test permute: \\n\\n");b = _mm256_permute4x64_epi64(*(__m256i*)arr1,0b00010001);   // 类似 shuffle 的 IMM8 控制符print_m256i_i32(arr1);print_m256i_i32(&b);index = _mm256_setr_epi32(1, 3, 5, 7, 2, 4, 6, 10);b = _mm256_permutevar8x32_epi32(*(__m256i*)arr1, index);    // 越界的 index,会自动模8(截取了低3比特)print_m256i_i32(&index);print_m256i_i32(arr1);print_m256i_i32(&b);printf("Test insert: \\n\\n");b = _mm256_insert_epi32(*(__m256i*)arr1, 321, 9);   // 插入一个数据,index 越界则自动模8(截取低3比特)print_m256i_i32(arr1);print_m256i_i32(&b);b = _mm256_insert_epi16(*(__m256i*)arr1, 321, 17);   // 插入一个数据,index 越界则自动模16(截取低4比特)print_m256i_i16(arr1);print_m256i_i16(&b);printf("Test add/sub/mul: \\n\\n");a = _mm256_setr_epi32(1,2,3,4,5,6,7,8);b = _mm256_setr_epi32(7,6,5,4,3,2,1,0);print_m256i_i32(&a);print_m256i_i32(&b);c = _mm256_add_epi16(a, b);     // 普通的加法,会溢出print_m256i_i32(&c);a = _mm256_set1_epi16(32767);b = _mm256_set1_epi16(32767);c = _mm256_adds_epi16(a, b);    // 范围受限,如果越界那么被限制在最大值上print_m256i_i16(&c);a = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8);b = _mm256_setr_epi32(7, 6, 5, 4, 3, 2, 1, 0);c = _mm256_hadd_epi32(a, b);    // 水平加法,连续两个 int 相加。按照 m128,前 2 个是 a 的,后 2 个是 b 的print_m256i_i32(&c);c = _mm256_sub_epi32(a, b);print_m256i_i32(&c);c = _mm256_mul_epi32(a, b);     // 连续的 8 字节,只取出第一个 4 字节的 int,将 long 的乘积结果写在对应的 8 字节里print_m256i_i32(&c);c = _mm256_mulhi_epi16(a, b);   // 截取 16*2 比特结果的高 16 位print_m256i_i32(&c);c = _mm256_mullo_epi16(a, b);   // 截取 16*2 比特结果的低 16 位print_m256i_i32(&c);printf("Test div/rem: \\n\\n");c = _mm256_div_epi32(b, a);     // 整数除法print_m256i_i32(&c);__m256 rem;c = _mm256_divrem_epi32(&rem, b, a);    // 带余除法print_m256i_i32(&rem);print_m256i_i32(&c);c = _mm256_rem_epi32(b, a);     // 余数,b mod aprint_m256i_i32(&c);printf("Test fmadd/fnmadd: \\n\\n");__m256 a2 = _mm256_set1_ps(1.5);__m256 b2 = _mm256_set1_ps(1.5);__m256 c2 = _mm256_set1_ps(3);c2 = _mm256_fmadd_ps(a2,b2,c2);   // a*b+c, 只有 ps/pd 有混合运算c = _mm256_cvtps_epi32(c2);print_m256i_i32(&c);c2 = _mm256_fnmadd_ps(a2, b2, c2);   // -a*b+c, 只有 ps/pd 有混合运算c = _mm256_cvtps_epi32(c2);print_m256i_i32(&c);printf("Test and/or/xor: \\n\\n");print_m256i_i32(&a);print_m256i_i32(&b);c = _mm256_and_epi32(a, b);     // bitwise ANDprint_m256i_i32(&c);c = _mm256_or_epi32(a, b);      // bitwise ORprint_m256i_i32(&c);c = _mm256_xor_epi32(a, b);     // bitwise XORprint_m256i_i32(&c);printf("Test sll/srl: \\n\\n");a = _mm256_setr_epi32(-1, -2, -3, -4, -5, -6, -7, -8);print_m256i_i32(&a);c = _mm256_slli_epi32(a, 1);    // 每个 int 左移,空位补零print_m256i_i32(&c);c = _mm256_srli_epi32(a, 1);    // 每个 int 右移,空位简单补零(负数也不补1)print_m256i_i32(&c);//c = _mm256_rol_epi32(a, 1);     // 循环左移,AVX2 不支持:非法指令//print_m256i_i32(&c);//c = _mm256_ror_epi32(a, 1);     // 循环右移,AVX2 不支持:非法指令//print_m256i_i32(&c);c = _mm256_slli_si256(a, 1);    // 字节水平的左移(m256i_i8[31]在最左边,m256i_i8[0]在最右边)。按m128,空字节全零print_m256i_i32(&c);c = _mm256_srli_si256(a, 4);    // 字节水平的右移。按m128,空字节全零print_m256i_i32(&c);// 这是啥?怎么结果全是 0 啊?/*__m128i offset = _mm_setr_epi32(1,2,3,4);a = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8);c = _mm256_srl_epi32(a, offset);print_m256i_i32(&a);print_m256i_i32(&c);*/// AVX2 不支持,需要 AVX-512//printf("Test mask: \\n\\n");//print_m256i_i16(arr1);//print_m256i_i16(&a);//print_m256i_i16(&b);//c = _mm256_mask_mullo_epi16(*(__m256i*)arr1, 0b0101010101010101, a, b);   // 根据 mask16 各个比特,选择是否在 source 上写入 a*b 结果//print_m256i_i16(&c);//print_m256i_i32(arr1);//print_m256i_i32(&a);//print_m256i_i32(&b);//c = _mm256_mask_mullo_epi32(*(__m256i*)arr1, 0b01010101, a, b);   // 根据 mask6 各个比特,选择是否在 source 上写入 a*b 结果//print_m256i_i32(&c);printf("Test cmp: \\n\\n");a = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8);b = _mm256_setr_epi32(7, 6, 5, 4, 3, 2, 1, 0);print_m256i_i32(&a);print_m256i_i32(&b);c = _mm256_cmpeq_epi32(a, b);   // 满足条件,全1(-1);不满足条件,全0(0)print_m256i_i32(&c);c = _mm256_cmpgt_epi32(a, b);   // 满足条件,全1(-1);不满足条件,全0(0)print_m256i_i32(&c);// AVX2 不支持/*__m128i a3 = _mm256_castsi256_si128(a);__m128i b3 = _mm256_castsi256_si128(b);int mask = _mm_cmpge_epi32_mask(a3, b3);*/return 0;
}