This commit is contained in:
Jorijn van der Graaf 2026-03-31 14:22:18 +02:00
commit a16f8ffbde
7 changed files with 251 additions and 133 deletions

View file

@ -24,7 +24,6 @@ export module Crafter.Math:VectorF32;
import std;
import :Common;
#ifdef __AVX512FP16__
namespace Crafter {
export template <std::uint8_t Len, std::uint8_t Packing>
struct VectorF32 : public VectorBase<Len, Packing, float> {
@ -38,6 +37,9 @@ namespace Crafter {
constexpr VectorF32(const float* vB) {
Load(vB);
};
constexpr VectorF32(const _Float16* vB) {
Load(vB);
};
constexpr VectorF32(float val) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_set1_ps(val);
@ -66,8 +68,55 @@ namespace Crafter {
}
}
constexpr std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
constexpr void Load(const _Float16* vB) {
#ifdef __F16C__
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_cvtph_ps(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(vB)));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(vB)));
} else {
this->v = _mm512_cvtph_ps(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(vB)));
}
#else
alignas(64) float tmp[Len];
for (int i = 0; i < Len; ++i)
tmp[i] = static_cast<float>(vB[i]);
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_load_ps(tmp);
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_load_ps(tmp);
} else {
this->v = _mm512_load_ps(tmp);
}
#endif
}
constexpr void Store(_Float16* vB) const {
#ifdef __F16C__
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
_mm_storel_epi64(reinterpret_cast<__m128i*>(vB), _mm_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(vB), _mm256_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(vB), _mm512_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
}
#else
alignas(64) float tmp[Len];
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
_mm_store_ps(tmp, this->v);
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
_mm256_store_ps(tmp, this->v);
} else {
_mm512_store_ps(tmp, this->v);
}
for (int i = 0; i < Len; ++i)
vB[i] = static_cast<_Float16>(tmp[i]);
#endif
}
template<typename T>
constexpr std::array<T, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
std::array<T, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
Store(returnArray.data());
return returnArray;
}
@ -96,36 +145,41 @@ namespace Crafter {
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<BLen>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
__m128i shuffleVec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(shuffleMask.data()));
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask =VectorBase<Len, Packing, float>::template GetExtractLoMaskepi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm_castps_si256(this->v));
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm256_castsi256_si128(result)));
#ifdef __AVX512F__
} else {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm512_castsi512_si128(result)));
#endif
}
} else if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m256>) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castsi128_si256(_mm_castps_si128(this->v)));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
#ifdef __AVX512F__
} else {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm512_loadu_epi32(permMask.data());
__m256i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi512_si256(_mm512_castps_si512(this->v)));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
#endif
}
#ifdef __AVX512F__
} else {
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
@ -143,6 +197,7 @@ namespace Crafter {
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
}
#endif
}
}
}
@ -272,25 +327,27 @@ namespace Crafter {
return Negate<VectorBase<Len, Packing, float>::GetAllTrue()>();
}
constexpr bool operator==(VectorF32<Len, Packing> b) const {
constexpr bool operator==(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 15;
#ifdef __AVX512VL__
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xF;
#else
return _mm_movemask_ps(_mm_cmpeq_ps(this->v, b.v)) == 0xF;
#endif
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 255;
#ifdef __AVX512VL__
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xFF;
#else
return _mm256_movemask_ps(_mm256_cmp_ps(this->v, b.v, _CMP_EQ_OQ)) == 0xFF;
#endif
} else {
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 65535;
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xFFFF;
}
}
}
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 15;
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 255;
} else {
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 65535;
}
}
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
return !(*this == b);
}
template<std::uint32_t ExtractLen>
constexpr VectorF32<ExtractLen, Packing> ExtractLo() const {
@ -301,7 +358,7 @@ namespace Crafter {
return VectorF32<ExtractLen, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(_mm256_castsi256_ps(result)));
@ -323,10 +380,12 @@ namespace Crafter {
} else {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(this->v));
#ifdef __AVX512F__
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(this->v));
#endif
} else {
return VectorF32<ExtractLen, Packing>(this->v);
}
@ -338,8 +397,10 @@ namespace Crafter {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x4(this->v));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x8(this->v));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x16(this->v));
#endif
}
}
@ -348,8 +409,10 @@ namespace Crafter {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x4(this->v));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x8(this->v));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x16(this->v));
#endif
}
}
@ -369,7 +432,7 @@ namespace Crafter {
VectorF32<Len, Packing>(s),
VectorF32<Len, Packing>(c)
};
#ifdef __AVX512F__
} else {
__m512 s, c;
VectorBase<Len, Packing, float>::sincos_f32x16(this->v, s, c);
@ -377,6 +440,7 @@ namespace Crafter {
VectorF32<Len, Packing>(s),
VectorF32<Len, Packing>(c)
};
#endif
}
}
@ -384,11 +448,13 @@ namespace Crafter {
constexpr VectorF32<Len, Packing> Negate() {
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> mask = VectorBase<Len, Packing, float>::template GetNegateMask<values>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_epi32(mask.data()))));
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_si128(reinterpret_cast<__m128i*>(mask.data())))));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_epi32(mask.data()))));
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_si256(reinterpret_cast<__m256i*>(mask.data())))));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(this->v), _mm512_loadu_epi32(mask.data()))));
#endif
}
}
@ -397,8 +463,10 @@ namespace Crafter {
return VectorF32<Len, Packing>(_mm_fmadd_ps(a.v, b.v, add.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_fmadd_ps(a.v, b.v, add.v));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(_mm512_fmadd_ps(a.v, b.v, add.v));
#endif
}
}
@ -407,55 +475,22 @@ namespace Crafter {
return VectorF32<Len, Packing>(_mm_fmsub_ps(a.v, b.v, sub.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_fmsub_ps(a.v, b.v, sub.v));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(_mm512_fmsub_ps(a.v, b.v, sub.v));
#endif
}
}
constexpr static VectorF32<Len, Packing> Cross(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) requires(Len == 3) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data());
__m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec1));
__m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec1));
VectorF32<Len, Packing> row1 = a.template Shuffle<{{1,2,0}}>();
VectorF32<Len, Packing> row4 = b.template Shuffle<{{1,2,0}}>();
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data());
__m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec3));
__m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec3));
VectorF32<Len, Packing> row3 = a.template Shuffle<{{2,0,1}}>();
VectorF32<Len, Packing> row2 = b.template Shuffle<{{2,0,1}}>();
__m128 result = _mm_mul_ps(row3, row4);
return _mm_fmsub_ps(row1,row2,result);
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data()));
__m256 row1 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec1)));
__m256 row4 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec1)));
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data()));
__m256 row3 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec3)));
__m256 row2 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec3)));
__m256 result = _mm256_mul_ps(row3, row4);
return _mm256_fmsub_ps(row1,row2,result);
} else {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data());
__m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec1));
__m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec1));
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data());
__m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec3));
__m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec3));
__m512 result = _mm512_mul_ps(row3, row4);
return _mm512_fmsub_ps(row1,row2,result);
}
VectorF32<Len, Packing> result = row3 * row4;
return VectorF32<Len, Packing>::MulitplySub(row1, row2, result);
}
template <const std::array<std::uint8_t, Len> ShuffleValues>
@ -465,21 +500,31 @@ namespace Crafter {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(this->v), imm)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
#ifdef __AVX512F__
} else {
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(this->v), imm)));
#endif
}
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()){
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)), _mm512_castsi256_si512(shuffleVec)))));
#ifdef __AVX512BW__
__m256i shuffleVec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(shuffleMask.data()));
return VectorF32<Len, Packing>(_mm256_castsi256_ps( _mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)),_mm512_castsi256_si512(shuffleVec)))));
#else
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutevar8x32_epi32(_mm256_castps_si256(this->v), permIdx)));
#endif
#ifdef __AVX512F__
} else {
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
__m512i shuffleVec = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(shuffleMask.data()));
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(this->v), shuffleVec)));
#endif
}
} else {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
@ -488,15 +533,17 @@ namespace Crafter {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v))));
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutevar8x32_epi32(_mm256_castps_si256(this->v), permIdx)));
#ifdef __AVX512F__
} else {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v))));
#endif
}
}
}
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
@ -539,6 +586,7 @@ namespace Crafter {
_mm256_mul_ps(C.v, fLenghtC.v),
_mm256_mul_ps(D.v, fLenghtD.v)
};
#if defined(__AVX512F__)
} else {
VectorF32<1, 16> lenght = LengthNoShuffle(A, C, B, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
@ -558,6 +606,7 @@ namespace Crafter {
VectorF32<Len, Packing>(_mm512_mul_ps(C.v, fLenghtC.v)),
VectorF32<Len, Packing>(_mm512_mul_ps(D.v, fLenghtD.v)),
};
#endif
}
}
@ -609,6 +658,7 @@ namespace Crafter {
};
}
#ifdef __AVX512F__
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
@ -629,6 +679,7 @@ namespace Crafter {
_mm512_mul_ps(C.v, fLenghtC.v),
};
}
#endif
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
@ -660,6 +711,7 @@ namespace Crafter {
_mm256_mul_ps(A.v, fLenghtA.v),
_mm256_mul_ps(B.v, fLenghtB.v),
};
#ifdef __AVX512F__
} else {
VectorF32<1, 16> lenght = LengthNoShuffle(A, B);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
@ -673,6 +725,7 @@ namespace Crafter {
_mm512_mul_ps(A.v, fLenghtA.v),
_mm512_mul_ps(B.v, fLenghtB.v),
};
#endif
}
}
@ -712,6 +765,7 @@ namespace Crafter {
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
}
#ifdef __AVX512F__
constexpr static VectorF32<1, 15> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
@ -720,6 +774,7 @@ namespace Crafter {
VectorF32<1, 15> lenghtSq = LengthSq(A, B, C);
return VectorF32<1, 15>(_mm512_sqrt_ps(lenghtSq.v));
}
#endif
constexpr static VectorF32<1, Packing*2> Length(
VectorF32<Len, Packing> A,
@ -730,8 +785,10 @@ namespace Crafter {
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
#ifdef __AVX512F__
} else {
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
#endif
}
}
@ -762,6 +819,7 @@ namespace Crafter {
return Dot(A, A, B, B, C, C, D, D);
}
#ifdef __AVX512F__
constexpr static VectorF32<1, 15> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
@ -769,6 +827,7 @@ namespace Crafter {
) requires(Len == 3 && Packing == 5) {
return Dot(A, A, B, B, C, C);
}
#endif
constexpr static VectorF32<1, Packing*2> LengthSq(
VectorF32<Len, Packing> A,
@ -792,6 +851,7 @@ namespace Crafter {
1,5,3,7,
}}>();
return vec.v;
#ifdef __AVX512F__
} else {
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
vec = vec.template Shuffle<{{
@ -801,6 +861,7 @@ namespace Crafter {
3,7,11,15
}}>();
return vec.v;
#endif
}
}
@ -955,6 +1016,7 @@ namespace Crafter {
return row1;
}
#ifdef __AVX512F__
constexpr static VectorF32<1, 15> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
@ -1044,6 +1106,7 @@ namespace Crafter {
return row1;
}
#endif
constexpr static VectorF32<1, Packing*2> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
@ -1058,6 +1121,7 @@ namespace Crafter {
2,3, 6,7,
}}>();
return vec.v;
#ifdef __AVX512F__
} else {
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
vec = vec.template Shuffle<{{
@ -1067,6 +1131,7 @@ namespace Crafter {
10,11, 14,15
}}>();
return vec.v;
#endif
}
}
@ -1083,8 +1148,10 @@ namespace Crafter {
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
#ifdef __AVX512F__
} else {
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
#endif
}
}
@ -1097,8 +1164,10 @@ namespace Crafter {
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
#ifdef __AVX512F__
} else {
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
#endif
}
}
@ -1172,6 +1241,7 @@ namespace Crafter {
row1 = _mm256_add_ps(row1, row4);
return row1;
#ifdef __AVX512F__
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulB = _mm512_mul_ps(B0.v, B1.v);
@ -1195,6 +1265,7 @@ namespace Crafter {
row1 = _mm512_add_ps(row1, row4);
return row1;
#endif
}
}
@ -1226,6 +1297,7 @@ namespace Crafter {
row56Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm256_add_ps(row12Temp1, row56Temp1);
#ifdef __AVX512F__
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
@ -1238,6 +1310,7 @@ namespace Crafter {
row56Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm512_add_ps(row12Temp1, row56Temp1);
#endif
}
}
public:
@ -1245,19 +1318,20 @@ namespace Crafter {
template <std::array<bool, Len> ShuffleValues>
constexpr static VectorF32<Len, Packing> Blend(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) {
constexpr auto mask = VectorBase<Len, Packing, float>::template GetBlendMaskEpi32<ShuffleValues>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b.v), mask));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
#ifndef __AVX512BW__
#ifndef __AVX512VL__
static_assert(false, "No __AVX512BW__ and __AVX512VL__ support");
#endif
#endif
return _mm256_castsi256_ps(_mm256_mask_blend_epi32(mask, _mm256_castps_si256(a.v), _mm256_castps_si256(b.v)));
} else {
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return _mm256_castsi256_ps(_mm256_blend_epi32(_mm256_castps_si256(a.v), _mm256_castps_si256(b.v), mask));
#ifdef __AVX512F__
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
return _mm512_castsi512_ps(_mm512_mask_blend_epi32(mask, _mm512_castps_si512(a.v), _mm512_castps_si512(b.v)));
#endif
}
}
}
constexpr static VectorF32<Len, Packing> Rotate(VectorF32<3, Packing> v, VectorF32<4, Packing> q) requires(Len == 3) {
VectorF32<3, Packing> qv(q);
@ -1314,7 +1388,7 @@ namespace Crafter {
export template <std::uint32_t Len, std::uint32_t Packing>
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.Store();
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.template Store<float>();
std::string out = "{";
for(std::uint32_t i = 0; i < Packing; i++) {
out += "{";
@ -1327,5 +1401,4 @@ struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::st
out += "}";
return std::formatter<std::string>::format(out, ctx);
}
};
#endif
};