Crafter.Math/interfaces/Crafter.Math-VectorF32.cppm

837 lines
No EOL
43 KiB
C++
Executable file

/*
Crafter®.Math
Copyright (C) 2026 Catcrafts®
catcrafts.net
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License version 3.0 as published by the Free Software Foundation;
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
module;
#ifdef __x86_64
#include <immintrin.>
#endif
export module Crafter.Math:VectorF32;
import std;
import :Vector;
import :Common;
namespace Crafter {
export template <std::uint32_t Len, std::uint32_t Packing>
struct VectorF32 {
#ifdef __AVX512F__
static constexpr std::uint32_t MaxSize = 16;
#else
static constexpr std::uint32_t MaxSize = 8;
#endif
static constexpr std::uint32_t MaxElement = 4;
static consteval std::uint32_t GetAlignment() {
#ifdef __AVX512F__
if constexpr (Len * Packing <= 4) {
return 4;
}
if constexpr (Len * Packing <= 8) {
return 8;
}
if constexpr (Len * Packing <= 16) {
return 16;
}
static_assert(Len * Packing <= 16, "Len * Packing is larger than supported max size of 16");
#else
if constexpr (Len * Packing <= 4) {
return 4;
}
if constexpr (Len * Packing <= 8) {
return 8;
}
static_assert(Len * Packing <= 8, "Len * Packing is larger than supported max size of 8");
#endif
}
using VectorType = std::conditional_t<
(Len * Packing > 8), __m512h,
std::conditional_t<(Len * Packing > 4), __m256h, __m128>
>;
VectorType v;
constexpr VectorF32() = default;
constexpr VectorF32(VectorType v) : v(v) {}
constexpr VectorF32(const float* vB) {
Load(vB);
};
constexpr VectorF32(const _Float16* vB) {
Load(vB);
};
constexpr VectorF32(float val) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_set1_ps(val);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_set1_ps(val);
} else {
v = _mm512_set1_ps(val);
}
};
constexpr void Load(const float* vB) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_loadu_ps(vB);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_loadu_ps(vB);
} else {
v = _mm512_loadu_ps(vB);
}
}
constexpr void Store(float* vB) const {
if constexpr(std::is_same_v<VectorType, __m128>) {
_mm_storeu_ps(vB, v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
_mm256_storeu_ps(vB, v);
} else {
_mm512_storeu_ps(vB, v);
}
}
constexpr void Load(const _Float16* vB) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
} else {
v = _mm512_cvtps_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB)));
}
}
constexpr void Store(_Float16* vB) const {
if constexpr(std::is_same_v<VectorType, __m128>) {
_mm_storeu_si128(_mm_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
_mm_storeu_si128(_mm256_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
} else {
_mm256_storeu_si256(_mm512_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
}
}
constexpr std::array<float, Alignment> Store() const {
std::array<float, Alignment> returnArray;
Store(returnArray.data());
return returnArray;
}
template <std::uint32_t BLen, std::uint32_t BPacking>
constexpr operator VectorF32<BLen, BPacking>() const {
if constexpr (Len == BLen) {
if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
return VectorF32<BLen, BPacking>(_mm256_castps256_ps128(v));
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
return VectorF32<BLen, BPacking>(_mm512_castps512_ps128(v));
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
return VectorF32<BLen, BPacking>(_mm512_castps512_ps256(v));
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
return VectorF32<BLen, BPacking>(_mm256_castps128_ps256(v));
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
return VectorF32<BLen, BPacking>(_mm512_castps128_ps512(v));
} else if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
return VectorF32<BLen, BPacking>(_mm512_castps256_ps512(v));
} else {
return VectorF32<BLen, BPacking>(v);
}
} else if constexpr (BLen <= Len) {
return this->template ExtractLo<BLen>();
} else {
return VectorF32<BLen, BPacking>(v);
}
}
constexpr VectorF32<Len, Packing> operator+(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF32<Len, Packing>(_mm_add_ph(v, b.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF32<Len, Packing>(_mm256_add_ph(v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_add_ph(v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator-(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF32<Len, Packing>(_mm_sub_ph(v, b.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF32<Len, Packing>(_mm256_sub_ph(v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_sub_ph(v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator*(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF32<Len, Packing>(_mm_mul_ph(v, b.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF32<Len, Packing>(_mm256_mul_ph(v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_mul_ph(v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator/(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF32<Len, Packing>(_mm_div_ph(v, b.v));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF32<Len, Packing>(_mm256_div_ph(v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_div_ph(v, b.v));
}
}
constexpr void operator+=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_add_ps(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_add_ps(v, b.v);
} else {
v = _mm512_add_ps(v, b.v);
}
}
constexpr void operator-=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_sub_ps(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_sub_ps(v, b.v);
} else {
v = _mm512_sub_ps(v, b.v);
}
}
constexpr void operator*=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_mul_ps(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_mul_ps(v, b.v);
} else {
v = _mm512_mul_ps(v, b.v);
}
}
constexpr void operator/=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
v = _mm_div_ps(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
v = _mm256_div_ps(v, b.v);
} else {
v = _mm512_div_ps(v, b.v);
}
}
constexpr VectorF32<Len, Packing> operator+(float b) {
VectorF32<Len, Packing> vB(b);
return *this + vB;
}
constexpr VectorF32<Len, Packing> operator-(float b) {
VectorF32<Len, Packing> vB(b);
return *this - vB;
}
constexpr VectorF32<Len, Packing> operator*(float b) {
VectorF32<Len, Packing> vB(b);
return *this * vB;
}
constexpr VectorF32<Len, Packing> operator/(float b) {
VectorF32<Len, Packing> vB(b);
return *this / vB;
}
constexpr void operator+=(float b) {
VectorF32<Len, Packing> vB(b);
*this += vB;
}
constexpr void operator-=(float b) {
VectorF32<Len, Packing> vB(b);
*this -= vB;
}
constexpr void operator*=(float b) {
VectorF32<Len, Packing> vB(b);
*this *= vB;
}
constexpr void operator/=(float b) {
VectorF32<Len, Packing> vB(b);
*this /= vB;
}
constexpr VectorF32<Len, Packing> operator-(){
return Negate<GetAllTrue<Len>()>();
}
constexpr bool operator==(VectorF32<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128>) {
return _mm_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 255;
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return _mm256_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 65535;
} else {
return _mm512_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 4294967295;
}
}
constexpr bool operator!=(VectorF32<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128>) {
return _mm_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 255;
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return _mm256_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 65535;
} else {
return _mm512_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 4294967295;
}
}
constexpr void Normalize() {
if constexpr(std::is_same_v<VectorType, __m128>) {
float dot = LengthSq();
__m128 vec = _mm_set1_ps(dot);
__m128 sqrt = _mm_rsqrt_ps(vec);
v = _mm_div_ps(v, sqrt);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
float dot = LengthSq();
__m256 vec = _mm256_set1_ps(dot);
__m256 sqrt = _mm256_rsqrt_ps(vec);
v = _mm256_div_ps(v, sqrt);
} else {
float dot = LengthSq();
__m512 vec = _mm512_set1_ps(dot);
__m512 sqrt = _mm512_rsqrt14_ps(vec);
v = _mm512_div_ps(v, sqrt);
}
}
constexpr float Length() const {
float Result = LengthSq();
return std::sqrtf(Result);
}
constexpr float LengthSq() const {
return Dot(*this, *this);
}
template <const std::array<std::uint8_t, Len> ShuffleValues>
constexpr VectorF32<Len, Packing> Shuffle() {
if constexpr(CheckEpi32Shuffle<ShuffleValues>()) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), GetShuffleMaskEpi32<ShuffleValues>())));
} else {
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(v), GetShuffleMaskEpi32<ShuffleValues>())));
}
} else if constexpr(CheckEpi8Shuffle<ShuffleValues>()){
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(v)), _mm512_castsi256_si512(shuffleVec)))));
} else {
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec)));
}
} else {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi16(permIdx, _mm256_castps_si256(v))));
} else {
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi16(permIdx, _mm512_castps_si512(v))));
}
}
}
static constexpr VectorF32<Len, Packing, Repeats> MulitplyAdd(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b, VectorF32<Len, Packing, Repeats> add) {
if constexpr(std::is_same_v<VectorType, __m128>) {
return VectorF32<Len, Packing, Repeats>(_mm_fmadd_ps(a, b, add));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return VectorF32<Len, Packing, Repeats>(_mm256_fmadd_ps(a, b, add));
} else {
return VectorF32<Len, Packing, Repeats>(_mm512_fmadd_ps(a, b, add));
}
}
static constexpr VectorF32<Len, Packing, Repeats> MulitplySub(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b, VectorF32<Len, Packing, Repeats> sub) {
if constexpr(std::is_same_v<VectorType, __m128>) {
return VectorF32<Len, Packing, Repeats>(_mm_fmsub_ps(a, b, sub));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return VectorF32<Len, Packing, Repeats>(_mm256_fmsub_ps(a, b, sub));
} else {
return VectorF32<Len, Packing, Repeats>(_mm512_fmsub_ps(a, b, sub));
}
}
constexpr static VectorF32<Len, Packing, Repeats> Cross(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b) requires(Len == 3) {
if constexpr(Len == 3) {
if constexpr(Repeats == 1) {
__m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(b.v), 0b01'10'00'11));
__m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(a.v), 0b01'10'00'11));
__m128 result = _mm_mul_ps(row3, row4);
__m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(a.v), 0b10'00'01'11));
__m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(b.v), 0b10'00'01'11));
return _mm_fmsub_ps(row1,row2,result);
}
if constexpr(Repeats == 2) {
__m256 row4 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(b.v), 0b01'10'00'11));
__m256 row3 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(a.v), 0b01'10'00'11));
__m256 result = _mm256_mul_ps(row3, row4);
__m256 row1 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(a.v), 0b10'00'01'11));
__m256 row2 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(b.v), 0b10'00'01'11));
return _mm256_fmsub_ps(row1,row2,result);
}
if constexpr(Repeats == 4) {
__m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(b.v), 0b01'10'00'11));
__m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(a.v), 0b01'10'00'11));
__m512 result = _mm512_mul_ps(row3, row4);
__m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(a.v), 0b10'00'01'11));
__m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(b.v), 0b10'00'01'11));
return _mm512_fmsub_ps(row1,row2,result);
}
}
}
constexpr static float Dot(VectorF32<Len, 1, 1> a, VectorF32<Len, 1, 1> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
union UN {
float f;
int i;
};
UN val;
val.i = _mm_extract_ps(_mm_dp_ps(a.v, b.v, 0b01110111), 0);
return val.f;
} else if constexpr(std::is_same_v<VectorType, __m256>) {
union UN {
float f;
int i;
};
UN val;
val.i = _mm_extract_epi32(_mm256_castsi256_si128(_mm256_castps_si256(_mm256_dp_ps(a.v, b.v, 0b01110111))), 0);
return val.f;
} else {
__m512 mul = _mm512_mul_ps(a.v, b.v);
return _mm512_reduce_add_ps(mul);
}
}
constexpr static std::tuple<VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>> Normalize(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> B,
VectorF32<Len, Packing, Repeats> C,
VectorF32<Len, Packing, Repeats> D
) requires(Packing == 1) {
if constexpr(std::is_same_v<VectorType, __m128>) {
VectorF32<Len, Packing, Repeats> lenght = Length(A, B, C, D);
constexpr float oneArr[] {1, 1, 1, 1};
__m128 one = _mm_loadu_ps(oneArr);
__m128 fLenght = _mm_div_ps(one, lenght.v);
__m128 fLenghtA = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b00'00'00'00));
__m128 fLenghtB = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b01'01'01'01));
__m128 fLenghtC = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b10'10'10'10));
__m128 fLenghtD = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b11'11'11'11));
return {
_mm_mul_ps(A.v, fLenghtA),
_mm_mul_ps(B.v, fLenghtB),
_mm_mul_ps(C.v, fLenghtC),
_mm_mul_ps(D.v, fLenghtD),
};
} else if constexpr(std::is_same_v<VectorType, __m256>) {
VectorF32<Len, Packing, Repeats> lenght = Length(A, B, C, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m256 one = _mm256_loadu_ps(oneArr);
__m256 fLenght = _mm256_div_ps(one, lenght.v);
__m256 fLenghtA = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b00'00'00'00));
__m256 fLenghtB = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b01'01'01'01));
__m256 fLenghtC = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b10'10'10'10));
__m256 fLenghtD = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b11'11'11'11));
return {
_mm256_mul_ps(A.v, fLenghtA),
_mm256_mul_ps(B.v, fLenghtB),
_mm256_mul_ps(C.v, fLenghtC),
_mm256_mul_ps(D.v, fLenghtD),
};
} else {
VectorF32<Len, Packing, Repeats> lenght = Length(A, B, C, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512 one = _mm512_loadu_ps(oneArr);
__m512 fLenght = _mm512_div_ps(one, lenght.v);
__m512 fLenghtA = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b00'00'00'00));
__m512 fLenghtB = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b01'01'01'01));
__m512 fLenghtC = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b10'10'10'10));
__m512 fLenghtD = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b11'11'11'11));
return {
_mm512_mul_ps(A.v, fLenghtA),
_mm512_mul_ps(B.v, fLenghtB),
_mm512_mul_ps(C.v, fLenghtC),
_mm512_mul_ps(D.v, fLenghtD),
};
}
}
constexpr static std::tuple<VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>, VectorF32<Len, Packing, Repeats>> Normalize(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> C
) requires(Packing == 2) {
if constexpr(std::is_same_v<VectorType, __m128>) {
VectorF32<Len, Packing, Repeats> lenght = Length(A, C);
constexpr float oneArr[] {1, 1, 1, 1};
__m128 one = _mm_loadu_ps(oneArr);
__m128 fLenght = _mm_div_ps(one, lenght.v);
__m128 fLenghtA = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b00'00'01'01));
__m128 fLenghtC = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b10'10'11'11));
return {
_mm_mul_ps(A.v, fLenghtA),
_mm_mul_ps(C.v, fLenghtC),
};
} else if constexpr(std::is_same_v<VectorType, __m256>) {
VectorF32<Len, Packing, Repeats> lenght = Length(A, C);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m256 one = _mm256_loadu_ps(oneArr);
__m256 fLenght = _mm256_div_ps(one, lenght.v);
__m256 fLenghtA = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b00'00'01'01));
__m256 fLenghtC = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b10'10'11'11));
return {
_mm256_mul_ps(A.v, fLenghtA),
_mm256_mul_ps(C.v, fLenghtC),
};
} else {
VectorF32<Len, Packing, Repeats> lenght = Length(A, C);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512 one = _mm512_loadu_ps(oneArr);
__m512 fLenght = _mm512_div_ps(one, lenght.v);
__m512 fLenghtA = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b00'00'01'01));
__m512 fLenghtC = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b10'10'11'11));
return {
_mm512_mul_ps(A.v, fLenghtA),
_mm512_mul_ps(C.v, fLenghtC),
};
}
}
constexpr static VectorF32<Len, Packing, Repeats> Length(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> B,
VectorF32<Len, Packing, Repeats> C,
VectorF32<Len, Packing, Repeats> D
) requires(Packing == 1) {
VectorF32<Len, Packing, Repeats> lenghtSq = LengthSq(A, B, C, D);
if constexpr(std::is_same_v<VectorType, __m128>) {
return VectorF32<Len, Packing, Repeats>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return VectorF32<Len, Packing, Repeats>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<Len, Packing, Repeats>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<Len, Packing, Repeats> Length(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> C
) requires(Packing == 2) {
VectorF32<Len, Packing, Repeats> lenghtSq = LengthSq(A, C);
if constexpr(std::is_same_v<VectorType, __m128>) {
return VectorF32<Len, Packing, Repeats>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
return VectorF32<Len, Packing, Repeats>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<Len, Packing, Repeats>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<Len, Packing, Repeats> LengthSq(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> B,
VectorF32<Len, Packing, Repeats> C,
VectorF32<Len, Packing, Repeats> D
) requires(Packing == 1) {
return Dot(A, A, B, B, C, C, D, D);
}
constexpr static VectorF32<Len, Packing, Repeats> LengthSq(
VectorF32<Len, Packing, Repeats> A,
VectorF32<Len, Packing, Repeats> C
) requires(Packing == 2) {
return Dot(A, A, C, C);
}
constexpr static VectorF32<Len, Packing, Repeats> Dot(
VectorF32<Len, Packing, Repeats> A0, VectorF32<Len, Packing, Repeats> A1,
VectorF32<Len, Packing, Repeats> B0, VectorF32<Len, Packing, Repeats> B1,
VectorF32<Len, Packing, Repeats> C0, VectorF32<Len, Packing, Repeats> C1,
VectorF32<Len, Packing, Repeats> D0, VectorF32<Len, Packing, Repeats> D1
) requires(Packing == 1) {
if constexpr(std::is_same_v<VectorType, __m128>) {
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
__m128 mulB = _mm_mul_ps(B0.v, B1.v);
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 B1 A2 B2
__m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A3 B3 A4 B4
__m128i row1TempTemp1 = row12Temp1;
__m128i row5TempTemp1 = row56Temp1;
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
__m128 mulD = _mm_mul_ps(D0.v, D1.v);
__m128i row34Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C1 D1 C2 D2
__m128i row78Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // C3 D3 C4 D4
row12Temp1 = _mm_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1
row34Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2
row56Temp1 = _mm_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3
row78Temp1 = _mm_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4
__m128 row1 = _mm_add_ps(row12Temp1, row34Temp1);
row1 = _mm_add_ps(row1, row56Temp1);
row1 = _mm_add_ps(row1, row78Temp1);
return row1;
} else if constexpr(std::is_same_v<VectorType, __m256>) {
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
__m256 mulB = _mm256_mul_ps(B0.v, B1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 B1 A2 B2
__m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A3 B3 A4 B4
__m256i row1TempTemp1 = row12Temp1;
__m256i row5TempTemp1 = row56Temp1;
__m256 mulC = _mm256_mul_ps(C0.v, C1.v);
__m256 mulD = _mm256_mul_ps(D0.v, D1.v);
__m256i row34Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C1 D1 C2 D2
__m256i row78Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // C3 D3 C4 D4
row12Temp1 = _mm256_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1
row34Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2
row56Temp1 = _mm256_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3
row78Temp1 = _mm256_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4
__m256 row1 = _mm256_add_ps(row12Temp1, row34Temp1);
row1 = _mm256_add_ps(row1, row56Temp1);
row1 = _mm256_add_ps(row1, row78Temp1);
return row1;
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulB = _mm512_mul_ps(B0.v, B1.v);
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 B1 A2 B2
__m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A3 B3 A4 B4
__m512i row1TempTemp1 = row12Temp1;
__m512i row5TempTemp1 = row56Temp1;
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
__m512 mulD = _mm512_mul_ps(D0.v, D1.v);
__m512i row34Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C1 D1 C2 D2
__m512i row78Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // C3 D3 C4 D4
row12Temp1 = _mm512_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1
row34Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2
row56Temp1 = _mm512_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3
row78Temp1 = _mm512_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4
__m512 row1 = _mm512_add_ps(row12Temp1, row34Temp1);
row1 = _mm512_add_ps(row1, row56Temp1);
row1 = _mm512_add_ps(row1, row78Temp1);
return row1;
}
}
constexpr static VectorF32<Len, Packing, Repeats> Dot(
VectorF32<Len, Packing, Repeats> A0, VectorF32<Len, Packing, Repeats> A1,
VectorF32<Len, Packing, Repeats> C0, VectorF32<Len, Packing, Repeats> C1
) requires(Packing == 2) {
if constexpr(std::is_same_v<VectorType, __m128>) {
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
__m128 mulB = _mm_mul_ps(C0.v, C1.v);
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 C1 A2 C2
__m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // B1 D1 B2 D2
__m128i row1TempTemp1 = row12Temp1;
__m128i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm_add_ps(row12Temp1, row56Temp1);
} else if constexpr(std::is_same_v<VectorType, __m256>) {
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
__m256 mulB = _mm256_mul_ps(C0.v, C1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 C1 A2 C2
__m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // B1 D1 B2 D2
__m256i row1TempTemp1 = row12Temp1;
__m256i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm256_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm256_add_ps(row12Temp1, row56Temp1);
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulB = _mm512_mul_ps(C0.v, C1.v);
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 C1 A2 C2
__m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // B1 D1 B2 D2
__m512i row1TempTemp1 = row12Temp1;
__m512i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm512_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm512_add_ps(row12Temp1, row56Temp1);
}
}
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D>
constexpr static VectorF32<Len, Packing, Repeats> Blend(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b) {
if constexpr(std::is_same_v<VectorType, __m128>) {
constexpr std::uint8_t val =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3);
return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b), val));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
constexpr std::uint8_t val =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3);
return _mm256_castsi256_ps(_mm256_blend_epi32(_mm256_castps_si256(a.v), _mm256_castps_si256(b), val));
} else {
constexpr std::uint16_t val =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3) |
((A & 1) << 4) |
((B & 1) << 5) |
((C & 1) << 6) |
((D & 1) << 7) |
((A & 1) << 8) |
((B & 1) << 9) |
((C & 1) << 10) |
((D & 1) << 11) |
((A & 1) << 12) |
((B & 1) << 13) |
((C & 1) << 14) |
((D & 1) << 15);
return _mm512_castsi512_ps(_mm512_mask_blend_epi32(val, _mm512_castps_si512(a.v), _mm512_castps_si512(b)));
}
}
constexpr static VectorF32<Len, Packing, Repeats> Rotate(VectorF32<3, 2, Repeats> v, VectorF32<4, 2, Repeats> q) requires(Len == 3 && Packing == 1) {
VectorF32<3, 2, Repeats> qv(q.v);
VectorF32<Len, Packing, Repeats> t = Cross(qv, v) * float(2);
return v + t * q.template Shuffle<3,3,3,3>(); + Cross(qv, t);
}
constexpr static VectorF32<4, 2, Repeats> RotatePivot(VectorF32<3, 2, Repeats> v, VectorF32<4, 2, Repeats> q, VectorF32<3, 2, Repeats> pivot) requires(Len == 3 && Packing == 1) {
VectorF32<Len, Packing, Repeats> translated = v - pivot;
VectorF32<3, 2, Repeats> qv(q.v);
VectorF32<Len, Packing, Repeats> t = Cross(qv, translated) * float(2);
VectorF32<Len, Packing, Repeats> rotated = translated + t * q.template Shuffle<3,3,3,3>() + Cross(qv, t);
return rotated + pivot;
}
constexpr static VectorF32<4, 2, Repeats> QuanternionFromEuler(VectorF32<3, 2, Repeats> EulerHalf) requires(Len == 3 && Packing == 1) {
VectorF32<3, 2, Repeats> sin = EulerHalf.Sin();
VectorF32<3, 2, Repeats> cos = EulerHalf.Cos();
VectorF32<3, 2, Repeats> row1 = cos.template Shuffle<0,0,0,0>();
row1 = VectorF32<3, 2, Repeats>::Blend<0,1,1,1>(sin, row1);
VectorF32<3, 2, Repeats> row2 = cos.template Shuffle<1,1,1,1>();
row2 = VectorF32<3, 2, Repeats>::Blend<1,0,1,1>(sin, row2);
row1 = row2;
VectorF32<3, 2, Repeats> row3 = cos.template Shuffle<2,2,2,2>();
row3 = VectorF32<3, 2, Repeats>::Blend<1,1,0,1>(sin, row3);
VectorF32<3, 2, Repeats> row4 = sin.template Shuffle<0,0,0,0>();
row4 = VectorF32<3, 2, Repeats>::Blend<1,0,0,0>(sin, row4);
if constexpr(std::is_same_v<VectorType, __m128>) {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m128i sign_mask = _mm_load_si128(reinterpret_cast<const __m128i*>(mask));
row4.v = (_mm_castsi128_ps(_mm_xor_si128(sign_mask, _mm_castps_si128(row4.v))));
} else if constexpr(std::is_same_v<VectorType, __m256>) {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
row4.v = (_mm256_castsi256_ps(_mm256_xor_si256(sign_mask, _mm256_castps_si256(row4.v))));
} else {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m512i sign_mask = _mm512_load_si512(reinterpret_cast<const __m256i*>(mask));
row4.v = (_mm512_castsi512_ps(_mm512_xor_si512(sign_mask, _mm512_castps_si512(row4.v))));
}
row1 = MulitplyAdd(row1, row3, row4);
VectorF32<3, 2, Repeats> row5 = sin.template Shuffle<1,1,1,1>();
row5 = VectorF32<3, 2, Repeats>::Blend<0,1,0,0>(sin, row5);
row1 *= row5;
VectorF32<3, 2, Repeats> row6 = sin.template Shuffle<2,2,2,2>();
row6 = VectorF32<3, 2, Repeats>::Blend<0,0,1,0>(sin, row6);
return row1 * row6;
}
};
}
export template <std::uint32_t Len, std::uint32_t Packing, std::uint32_t Repeats>
struct std::formatter<Crafter::VectorF32<Len, Packing, Repeats>> : std::formatter<std::string> {
auto format(const Crafter::VectorF32<Len, Packing, Repeats>& obj, format_context& ctx) const {
Crafter::Vector<float, Len * Packing * Repeats, 0> vec = obj.template Store<Len * Packing * Repeats, 0>();
std::string out;
for(std::uint32_t i = 0; i < Repeats; i++) {
out += "{";
for(std::uint32_t i2 = 0; i2 < Packing; i2++) {
out += "{";
for(std::uint32_t i3 = 0; i3 < Len; i3++) {
out += std::format("{}", static_cast<float>(vec.v[i * Packing * Len + i2 * Len + i3]));
if (i3 + 1 < Len) out += ",";
}
out += "}";
}
out += "}";
}
return std::formatter<std::string>::format(out, ctx);
}
};