all tests passing again
This commit is contained in:
parent
d09155736f
commit
8999c8b9ec
7 changed files with 912 additions and 914 deletions
|
|
@ -18,14 +18,15 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|||
*/
|
||||
module;
|
||||
#ifdef __x86_64
|
||||
#include <immintrin.h>
|
||||
#include <immintrin.>
|
||||
#endif
|
||||
export module Crafter.Math:VectorF32;
|
||||
import std;
|
||||
import :Vector;
|
||||
import :Common;
|
||||
|
||||
namespace Crafter {
|
||||
export template <std::uint32_t Len, std::uint32_t Packing, std::uint32_t Repeats>
|
||||
export template <std::uint32_t Len, std::uint32_t Packing>
|
||||
struct VectorF32 {
|
||||
#ifdef __AVX512F__
|
||||
static constexpr std::uint32_t MaxSize = 16;
|
||||
|
|
@ -45,8 +46,6 @@ namespace Crafter {
|
|||
return 16;
|
||||
}
|
||||
static_assert(Len * Packing <= 16, "Len * Packing is larger than supported max size of 16");
|
||||
static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4");
|
||||
static_assert(Len * Packing * Repeats <= 16, "Len * Packing * Repeats is larger than supported max of 16");
|
||||
#else
|
||||
if constexpr (Len * Packing <= 4) {
|
||||
return 4;
|
||||
|
|
@ -55,17 +54,12 @@ namespace Crafter {
|
|||
return 8;
|
||||
}
|
||||
static_assert(Len * Packing <= 8, "Len * Packing is larger than supported max size of 8");
|
||||
static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4");
|
||||
static_assert(Len * Packing * Repeats <= 8, "Len * Packing * Repeats is larger than supported max of 8");
|
||||
#endif
|
||||
}
|
||||
static consteval std::uint32_t GetTotalSize() {
|
||||
return GetAlignment() * Repeats;
|
||||
}
|
||||
|
||||
using VectorType = std::conditional_t<
|
||||
(GetTotalSize() == 16), __m512,
|
||||
std::conditional_t<(GetTotalSize() == 8), __m256, __m128>
|
||||
(Len * Packing > 8), __m512h,
|
||||
std::conditional_t<(Len * Packing > 4), __m256h, __m128>
|
||||
>;
|
||||
|
||||
VectorType v;
|
||||
|
|
@ -107,91 +101,96 @@ namespace Crafter {
|
|||
}
|
||||
constexpr void Load(const _Float16* vB) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
v = _mm_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
v = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
v = _mm256_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
} else {
|
||||
v = _mm512_cvtph_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB)));
|
||||
v = _mm512_cvtps_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB)));
|
||||
}
|
||||
}
|
||||
constexpr void Store(_Float16* vB) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
_mm_storeu_si128(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm_storeu_si128(_mm_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
_mm_storeu_si128(_mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm_storeu_si128(_mm256_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
} else {
|
||||
_mm256_storeu_si256(_mm512_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm256_storeu_si256(_mm512_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
}
|
||||
}
|
||||
|
||||
template <std::uint32_t VLen, std::uint32_t VAlign>
|
||||
constexpr Vector<float, VLen, VAlign> Store() const {
|
||||
Vector<float, VLen, VAlign> returnVec;
|
||||
Store(returnVec.v);
|
||||
return returnVec;
|
||||
constexpr std::array<float, Alignment> Store() const {
|
||||
std::array<float, Alignment> returnArray;
|
||||
Store(returnArray.data());
|
||||
return returnArray;
|
||||
}
|
||||
|
||||
template <std::uint32_t BLen, std::uint32_t BPacking, std::uint32_t BRepeats>
|
||||
constexpr operator VectorF32<BLen, BPacking, BRepeats>() const {
|
||||
if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm256_castps256_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps512_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps512_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm256_castps128_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps128_ps512(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps256_ps512(v));
|
||||
template <std::uint32_t BLen, std::uint32_t BPacking>
|
||||
constexpr operator VectorF32<BLen, BPacking>() const {
|
||||
if constexpr (Len == BLen) {
|
||||
if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking>(_mm256_castps256_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking>(_mm256_castps128_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps128_ps512(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps256_ps512(v));
|
||||
} else {
|
||||
return VectorF32<BLen, BPacking>(v);
|
||||
}
|
||||
} else if constexpr (BLen <= Len) {
|
||||
return this->template ExtractLo<BLen>();
|
||||
} else {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(v);
|
||||
return VectorF32<BLen, BPacking>(v);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator+(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_add_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_add_ps(v, b.v));
|
||||
|
||||
constexpr VectorF32<Len, Packing> operator+(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_add_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_add_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_add_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_add_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_sub_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_sub_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator-(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_sub_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_sub_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_sub_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_sub_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator*(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_mul_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_mul_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator*(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_mul_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_mul_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_mul_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_mul_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator/(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_div_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_div_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator/(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_div_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_div_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_div_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_div_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
constexpr void operator+=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator+=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_add_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -201,7 +200,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator-=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator-=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_sub_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -211,7 +210,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator*=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator*=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_mul_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -221,7 +220,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator/=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator/=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_div_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -231,60 +230,48 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator+(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this + vB;
|
||||
constexpr VectorF32<Len, Packing> operator+(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this + vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this - vB;
|
||||
constexpr VectorF32<Len, Packing> operator-(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this - vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator*(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this * vB;
|
||||
constexpr VectorF32<Len, Packing> operator*(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this * vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator/(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this / vB;
|
||||
constexpr VectorF32<Len, Packing> operator/(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this / vB;
|
||||
}
|
||||
|
||||
constexpr void operator+=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this += vB;
|
||||
constexpr void operator+=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this += vB;
|
||||
}
|
||||
|
||||
constexpr void operator-=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this -= vB;
|
||||
constexpr void operator-=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this -= vB;
|
||||
}
|
||||
|
||||
constexpr void operator*=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this *= vB;
|
||||
constexpr void operator*=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this *= vB;
|
||||
}
|
||||
|
||||
constexpr void operator/=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this /= vB;
|
||||
constexpr void operator/=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this /= vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(){
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m128i sign_mask = _mm_loadu_si128(reinterpret_cast<const __m128i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_castsi128_ps(_mm_xor_si128(sign_mask, _mm_castps_si128(v))));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m256i sign_mask = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_xor_si256(sign_mask, _mm256_castps_si256(v))));
|
||||
} else {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m512i sign_mask = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_xor_si512(sign_mask, _mm512_castps_si512(v))));
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing> operator-(){
|
||||
return Negate<GetAllTrue<Len>()>();
|
||||
}
|
||||
|
||||
constexpr bool operator==(VectorF32<Len, Packing, Repeats> b) const {
|
||||
|
|
@ -335,71 +322,47 @@ namespace Crafter {
|
|||
return Dot(*this, *this);
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> Cos() requires(Len == 3) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_cos_ps(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_cos_ps(v));
|
||||
template <const std::array<std::uint8_t, Len> ShuffleValues>
|
||||
constexpr VectorF32<Len, Packing> Shuffle() {
|
||||
if constexpr(CheckEpi32Shuffle<ShuffleValues>()) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
}
|
||||
} else if constexpr(CheckEpi8Shuffle<ShuffleValues>()){
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(v)), _mm512_castsi256_si512(shuffleVec)))));
|
||||
} else {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec)));
|
||||
}
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_cos_ps(v));
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi16(permIdx, _mm256_castps_si256(v))));
|
||||
} else {
|
||||
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi16(permIdx, _mm512_castps_si512(v))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> Sin() requires(Len == 3) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_sin_ps(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_sin_ps(v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_sin_ps(v));
|
||||
}
|
||||
}
|
||||
|
||||
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() {
|
||||
constexpr std::uint32_t val =
|
||||
(A & 0x3) |
|
||||
((B & 0x3) << 2) |
|
||||
((C & 0x3) << 4) |
|
||||
((D & 0x3) << 6);
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), val)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), val)));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm_512castps_si512(v), val)));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0,
|
||||
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1
|
||||
>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() requires(Repeats == 2) {
|
||||
constexpr std::uint8_t shuffleMask[] {
|
||||
A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0,
|
||||
A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1,
|
||||
};
|
||||
__m256 shuffleVec = _mm256_loadu_epi8(shuffleMask);
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(v), shuffleVec)));
|
||||
}
|
||||
|
||||
template <
|
||||
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
|
||||
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1,
|
||||
std::uint8_t A2, std::uint8_t B2, std::uint8_t C2, std::uint8_t D2, std::uint8_t E2, std::uint8_t F2, std::uint8_t G2, std::uint8_t H2,
|
||||
std::uint8_t A3, std::uint8_t B3, std::uint8_t C3, std::uint8_t D3, std::uint8_t E3, std::uint8_t F3, std::uint8_t G3, std::uint8_t H3
|
||||
>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() requires(Repeats == 4) {
|
||||
constexpr std::uint8_t shuffleMask[] {
|
||||
A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0,
|
||||
A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1,
|
||||
A2,A2,A2,A2,B2,B2,B2,B2,C2,C2,C2,C2,D2,D2,D2,D2,
|
||||
A3,A3,A3,A3,B3,B3,B3,B3,C3,C3,C3,C3,D3,D3,D3,D3,
|
||||
};
|
||||
__m512 shuffleVec = _mm512_loadu_epi8(shuffleMask);
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec)));
|
||||
}
|
||||
|
||||
static constexpr VectorF32<Len, Packing, Repeats> MulitplyAdd(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b, VectorF32<Len, Packing, Repeats> add) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue