more tests

This commit is contained in:
Jorijn van der Graaf 2026-03-26 03:53:30 +01:00
commit cc2c13f7a5
3 changed files with 393 additions and 307 deletions

View file

@ -45,11 +45,141 @@ namespace Crafter {
VectorType v;
public:
template <std::uint32_t Len2, std::uint32_t Packing2>
friend class VectorF16;
static constexpr std::uint32_t MaxSize = 32;
static constexpr std::uint8_t Alignment = GetAlingment();
static_assert(Len * Packing <= MaxSize, "Len * Packing exceeds MaxSize");
private:
template <std::array<bool, Len> values>
static consteval std::array<std::uint16_t, Alignment> GetNegateMask() {
std::array<std::uint16_t, Alignment> mask{0};
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
for(std::uint8_t i = 0; i < Len; i++) {
if(values[i]) {
mask[i2*Len+i] = 0b1000000000000000;
} else {
mask[i2*Len+i] = 0;
}
}
}
return mask;
}
static consteval std::array<std::uint16_t, Alignment> GetNegateMaskAll() {
std::array<std::uint16_t, Alignment> mask{0};
for(std::uint8_t i = 0; i < Packing*Len; i++) {
mask[i] = 0b1000000000000000;
}
return mask;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval bool GetShuffleMaskEpi32() {
std::uint8_t mask = 0;
for(std::uint8_t i = 0; i < std::min(Len, std::uint32_t(8)); i+=2) {
mask = mask | (ShuffleValues[i] & 0b11) << i;
}
return mask;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval std::array<std::uint16_t, VectorF16<Len, Packing>::Alignment> GetPermuteMaskEpi16() {
std::array<std::uint16_t, VectorF16<Len, Packing>::Alignment> shuffleMask {{0}};
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
for(std::uint8_t i = 0; i < Len; i++) {
shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len;
}
}
return shuffleMask;
}
static consteval std::array<bool, Len> GetAllTrue() {
std::array<bool, Len> arr{};
arr.fill(true);
return arr;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval bool CheckEpi32Shuffle() {
for(std::uint8_t i = 1; i < Len; i+=2) {
if(ShuffleValues[i-1] != ShuffleValues[i] - 1) {
return false;
}
}
for(std::uint8_t i = 0; i < Len; i++) {
for(std::uint8_t i2 = 0; i2 < Len; i2 += 8) {
if(ShuffleValues[i] != ShuffleValues[i2]) {
return false;
}
}
}
return true;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval bool CheckEpi8Shuffle() {
for(std::uint8_t i = 0; i < Len; i++) {
std::uint8_t lane = i / 8;
if(ShuffleValues[i] < lane * 8 || ShuffleValues[i] > lane * 8 + 7) {
return false;
}
}
return true;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m128h>){
std::uint8_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m256h>){
std::uint16_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m512h>){
std::uint32_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> GetShuffleMaskEpi8() {
std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask {{0}};
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
for(std::uint8_t i = 0; i < Len; i++) {
shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2);
shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2);
}
}
return shuffleMask;
}
public:
template <std::uint32_t Len2, std::uint32_t Packing2>
friend class VectorF16;
constexpr VectorF16() = default;
constexpr VectorF16(VectorType v) : v(v) {}
@ -108,8 +238,60 @@ namespace Crafter {
} else {
return VectorF16<BLen, BPacking>(v);
}
} else {
} else if constexpr (BLen <= Len) {
return this->template ExtractLo<BLen>();
} else {
if constexpr(std::is_same_v<typename VectorF16<BLen, BPacking>::VectorType, __m128h>) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF16<BLen, Packing>::Alignment*2> shuffleMask = GetExtractLoMaskEpi8<BLen>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm_castph_si256(v));
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm256_castsi256_si128(result)));
} else {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v));
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm512_castsi512_si128(result)));
}
} else if constexpr(std::is_same_v<typename VectorF16<BLen, BPacking>::VectorType, __m256h>) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castsi128_si256(_mm_castph_si128(v)));
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v));
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
} else {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m256i permIdx = _mm512_loadu_epi16(permMask.data());
__m256i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi512_si256(_mm512_castph_si512(v)));
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
}
} else {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi128_si512(_mm_castph_si128(v)));
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi256_si512(_mm256_castph_si256(v)));
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
} else {
constexpr std::array<std::uint16_t, VectorF16<BLen, Packing>::Alignment> permMask = GetExtractLoMaskEpi16<BLen>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v));
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
}
}
}
}
@ -442,33 +624,6 @@ namespace Crafter {
}
}
template <const std::array<std::uint8_t, Len> ShuffleValues>
constexpr VectorF16<Len, Packing> Shuffle() {
if constexpr(VectorF16<Len, Packing>::CheckEpi32Shuffle<ShuffleValues>()) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else {
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(v), GetShuffleMaskEpi32<ShuffleValues>())));
}
} else {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(v)), _mm512_castsi256_si512(shuffleVec)))));
} else {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
}
}
}
static constexpr VectorF16<Len, Packing> MulitplyAdd(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> add) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add.v));
@ -549,6 +704,47 @@ namespace Crafter {
}
}
template <const std::array<std::uint8_t, Len> ShuffleValues>
constexpr VectorF16<Len, Packing> Shuffle() {
if constexpr(CheckEpi32Shuffle<ShuffleValues>()) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else {
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(v), GetShuffleMaskEpi32<ShuffleValues>())));
}
} else if constexpr(CheckEpi8Shuffle<ShuffleValues>()){
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(v)), _mm512_castsi256_si512(shuffleVec)))));
} else {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
}
} else {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint16_t, VectorF16<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi16<ShuffleValues>();
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v))));
} else {
constexpr std::array<std::uint16_t, VectorF16<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi16<ShuffleValues>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v))));
}
}
}
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> B,
@ -558,7 +754,7 @@ namespace Crafter {
VectorF16<Len, Packing> F,
VectorF16<Len, Packing> G,
VectorF16<Len, Packing> H
) requires(Len == 8) {
) requires(Len == 8 && Packing*Len == Alignment) {
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0,0,0,0,0}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskB = GetShuffleMaskEpi8<{{1,1,1,1,1,1,1,1}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskC = GetShuffleMaskEpi8<{{2,2,2,2,2,2,2,2}}>();
@ -696,7 +892,7 @@ namespace Crafter {
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4) {
) requires(Len == 4 && Packing*Len == Alignment) {
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskC = GetShuffleMaskEpi8<{{1,1,1,1}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskE = GetShuffleMaskEpi8<{{2,2,2,2}}>();
@ -780,62 +976,50 @@ namespace Crafter {
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> E
) requires(Len == 2) {
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskA = GetShuffleMaskEpi8<{{0,0}}>();
constexpr std::array<std::uint8_t, Alignment*2> shuffleMaskE = GetShuffleMaskEpi8<{{1,1}}>();
) requires(Len == 2 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
VectorF16<Len, Packing> lenght = Length(A, E);
VectorF16<1, 8> lenght = Length(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m128h one = _mm_loadu_ph(oneArr);
__m128h fLenght = _mm_div_ph(one, lenght.v);
VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v));
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data());
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data());
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3}}>();
VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7}}>();
return {
_mm_mul_ph(A.v, fLenghtA),
_mm_mul_ph(E.v, fLenghtE),
_mm_mul_ph(A.v, fLenghtA.v),
_mm_mul_ph(E.v, fLenghtE.v),
};
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
VectorF16<Len, Packing> lenght = Length(A, E);
VectorF16<1, 16> lenght = Length(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m256h one = _mm256_loadu_ph(oneArr);
__m256h fLenght = _mm256_div_ph(one, lenght.v);
VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v));
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data());
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data());
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7}}>();
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>();
return {
_mm256_mul_ph(A.v, fLenghtA),
_mm256_mul_ph(E.v, fLenghtE),
_mm256_mul_ph(A.v, fLenghtA.v),
_mm256_mul_ph(E.v, fLenghtE.v),
};
} else {
VectorF16<Len, Packing> lenght = Length(A, E);
VectorF16<1, 32> lenght = Length(A, E);
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512h one = _mm512_loadu_ph(oneArr);
__m512h fLenght = _mm512_div_ph(one, lenght.v);
VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v));
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data());
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data());
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15}}>();
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,30,30,31,31}}>();
return {
_mm512_mul_ph(A.v, fLenghtA),
_mm512_mul_ph(E.v, fLenghtE),
_mm512_mul_ph(A.v, fLenghtA.v),
_mm512_mul_ph(E.v, fLenghtE.v),
};
}
}
constexpr static VectorF16<Len, Packing> Length(
constexpr static VectorF16<1, Packing*8> Length(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> B,
VectorF16<Len, Packing> C,
@ -844,48 +1028,48 @@ namespace Crafter {
VectorF16<Len, Packing> F,
VectorF16<Len, Packing> G,
VectorF16<Len, Packing> H
) requires(Len == 8) {
VectorF16<Len, Packing> lenghtSq = LengthSq(A, B, C, D, E, F, G, H);
) requires(Len == 8 && Packing*Len == Alignment) {
VectorF16<1, Packing*8> lenghtSq = LengthSq(A, B, C, D, E, F, G, H);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*8>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*8>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<Len, Packing>(_mm512_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*8>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<Len, Packing> Length(
constexpr static VectorF16<1, Packing*4> Length(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4) {
VectorF16<Len, Packing> lenghtSq = LengthSq(A, C, E, G);
) requires(Len == 4 && Packing*Len == Alignment) {
VectorF16<1, Packing*4> lenghtSq = LengthSq(A, C, E, G);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<Len, Packing>(_mm512_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<Len, Packing> Length(
constexpr static VectorF16<1, Packing*2> Length(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> E
) requires(Len == 2) {
VectorF16<Len, Packing> lenghtSq = LengthSq(A, E);
) requires(Len == 2 && Packing*Len == Alignment) {
VectorF16<1, Packing*2> lenghtSq = LengthSq(A, E);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing>(_mm_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing>(_mm256_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v));
} else {
return VectorF16<Len, Packing>(_mm512_sqrt_ph(lenghtSq.v));
return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v));
}
}
constexpr static VectorF16<Len, Packing> LengthSq(
constexpr static VectorF16<1, Packing*8> LengthSq(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> B,
VectorF16<Len, Packing> C,
@ -894,27 +1078,27 @@ namespace Crafter {
VectorF16<Len, Packing> F,
VectorF16<Len, Packing> G,
VectorF16<Len, Packing> H
) requires(Len == 8) {
) requires(Len == 8 && Packing*Len == Alignment) {
return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
}
constexpr static VectorF16<Len, Packing> LengthSq(
constexpr static VectorF16<1, Packing*4> LengthSq(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> C,
VectorF16<Len, Packing> E,
VectorF16<Len, Packing> G
) requires(Len == 4) {
) requires(Len == 4 && Packing*Len == Alignment) {
return Dot(A, A, C, C, E, E, G, G);
}
constexpr static VectorF16<Len, Packing> LengthSq(
constexpr static VectorF16<1, Packing*2> LengthSq(
VectorF16<Len, Packing> A,
VectorF16<Len, Packing> E
) requires(Len == 2) {
) requires(Len == 2 && Packing*Len == Alignment) {
return Dot(A, A, E, E);
}
constexpr static VectorF16<Len, Packing> Dot(
constexpr static VectorF16<1, Packing*8> Dot(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> B0, VectorF16<Len, Packing> B1,
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
@ -923,7 +1107,7 @@ namespace Crafter {
VectorF16<Len, Packing> F0, VectorF16<Len, Packing> F1,
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1,
VectorF16<Len, Packing> H0, VectorF16<Len, Packing> H1
) requires(Len == 8) {
) requires(Len == 8 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
__m128h mulB = _mm_mul_ph(B0.v, B1.v);
@ -1086,12 +1270,12 @@ namespace Crafter {
}
}
constexpr static VectorF16<Len, Packing> Dot(
constexpr static VectorF16<1, Packing*4> Dot(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1
) requires(Len == 4) {
) requires(Len == 4 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
@ -1179,10 +1363,10 @@ namespace Crafter {
}
}
constexpr static VectorF16<Len, Packing> Dot(
constexpr static VectorF16<1, Packing*2> Dot(
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1
) requires(Len == 2) {
) requires(Len == 2 && Packing*Len == Alignment) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
@ -1200,7 +1384,9 @@ namespace Crafter {
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
__m256i row12Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
__m256i row12Temp1Temp = row12Temp1;
@ -1209,8 +1395,12 @@ namespace Crafter {
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
return _mm256_add_ph(row1, row2);
__m256h result = _mm256_add_ph(row1, row2);
VectorF16<16, 1> vec(result);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
return VectorF16<1, 16>(vec.v);
} else {
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
@ -1223,8 +1413,11 @@ namespace Crafter {
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
__m512h result = _mm512_add_ph(row1, row2);
return _mm512_add_ph(row1, row2);
VectorF16<32, 1> vec(result);
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
return VectorF16<1, 32>(vec.v);
}
}
@ -1292,113 +1485,6 @@ namespace Crafter {
return row1;
}
private:
template <std::array<bool, Len> values>
static consteval std::array<std::uint16_t, Alignment> GetNegateMask() {
std::array<std::uint16_t, Alignment> mask{0};
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
for(std::uint8_t i = 0; i < Len; i++) {
if(values[i]) {
mask[i2*Len+i] = 0b1000000000000000;
} else {
mask[i2*Len+i] = 0;
}
}
}
return mask;
}
static consteval std::array<std::uint16_t, Alignment> GetNegateMaskAll() {
std::array<std::uint16_t, Alignment> mask{0};
for(std::uint8_t i = 0; i < Packing*Len; i++) {
mask[i] = 0b1000000000000000;
}
return mask;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval bool GetShuffleMaskEpi32() {
std::uint8_t mask = 0;
for(std::uint8_t i = 0; i < std::min(Len, std::uint32_t(8)); i+=2) {
mask = mask | (ShuffleValues[i] & 0b11) << i;
}
return mask;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> GetShuffleMaskEpi8() {
std::array<std::uint8_t, VectorF16<Len, Packing>::Alignment*2> shuffleMask {{0}};
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
for(std::uint8_t i = 0; i < Len; i++) {
shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2);
shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2);
}
}
return shuffleMask;
}
static consteval std::array<bool, Len> GetAllTrue() {
std::array<bool, Len> arr{};
arr.fill(true);
return arr;
}
template <std::array<std::uint8_t, Len> ShuffleValues>
static consteval bool CheckEpi32Shuffle() {
for(std::uint8_t i = 1; i < Len; i+=2) {
if(ShuffleValues[i-1] != ShuffleValues[i] - 1) {
return false;
}
}
for(std::uint8_t i = 0; i < Len; i++) {
for(std::uint8_t i2 = 0; i2 < Len; i2 += 8) {
if(ShuffleValues[i] != ShuffleValues[i2]) {
return false;
}
}
}
return true;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m128h>){
std::uint8_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m256h>){
std::uint16_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
template <std::array<bool, Len> ShuffleValues>
static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m512h>){
std::uint32_t mask = 0;
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
for (std::uint8_t i = 0; i < Len; i++) {
if (ShuffleValues[i]) {
mask |= (1u << (i2 * Len + i));
}
}
}
return mask;
}
static constexpr float two_over_pi = 0.6366197723675814f;
static constexpr float pi_over_2_hi = 1.5707963267341256f;
static constexpr float pi_over_2_lo = 6.077100506506192e-11f;