x86v3
This commit is contained in:
parent
143b71eeb9
commit
a16f8ffbde
7 changed files with 251 additions and 133 deletions
|
|
@ -24,7 +24,6 @@ export module Crafter.Math:VectorF32;
|
|||
import std;
|
||||
import :Common;
|
||||
|
||||
#ifdef __AVX512FP16__
|
||||
namespace Crafter {
|
||||
export template <std::uint8_t Len, std::uint8_t Packing>
|
||||
struct VectorF32 : public VectorBase<Len, Packing, float> {
|
||||
|
|
@ -38,6 +37,9 @@ namespace Crafter {
|
|||
constexpr VectorF32(const float* vB) {
|
||||
Load(vB);
|
||||
};
|
||||
constexpr VectorF32(const _Float16* vB) {
|
||||
Load(vB);
|
||||
};
|
||||
constexpr VectorF32(float val) {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
this->v = _mm_set1_ps(val);
|
||||
|
|
@ -66,8 +68,55 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
|
||||
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
|
||||
constexpr void Load(const _Float16* vB) {
|
||||
#ifdef __F16C__
|
||||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
this->v = _mm_cvtph_ps(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(vB)));
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
this->v = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(vB)));
|
||||
} else {
|
||||
this->v = _mm512_cvtph_ps(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(vB)));
|
||||
}
|
||||
#else
|
||||
alignas(64) float tmp[Len];
|
||||
for (int i = 0; i < Len; ++i)
|
||||
tmp[i] = static_cast<float>(vB[i]);
|
||||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
this->v = _mm_load_ps(tmp);
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
this->v = _mm256_load_ps(tmp);
|
||||
} else {
|
||||
this->v = _mm512_load_ps(tmp);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
constexpr void Store(_Float16* vB) const {
|
||||
#ifdef __F16C__
|
||||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
_mm_storel_epi64(reinterpret_cast<__m128i*>(vB), _mm_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i*>(vB), _mm256_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
|
||||
} else {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(vB), _mm512_cvtps_ph(this->v, _MM_FROUND_TO_NEAREST_INT));
|
||||
}
|
||||
#else
|
||||
alignas(64) float tmp[Len];
|
||||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
_mm_store_ps(tmp, this->v);
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
_mm256_store_ps(tmp, this->v);
|
||||
} else {
|
||||
_mm512_store_ps(tmp, this->v);
|
||||
}
|
||||
for (int i = 0; i < Len; ++i)
|
||||
vB[i] = static_cast<_Float16>(tmp[i]);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr std::array<T, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
|
||||
std::array<T, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
|
||||
Store(returnArray.data());
|
||||
return returnArray;
|
||||
}
|
||||
|
|
@ -96,36 +145,41 @@ namespace Crafter {
|
|||
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<BLen>();
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
__m128i shuffleVec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(shuffleMask.data()));
|
||||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask =VectorBase<Len, Packing, float>::template GetExtractLoMaskepi32<BLen>();
|
||||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm_castps_si256(this->v));
|
||||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm256_castsi256_si128(result)));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
|
||||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm512_castsi512_si128(result)));
|
||||
#endif
|
||||
}
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m256>) {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castsi128_si256(_mm_castps_si128(this->v)));
|
||||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
|
||||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||||
__m256i permIdx = _mm512_loadu_epi32(permMask.data());
|
||||
__m256i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi512_si256(_mm512_castps_si512(this->v)));
|
||||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||||
#endif
|
||||
}
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||||
|
|
@ -143,6 +197,7 @@ namespace Crafter {
|
|||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
|
||||
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -272,25 +327,27 @@ namespace Crafter {
|
|||
return Negate<VectorBase<Len, Packing, float>::GetAllTrue()>();
|
||||
}
|
||||
|
||||
constexpr bool operator==(VectorF32<Len, Packing> b) const {
|
||||
constexpr bool operator==(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 15;
|
||||
#ifdef __AVX512VL__
|
||||
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xF;
|
||||
#else
|
||||
return _mm_movemask_ps(_mm_cmpeq_ps(this->v, b.v)) == 0xF;
|
||||
#endif
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 255;
|
||||
#ifdef __AVX512VL__
|
||||
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xFF;
|
||||
#else
|
||||
return _mm256_movemask_ps(_mm256_cmp_ps(this->v, b.v, _CMP_EQ_OQ)) == 0xFF;
|
||||
#endif
|
||||
} else {
|
||||
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 65535;
|
||||
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 0xFFFF;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 15;
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 255;
|
||||
} else {
|
||||
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 65535;
|
||||
}
|
||||
}
|
||||
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
|
||||
return !(*this == b);
|
||||
}
|
||||
|
||||
template<std::uint32_t ExtractLen>
|
||||
constexpr VectorF32<ExtractLen, Packing> ExtractLo() const {
|
||||
|
|
@ -301,7 +358,7 @@ namespace Crafter {
|
|||
return VectorF32<ExtractLen, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
|
||||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
|
||||
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||||
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(_mm256_castsi256_ps(result)));
|
||||
|
|
@ -323,10 +380,12 @@ namespace Crafter {
|
|||
} else {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||||
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(this->v));
|
||||
#ifdef __AVX512F__
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(this->v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(this->v));
|
||||
#endif
|
||||
} else {
|
||||
return VectorF32<ExtractLen, Packing>(this->v);
|
||||
}
|
||||
|
|
@ -338,8 +397,10 @@ namespace Crafter {
|
|||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x4(this->v));
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x8(this->v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x16(this->v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -348,8 +409,10 @@ namespace Crafter {
|
|||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x4(this->v));
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x8(this->v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x16(this->v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -369,7 +432,7 @@ namespace Crafter {
|
|||
VectorF32<Len, Packing>(s),
|
||||
VectorF32<Len, Packing>(c)
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
__m512 s, c;
|
||||
VectorBase<Len, Packing, float>::sincos_f32x16(this->v, s, c);
|
||||
|
|
@ -377,6 +440,7 @@ namespace Crafter {
|
|||
VectorF32<Len, Packing>(s),
|
||||
VectorF32<Len, Packing>(c)
|
||||
};
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -384,11 +448,13 @@ namespace Crafter {
|
|||
constexpr VectorF32<Len, Packing> Negate() {
|
||||
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> mask = VectorBase<Len, Packing, float>::template GetNegateMask<values>();
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_epi32(mask.data()))));
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_si128(reinterpret_cast<__m128i*>(mask.data())))));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_epi32(mask.data()))));
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_si256(reinterpret_cast<__m256i*>(mask.data())))));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(this->v), _mm512_loadu_epi32(mask.data()))));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -397,8 +463,10 @@ namespace Crafter {
|
|||
return VectorF32<Len, Packing>(_mm_fmadd_ps(a.v, b.v, add.v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(_mm256_fmadd_ps(a.v, b.v, add.v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_fmadd_ps(a.v, b.v, add.v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -407,55 +475,22 @@ namespace Crafter {
|
|||
return VectorF32<Len, Packing>(_mm_fmsub_ps(a.v, b.v, sub.v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(_mm256_fmsub_ps(a.v, b.v, sub.v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_fmsub_ps(a.v, b.v, sub.v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF32<Len, Packing> Cross(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) requires(Len == 3) {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||||
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data());
|
||||
__m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec1));
|
||||
__m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec1));
|
||||
VectorF32<Len, Packing> row1 = a.template Shuffle<{{1,2,0}}>();
|
||||
VectorF32<Len, Packing> row4 = b.template Shuffle<{{1,2,0}}>();
|
||||
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||||
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data());
|
||||
__m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec3));
|
||||
__m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec3));
|
||||
VectorF32<Len, Packing> row3 = a.template Shuffle<{{2,0,1}}>();
|
||||
VectorF32<Len, Packing> row2 = b.template Shuffle<{{2,0,1}}>();
|
||||
|
||||
__m128 result = _mm_mul_ps(row3, row4);
|
||||
return _mm_fmsub_ps(row1,row2,result);
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||||
__m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data()));
|
||||
__m256 row1 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec1)));
|
||||
__m256 row4 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec1)));
|
||||
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||||
|
||||
__m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data()));
|
||||
__m256 row3 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec3)));
|
||||
__m256 row2 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec3)));
|
||||
|
||||
__m256 result = _mm256_mul_ps(row3, row4);
|
||||
return _mm256_fmsub_ps(row1,row2,result);
|
||||
} else {
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||||
|
||||
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data());
|
||||
__m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec1));
|
||||
__m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec1));
|
||||
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||||
|
||||
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data());
|
||||
__m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec3));
|
||||
__m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec3));
|
||||
|
||||
__m512 result = _mm512_mul_ps(row3, row4);
|
||||
return _mm512_fmsub_ps(row1,row2,result);
|
||||
}
|
||||
VectorF32<Len, Packing> result = row3 * row4;
|
||||
return VectorF32<Len, Packing>::MulitplySub(row1, row2, result);
|
||||
}
|
||||
|
||||
template <const std::array<std::uint8_t, Len> ShuffleValues>
|
||||
|
|
@ -465,21 +500,31 @@ namespace Crafter {
|
|||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(this->v), imm)));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(this->v), imm)));
|
||||
#endif
|
||||
}
|
||||
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()){
|
||||
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()) {
|
||||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)), _mm512_castsi256_si512(shuffleVec)))));
|
||||
#ifdef __AVX512BW__
|
||||
__m256i shuffleVec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(shuffleMask.data()));
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps( _mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)),_mm512_castsi256_si512(shuffleVec)))));
|
||||
#else
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutevar8x32_epi32(_mm256_castps_si256(this->v), permIdx)));
|
||||
#endif
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
|
||||
__m512i shuffleVec = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(shuffleMask.data()));
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(this->v), shuffleVec)));
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
|
|
@ -488,15 +533,17 @@ namespace Crafter {
|
|||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v))));
|
||||
__m256i permIdx = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(permMask.data()));
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutevar8x32_epi32(_mm256_castps_si256(this->v), permIdx)));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v))));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||||
VectorF32<Len, Packing> A,
|
||||
|
|
@ -539,6 +586,7 @@ namespace Crafter {
|
|||
_mm256_mul_ps(C.v, fLenghtC.v),
|
||||
_mm256_mul_ps(D.v, fLenghtD.v)
|
||||
};
|
||||
#if defined(__AVX512F__)
|
||||
} else {
|
||||
VectorF32<1, 16> lenght = LengthNoShuffle(A, C, B, D);
|
||||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
|
|
@ -558,6 +606,7 @@ namespace Crafter {
|
|||
VectorF32<Len, Packing>(_mm512_mul_ps(C.v, fLenghtC.v)),
|
||||
VectorF32<Len, Packing>(_mm512_mul_ps(D.v, fLenghtD.v)),
|
||||
};
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -609,6 +658,7 @@ namespace Crafter {
|
|||
};
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||||
VectorF32<Len, Packing> A,
|
||||
VectorF32<Len, Packing> B,
|
||||
|
|
@ -629,6 +679,7 @@ namespace Crafter {
|
|||
_mm512_mul_ps(C.v, fLenghtC.v),
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||||
VectorF32<Len, Packing> A,
|
||||
|
|
@ -660,6 +711,7 @@ namespace Crafter {
|
|||
_mm256_mul_ps(A.v, fLenghtA.v),
|
||||
_mm256_mul_ps(B.v, fLenghtB.v),
|
||||
};
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
VectorF32<1, 16> lenght = LengthNoShuffle(A, B);
|
||||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
|
|
@ -673,6 +725,7 @@ namespace Crafter {
|
|||
_mm512_mul_ps(A.v, fLenghtA.v),
|
||||
_mm512_mul_ps(B.v, fLenghtB.v),
|
||||
};
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -712,6 +765,7 @@ namespace Crafter {
|
|||
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
constexpr static VectorF32<1, 15> Length(
|
||||
VectorF32<Len, Packing> A,
|
||||
VectorF32<Len, Packing> B,
|
||||
|
|
@ -720,6 +774,7 @@ namespace Crafter {
|
|||
VectorF32<1, 15> lenghtSq = LengthSq(A, B, C);
|
||||
return VectorF32<1, 15>(_mm512_sqrt_ps(lenghtSq.v));
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr static VectorF32<1, Packing*2> Length(
|
||||
VectorF32<Len, Packing> A,
|
||||
|
|
@ -730,8 +785,10 @@ namespace Crafter {
|
|||
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -762,6 +819,7 @@ namespace Crafter {
|
|||
return Dot(A, A, B, B, C, C, D, D);
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
constexpr static VectorF32<1, 15> LengthSq(
|
||||
VectorF32<Len, Packing> A,
|
||||
VectorF32<Len, Packing> B,
|
||||
|
|
@ -769,6 +827,7 @@ namespace Crafter {
|
|||
) requires(Len == 3 && Packing == 5) {
|
||||
return Dot(A, A, B, B, C, C);
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr static VectorF32<1, Packing*2> LengthSq(
|
||||
VectorF32<Len, Packing> A,
|
||||
|
|
@ -792,6 +851,7 @@ namespace Crafter {
|
|||
1,5,3,7,
|
||||
}}>();
|
||||
return vec.v;
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
|
||||
vec = vec.template Shuffle<{{
|
||||
|
|
@ -801,6 +861,7 @@ namespace Crafter {
|
|||
3,7,11,15
|
||||
}}>();
|
||||
return vec.v;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -955,6 +1016,7 @@ namespace Crafter {
|
|||
return row1;
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
constexpr static VectorF32<1, 15> Dot(
|
||||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||||
|
|
@ -1044,6 +1106,7 @@ namespace Crafter {
|
|||
|
||||
return row1;
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr static VectorF32<1, Packing*2> Dot(
|
||||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||||
|
|
@ -1058,6 +1121,7 @@ namespace Crafter {
|
|||
2,3, 6,7,
|
||||
}}>();
|
||||
return vec.v;
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
|
||||
vec = vec.template Shuffle<{{
|
||||
|
|
@ -1067,6 +1131,7 @@ namespace Crafter {
|
|||
10,11, 14,15
|
||||
}}>();
|
||||
return vec.v;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1083,8 +1148,10 @@ namespace Crafter {
|
|||
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1097,8 +1164,10 @@ namespace Crafter {
|
|||
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1172,6 +1241,7 @@ namespace Crafter {
|
|||
row1 = _mm256_add_ps(row1, row4);
|
||||
|
||||
return row1;
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
|
||||
__m512 mulB = _mm512_mul_ps(B0.v, B1.v);
|
||||
|
|
@ -1195,6 +1265,7 @@ namespace Crafter {
|
|||
row1 = _mm512_add_ps(row1, row4);
|
||||
|
||||
return row1;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1226,6 +1297,7 @@ namespace Crafter {
|
|||
row56Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
|
||||
|
||||
return _mm256_add_ps(row12Temp1, row56Temp1);
|
||||
#ifdef __AVX512F__
|
||||
} else {
|
||||
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
|
||||
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
|
||||
|
|
@ -1238,6 +1310,7 @@ namespace Crafter {
|
|||
row56Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
|
||||
|
||||
return _mm512_add_ps(row12Temp1, row56Temp1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
public:
|
||||
|
|
@ -1245,19 +1318,20 @@ namespace Crafter {
|
|||
template <std::array<bool, Len> ShuffleValues>
|
||||
constexpr static VectorF32<Len, Packing> Blend(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) {
|
||||
constexpr auto mask = VectorBase<Len, Packing, float>::template GetBlendMaskEpi32<ShuffleValues>();
|
||||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
|
||||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||||
return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b.v), mask));
|
||||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
#ifndef __AVX512BW__
|
||||
#ifndef __AVX512VL__
|
||||
static_assert(false, "No __AVX512BW__ and __AVX512VL__ support");
|
||||
#endif
|
||||
#endif
|
||||
return _mm256_castsi256_ps(_mm256_mask_blend_epi32(mask, _mm256_castps_si256(a.v), _mm256_castps_si256(b.v)));
|
||||
} else {
|
||||
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||||
return _mm256_castsi256_ps(_mm256_blend_epi32(_mm256_castps_si256(a.v), _mm256_castps_si256(b.v), mask));
|
||||
|
||||
#ifdef __AVX512F__
|
||||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
|
||||
return _mm512_castsi512_ps(_mm512_mask_blend_epi32(mask, _mm512_castps_si512(a.v), _mm512_castps_si512(b.v)));
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF32<Len, Packing> Rotate(VectorF32<3, Packing> v, VectorF32<4, Packing> q) requires(Len == 3) {
|
||||
VectorF32<3, Packing> qv(q);
|
||||
|
|
@ -1314,7 +1388,7 @@ namespace Crafter {
|
|||
export template <std::uint32_t Len, std::uint32_t Packing>
|
||||
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
|
||||
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
|
||||
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.Store();
|
||||
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.template Store<float>();
|
||||
std::string out = "{";
|
||||
for(std::uint32_t i = 0; i < Packing; i++) {
|
||||
out += "{";
|
||||
|
|
@ -1327,5 +1401,4 @@ struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::st
|
|||
out += "}";
|
||||
return std::formatter<std::string>::format(out, ctx);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
};
|
||||
Loading…
Add table
Add a link
Reference in a new issue