/* Crafter®.Math Copyright (C) 2026 Catcrafts® catcrafts.net This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License version 3.0 as published by the Free Software Foundation; This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ module; #ifdef __x86_64 #include #endif export module Crafter.Math:VectorF16; import std; import :Vector; #ifdef __AVX512FP16__ namespace Crafter { export template struct VectorF16 { private: static consteval std::uint8_t GetAlingment() { if(Len * Packing <= 8) { return 8; } else if(Len * Packing <= 16) { return 16; } else if(Len * Packing <= 32) { return 32; } } using VectorType = std::conditional_t< (Len * Packing > 16), __m512h, std::conditional_t<(Len * Packing > 8), __m256h, __m128h> >; VectorType v; public: template friend class VectorF16; static constexpr std::uint32_t MaxSize = 32; static constexpr std::uint8_t Alignment = GetAlingment(); static_assert(Len * Packing <= MaxSize, "Len * Packing exceeds MaxSize"); constexpr VectorF16() = default; constexpr VectorF16(VectorType v) : v(v) {} constexpr VectorF16(const _Float16* vB) { Load(vB); }; constexpr VectorF16(_Float16 val) { if constexpr(std::is_same_v) { v = _mm_set1_ph(val); } else if constexpr(std::is_same_v) { v = _mm256_set1_ph(val); } else { v = _mm512_set1_ph(val); } }; constexpr void Load(const _Float16* vB) { if constexpr(std::is_same_v) { v = _mm_loadu_ph(vB); } else if constexpr(std::is_same_v) { v = _mm256_loadu_ph(vB); } else { v = _mm512_loadu_ph(vB); } } constexpr void Store(_Float16* vB) const { if constexpr(std::is_same_v) { _mm_storeu_ph(vB, v); } else if constexpr(std::is_same_v) { _mm256_storeu_ph(vB, v); } else { _mm512_storeu_ph(vB, v); } } constexpr Vector<_Float16, Len*Packing, Alignment> Store() const { Vector<_Float16, Len*Packing, Alignment> returnVec; Store(returnVec.v); return returnVec; } template constexpr operator VectorF16() const { if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm256_castph256_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm512_castph512_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm512_castph512_ph256(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm256_castph128_ph256(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { return VectorF16(_mm512_castph128_ph512(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { return VectorF16(_mm512_castph256_ph512(v)); } else { return VectorF16(v); } } constexpr VectorF16 operator+(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_add_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_add_ph(v, b.v)); } else { return VectorF16(_mm512_add_ph(v, b.v)); } } constexpr VectorF16 operator-(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_sub_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_sub_ph(v, b.v)); } else { return VectorF16(_mm512_sub_ph(v, b.v)); } } constexpr VectorF16 operator*(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_mul_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_mul_ph(v, b.v)); } else { return VectorF16(_mm512_mul_ph(v, b.v)); } } constexpr VectorF16 operator/(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_div_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_div_ph(v, b.v)); } else { return VectorF16(_mm512_div_ph(v, b.v)); } } constexpr void operator+=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_add_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_add_ph(v, b.v); } else { v = _mm512_add_ph(v, b.v); } } constexpr void operator-=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_sub_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_sub_ph(v, b.v); } else { v = _mm512_sub_ph(v, b.v); } } constexpr void operator*=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_mul_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_mul_ph(v, b.v); } else { v = _mm512_mul_ph(v, b.v); } } constexpr void operator/=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_div_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_div_ph(v, b.v); } else { v = _mm512_div_ph(v, b.v); } } constexpr VectorF16 operator+(_Float16 b) { VectorF16 vB(b); return *this + vB; } constexpr VectorF16 operator-(_Float16 b) { VectorF16 vB(b); return *this - vB; } constexpr VectorF16 operator*(_Float16 b) { VectorF16 vB(b); return *this * vB; } constexpr VectorF16 operator/(_Float16 b) { VectorF16 vB(b); return *this / vB; } constexpr void operator+=(_Float16 b) { VectorF16 vB(b); *this += vB; } constexpr void operator-=(_Float16 b) { VectorF16 vB(b); *this -= vB; } constexpr void operator*=(_Float16 b) { VectorF16 vB(b); *this *= vB; } constexpr void operator/=(_Float16 b) { VectorF16 vB(b); *this /= vB; } constexpr VectorF16 operator-(){ return Negate(); } constexpr bool operator==(VectorF16 b) const { if constexpr(std::is_same_v) { return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 255; } else if constexpr(std::is_same_v) { return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 65535; } else { return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 4294967295; } } constexpr bool operator!=(VectorF16 b) const { if constexpr(std::is_same_v) { return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 255; } else if constexpr(std::is_same_v) { return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 65535; } else { return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 4294967295; } } constexpr void Normalize() { if constexpr(std::is_same_v) { _Float16 dot = LengthSq(); __m128h vec = _mm_set1_ph(dot); __m128h sqrt = _mm_sqrt_ph(vec); v = _mm_div_ph(v, sqrt); } else if constexpr(std::is_same_v) { _Float16 dot = LengthSq(); __m256h vec = _mm256_set1_ph(dot); __m256h sqrt = _mm256_sqrt_ph(vec); v = _mm256_div_ph(v, sqrt); } else { _Float16 dot = LengthSq(); __m512h vec = _mm512_set1_ph(dot); __m512h sqrt = _mm512_sqrt_ph(vec); v = _mm512_div_ph(v, sqrt); } } constexpr _Float16 Length() const { _Float16 Result = LengthSq(); return std::sqrtf(Result); } constexpr _Float16 LengthSq() const { return Dot(*this, *this); } constexpr VectorF16 Cos() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); wide = cos_f32x8(wide); return VectorF16( _mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); wide = cos_f32x16(wide); return VectorF16( _mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); __m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); return VectorF16( _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } constexpr VectorF16 Sin() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); wide = sin_f32x8(wide); return VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); wide = sin_f32x16(wide); return VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); __m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); return VectorF16(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } std::tuple, VectorF16> SinCos() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); __m256 s, c; sincos_f32x8(wide, s, c); return { VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); __m512 s, c; sincos_f32x16(wide, s, c); return { VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m512 s_lo, c_lo, s_hi, c_hi; sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo); sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi); auto pack = [](__m256i lo_ph, __m256i hi_ph) { return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)); }; return { VectorF16(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))), VectorF16(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT))) }; } } template values> constexpr VectorF16 Negate() { std::array mask = GetNegateMask(); if constexpr(std::is_same_v) { return VectorF16(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data())))); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_castsi2568_ph(_mm256_xor_si256(_mm256_castph_si256(v), _mm_loadu_epi16(mask.data())))); } else { return VectorF16(_mm512_castsi512_ph(_mm512_xor_si256(_mm512_castph_si512(v), _mm_loadu_epi16(mask.data())))); } } template ShuffleValues> constexpr VectorF16 Shuffle() { if constexpr(VectorF16::CheckEpi32Shuffle()) { if constexpr(std::is_same_v) { return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), GetShuffleMaskEpi32()))); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), GetShuffleMaskEpi32()))); } else { return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(v), GetShuffleMaskEpi32()))); } } else { if constexpr(std::is_same_v) { constexpr std::array shuffleMask = GetShuffleMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); } else if constexpr(std::is_same_v) { constexpr std::array shuffleMask = GetShuffleMaskEpi8(); __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); return VectorF16(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(v)), _mm512_castsi256_si512(shuffleVec))))); } else { constexpr std::array shuffleMask = GetShuffleMaskEpi8(); __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec))); } } } static constexpr VectorF16 MulitplyAdd(VectorF16 a, VectorF16 b, VectorF16 add) { if constexpr(std::is_same_v) { return VectorF16(_mm_fmadd_ph(a.v, b.v, add.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_fmadd_ph(a.v, b.v, add.v)); } else { return VectorF16(_mm512_fmadd_ph(a.v, b.v, add.v)); } } static constexpr VectorF16 MulitplySub(VectorF16 a, VectorF16 b, VectorF16 sub) { if constexpr(std::is_same_v) { return VectorF16(_mm_fmsub_ph(a.v, b.v, sub.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_fmsub_ph(a.v, b.v, sub.v)); } else { return VectorF16(_mm512_fmsub_ph(a.v, b.v, sub.v)); } } constexpr static VectorF16 Cross(VectorF16 a, VectorF16 b) requires(Len == 3) { if constexpr(std::is_same_v) { constexpr std::array shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data()); __m128h row1 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec1)); __m128h row4 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec1)); constexpr std::array shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data()); __m128h row3 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec3)); __m128h row2 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec3)); __m128h result = _mm_mul_ph(row3, row4); return _mm_fmsub_ph(row1,row2,result); } else if constexpr (std::is_same_v) { constexpr std::array shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data())); __m256h row1 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec1))); __m256h row4 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec1))); constexpr std::array shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data())); __m256h row3 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec3))); __m256h row2 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec3))); __m256h result = _mm256_mul_ph(row3, row4); return _mm256_fmsub_ph(row1,row2,result); } else { constexpr std::array shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data()); __m512h row1 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec1)); __m512h row4 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec1)); constexpr std::array shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data()); __m512h row3 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec3)); __m512h row2 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec3)); __m512h result = _mm512_mul_ph(row3, row4); return _mm512_fmsub_ph(row1,row2,result); } } constexpr static _Float16 Dot(VectorF16 a, VectorF16 b) { if constexpr(std::is_same_v) { __m128h mul = _mm_mul_ph(a.v, b.v); return _mm_reduce_add_ph(mul); } else if constexpr(std::is_same_v) { __m256h mul = _mm256_mul_ph(a.v, b.v); return _mm256_reduce_add_ph(mul); } else { __m512h mul = _mm512_mul_ph(a.v, b.v); return _mm512_reduce_add_ph(mul); } } constexpr static std::tuple, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16> Normalize( VectorF16 A, VectorF16 B, VectorF16 C, VectorF16 D, VectorF16 E, VectorF16 F, VectorF16 G, VectorF16 H ) requires(Len == 8) { constexpr std::array shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0,0,0,0,0}}>(); constexpr std::array shuffleMaskB = GetShuffleMaskEpi8<{{1,1,1,1,1,1,1,1}}>(); constexpr std::array shuffleMaskC = GetShuffleMaskEpi8<{{2,2,2,2,2,2,2,2}}>(); constexpr std::array shuffleMaskD = GetShuffleMaskEpi8<{{3,3,3,3,3,3,3,3}}>(); constexpr std::array shuffleMaskE = GetShuffleMaskEpi8<{{4,4,4,4,4,4,4,4}}>(); constexpr std::array shuffleMaskF = GetShuffleMaskEpi8<{{5,5,5,5,5,5,5,5}}>(); constexpr std::array shuffleMaskG = GetShuffleMaskEpi8<{{6,6,6,6,6,6,6,6}}>(); constexpr std::array shuffleMaskH = GetShuffleMaskEpi8<{{7,7,7,7,7,7,7,7}}>(); if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, B, C, D, E, F, G, H); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); __m128h fLenght = _mm_div_ph(one, lenght.v); __m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data()); __m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA)); __m128i shuffleVecB = _mm_loadu_epi8(shuffleMaskB.data()); __m128h fLenghtB = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecB)); __m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC.data()); __m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC)); __m128i shuffleVecD = _mm_loadu_epi8(shuffleMaskD.data()); __m128h fLenghtD = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecD)); __m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data()); __m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE)); __m128i shuffleVecF = _mm_loadu_epi8(shuffleMaskF.data()); __m128h fLenghtF = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecF)); __m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG.data()); __m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG)); __m128i shuffleVecH = _mm_loadu_epi8(shuffleMaskH.data()); __m128h fLenghtH = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecH)); return { _mm_mul_ph(A.v, fLenghtA), _mm_mul_ph(B.v, fLenghtB), _mm_mul_ph(C.v, fLenghtC), _mm_mul_ph(D.v, fLenghtD), _mm_mul_ph(E.v, fLenghtE), _mm_mul_ph(F.v, fLenghtF), _mm_mul_ph(G.v, fLenghtG), _mm_mul_ph(H.v, fLenghtH) }; } else if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, B, C, D, E, F, G, H); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); __m256h fLenght = _mm256_div_ph(one, lenght.v); __m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data()); __m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA)); __m256i shuffleVecB = _mm256_loadu_epi8(shuffleMaskB.data()); __m256h fLenghtB = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecB)); __m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC.data()); __m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC)); __m256i shuffleVecD = _mm256_loadu_epi8(shuffleMaskD.data()); __m256h fLenghtD = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecD)); __m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data()); __m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE)); __m256i shuffleVecF = _mm256_loadu_epi8(shuffleMaskF.data()); __m256h fLenghtF = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecF)); __m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG.data()); __m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG)); __m256i shuffleVecH = _mm256_loadu_epi8(shuffleMaskH.data()); __m256h fLenghtH = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecH)); return { _mm256_mul_ph(A.v, fLenghtA), _mm256_mul_ph(B.v, fLenghtB), _mm256_mul_ph(C.v, fLenghtC), _mm256_mul_ph(D.v, fLenghtD), _mm256_mul_ph(E.v, fLenghtE), _mm256_mul_ph(F.v, fLenghtF), _mm256_mul_ph(G.v, fLenghtG), _mm256_mul_ph(H.v, fLenghtH) }; } else { VectorF16 lenght = Length(A, B, C, D, E, F, G, H); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); __m512h fLenght = _mm512_div_ph(one, lenght.v); __m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data()); __m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA)); __m512i shuffleVecB = _mm512_loadu_epi8(shuffleMaskB.data()); __m512h fLenghtB = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecB)); __m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC.data()); __m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC)); __m512i shuffleVecD = _mm512_loadu_epi8(shuffleMaskD.data()); __m512h fLenghtD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecD)); __m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data()); __m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE)); __m512i shuffleVecF = _mm512_loadu_epi8(shuffleMaskF.data()); __m512h fLenghtF = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecF)); __m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG.data()); __m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG)); __m512i shuffleVecH = _mm512_loadu_epi8(shuffleMaskH.data()); __m512h fLenghtH = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecH)); return { _mm512_mul_ph(A.v, fLenghtA), _mm512_mul_ph(B.v, fLenghtB), _mm512_mul_ph(C.v, fLenghtC), _mm512_mul_ph(D.v, fLenghtD), _mm512_mul_ph(E.v, fLenghtE), _mm512_mul_ph(F.v, fLenghtF), _mm512_mul_ph(G.v, fLenghtG), _mm512_mul_ph(H.v, fLenghtH) }; } } constexpr static std::tuple, VectorF16, VectorF16, VectorF16> Normalize( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4) { constexpr std::array shuffleMaskA = GetShuffleMaskEpi8<{{0,0,0,0}}>(); constexpr std::array shuffleMaskC = GetShuffleMaskEpi8<{{1,1,1,1}}>(); constexpr std::array shuffleMaskE = GetShuffleMaskEpi8<{{2,2,2,2}}>(); constexpr std::array shuffleMaskG = GetShuffleMaskEpi8<{{3,3,3,3}}>(); if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, C, E, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); __m128h fLenght = _mm_div_ph(one, lenght.v); __m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data()); __m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA)); __m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC.data()); __m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC)); __m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data()); __m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE)); __m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG.data()); __m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG)); return { _mm_mul_ph(A.v, fLenghtA), _mm_mul_ph(C.v, fLenghtC), _mm_mul_ph(E.v, fLenghtE), _mm_mul_ph(G.v, fLenghtG), }; } else if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, C, E, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); __m256h fLenght = _mm256_div_ph(one, lenght.v); __m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data()); __m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA)); __m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC.data()); __m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC)); __m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data()); __m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE)); __m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG.data()); __m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG)); return { _mm256_mul_ph(A.v, fLenghtA), _mm256_mul_ph(C.v, fLenghtC), _mm256_mul_ph(E.v, fLenghtE), _mm256_mul_ph(G.v, fLenghtG), }; } else { VectorF16 lenght = Length(A, C, E, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); __m512h fLenght = _mm512_div_ph(one, lenght.v); __m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data()); __m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA)); __m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC.data()); __m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC)); __m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data()); __m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE)); __m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG.data()); __m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG)); return { VectorF16(_mm512_mul_ph(A.v, fLenghtA)), VectorF16(_mm512_mul_ph(C.v, fLenghtC)), VectorF16(_mm512_mul_ph(E.v, fLenghtE)), VectorF16(_mm512_mul_ph(G.v, fLenghtG)), }; } } constexpr static std::tuple, VectorF16> Normalize( VectorF16 A, VectorF16 E ) requires(Len == 2) { constexpr std::array shuffleMaskA = GetShuffleMaskEpi8<{{0,0}}>(); constexpr std::array shuffleMaskE = GetShuffleMaskEpi8<{{1,1}}>(); if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); __m128h fLenght = _mm_div_ph(one, lenght.v); __m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA.data()); __m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA)); __m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE.data()); __m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE)); return { _mm_mul_ph(A.v, fLenghtA), _mm_mul_ph(E.v, fLenghtE), }; } else if constexpr(std::is_same_v) { VectorF16 lenght = Length(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); __m256h fLenght = _mm256_div_ph(one, lenght.v); __m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA.data()); __m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA)); __m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE.data()); __m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE)); return { _mm256_mul_ph(A.v, fLenghtA), _mm256_mul_ph(E.v, fLenghtE), }; } else { VectorF16 lenght = Length(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); __m512h fLenght = _mm512_div_ph(one, lenght.v); __m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA.data()); __m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA)); __m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE.data()); __m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE)); return { _mm512_mul_ph(A.v, fLenghtA), _mm512_mul_ph(E.v, fLenghtE), }; } } constexpr static VectorF16 Length( VectorF16 A, VectorF16 B, VectorF16 C, VectorF16 D, VectorF16 E, VectorF16 F, VectorF16 G, VectorF16 H ) requires(Len == 8) { VectorF16 lenghtSq = LengthSq(A, B, C, D, E, F, G, H); if constexpr(std::is_same_v) { return VectorF16(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16 Length( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4) { VectorF16 lenghtSq = LengthSq(A, C, E, G); if constexpr(std::is_same_v) { return VectorF16(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16 Length( VectorF16 A, VectorF16 E ) requires(Len == 2) { VectorF16 lenghtSq = LengthSq(A, E); if constexpr(std::is_same_v) { return VectorF16(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16 LengthSq( VectorF16 A, VectorF16 B, VectorF16 C, VectorF16 D, VectorF16 E, VectorF16 F, VectorF16 G, VectorF16 H ) requires(Len == 8) { return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H); } constexpr static VectorF16 LengthSq( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4) { return Dot(A, A, C, C, E, E, G, G); } constexpr static VectorF16 LengthSq( VectorF16 A, VectorF16 E ) requires(Len == 2) { return Dot(A, A, E, E); } constexpr static VectorF16 Dot( VectorF16 A0, VectorF16 A1, VectorF16 B0, VectorF16 B1, VectorF16 C0, VectorF16 C1, VectorF16 D0, VectorF16 D1, VectorF16 E0, VectorF16 E1, VectorF16 F0, VectorF16 F1, VectorF16 G0, VectorF16 G1, VectorF16 H0, VectorF16 H1 ) requires(Len == 8) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulB = _mm_mul_ph(B0.v, B1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4 __m128i row56Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8 __m128i row1TempTemp1 = row12Temp1; __m128i row5TempTemp1 = row56Temp1; __m128h mulC = _mm_mul_ph(C0.v, C1.v); __m128h mulD = _mm_mul_ph(D0.v, D1.v); __m128i row34Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulC), _mm_castph_si128(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4 __m128i row78Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8 row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2 row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4 row56Temp1 = _mm_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6 row78Temp1 = _mm_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8 __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128h mulF = _mm_mul_ph(F0.v, F1.v); __m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4 __m128i row56Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8 __m128i row1TempTemp2 = row12Temp2; __m128i row5TempTemp2 = row56Temp2; __m128h mulG = _mm_mul_ph(G0.v, G1.v); __m128h mulH = _mm_mul_ph(H0.v, H1.v); __m128i row34Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulG), _mm_castph_si128(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4 __m128i row78Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8 row12Temp2 = _mm_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2 row34Temp2 = _mm_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4 row56Temp2 = _mm_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6 row78Temp2 = _mm_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8 __m128h row1 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2 __m128h row3 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3 __m128h row4 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4 __m128h row5 = _mm_castsi128_ph(_mm_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5 __m128h row6 = _mm_castsi128_ph(_mm_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6 __m128h row7 = _mm_castsi128_ph(_mm_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7 __m128h row8 = _mm_castsi128_ph(_mm_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8 row1 = _mm_add_ph(row1, row2); row1 = _mm_add_ph(row1, row3); row1 = _mm_add_ph(row1, row4); row1 = _mm_add_ph(row1, row5); row1 = _mm_add_ph(row1, row6); row1 = _mm_add_ph(row1, row7); row1 = _mm_add_ph(row1, row8); return row1; } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulB = _mm256_mul_ph(B0.v, B1.v); __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4 __m256i row56Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8 __m256i row1TempTemp1 = row12Temp1; __m256i row5TempTemp1 = row56Temp1; __m256h mulC = _mm256_mul_ph(C0.v, C1.v); __m256h mulD = _mm256_mul_ph(D0.v, D1.v); __m256i row34Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulC), _mm256_castph_si256(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4 __m256i row78Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8 row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2 row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4 row56Temp1 = _mm256_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6 row78Temp1 = _mm256_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8 __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256h mulF = _mm256_mul_ph(F0.v, F1.v); __m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4 __m256i row56Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8 __m256i row1TempTemp2 = row12Temp2; __m256i row5TempTemp2 = row56Temp2; __m256h mulG = _mm256_mul_ph(G0.v, G1.v); __m256h mulH = _mm256_mul_ph(H0.v, H1.v); __m256i row34Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulG), _mm256_castph_si256(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4 __m256i row78Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8 row12Temp2 = _mm256_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2 row34Temp2 = _mm256_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4 row56Temp2 = _mm256_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6 row78Temp2 = _mm256_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8 __m256h row1 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2 __m256h row3 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3 __m256h row4 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4 __m256h row5 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5 __m256h row6 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6 __m256h row7 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7 __m256h row8 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8 row1 = _mm256_add_ph(row1, row2); row1 = _mm256_add_ph(row1, row3); row1 = _mm256_add_ph(row1, row4); row1 = _mm256_add_ph(row1, row5); row1 = _mm256_add_ph(row1, row6); row1 = _mm256_add_ph(row1, row7); row1 = _mm256_add_ph(row1, row8); return row1; } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulB = _mm512_mul_ph(B0.v, B1.v); __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4 __m512i row56Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8 __m512i row1TempTemp1 = row12Temp1; __m512i row5TempTemp1 = row56Temp1; __m512h mulC = _mm512_mul_ph(C0.v, C1.v); __m512h mulD = _mm512_mul_ph(D0.v, D1.v); __m512i row34Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulC), _mm512_castph_si512(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4 __m512i row78Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8 row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2 row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4 row56Temp1 = _mm512_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6 row78Temp1 = _mm512_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8 __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512h mulF = _mm512_mul_ph(F0.v, F1.v); __m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4 __m512i row56Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8 __m512i row1TempTemp2 = row12Temp2; __m512i row5TempTemp2 = row56Temp2; __m512h mulG = _mm512_mul_ph(G0.v, G1.v); __m512h mulH = _mm512_mul_ph(H0.v, H1.v); __m512i row34Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulG), _mm512_castph_si512(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4 __m512i row78Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8 row12Temp2 = _mm512_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2 row34Temp2 = _mm512_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4 row56Temp2 = _mm512_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6 row78Temp2 = _mm512_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8 __m512h row1 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2 __m512h row3 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3 __m512h row4 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4 __m512h row5 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5 __m512h row6 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6 __m512h row7 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7 __m512h row8 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8 row1 = _mm512_add_ph(row1, row2); row1 = _mm512_add_ph(row1, row3); row1 = _mm512_add_ph(row1, row4); row1 = _mm512_add_ph(row1, row5); row1 = _mm512_add_ph(row1, row6); row1 = _mm512_add_ph(row1, row7); row1 = _mm512_add_ph(row1, row8); return row1; } } constexpr static VectorF16 Dot( VectorF16 A0, VectorF16 A1, VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, VectorF16 G0, VectorF16 G1 ) requires(Len == 4) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulC = _mm_mul_ph(C0.v, C1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m128i row1TempTemp1 = row12Temp1; __m128i row5TempTemp1 = row34Temp1; __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128h mulG = _mm_mul_ph(G0.v, G1.v); __m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 row34Temp2 = _mm_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 __m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m128h row3 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m128h row4 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm_add_ph(row1, row2); row1 = _mm_add_ph(row1, row3); row1 = _mm_add_ph(row1, row4); return row1; } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulC = _mm256_mul_ph(C0.v, C1.v); __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m256i row1TempTemp1 = row12Temp1; __m256i row5TempTemp1 = row34Temp1; __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256h mulG = _mm256_mul_ph(G0.v, G1.v); __m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 row34Temp2 = _mm256_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 __m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m256h row3 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m256h row4 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm256_add_ph(row1, row2); row1 = _mm256_add_ph(row1, row3); row1 = _mm256_add_ph(row1, row4); return row1; } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulC = _mm512_mul_ph(C0.v, C1.v); __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m512i row1TempTemp1 = row12Temp1; __m512i row5TempTemp1 = row34Temp1; __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512h mulG = _mm512_mul_ph(G0.v, G1.v); __m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4 row34Temp2 = _mm512_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m512h row3 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m512h row4 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm512_add_ph(row1, row2); row1 = _mm512_add_ph(row1, row3); row1 = _mm512_add_ph(row1, row4); return row1; } } constexpr static VectorF16 Dot( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 ) requires(Len == 2) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m128i row12Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m128i row12Temp1Temp = row12Temp1; row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 return _mm_add_ph(row1, row2); } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m256i row12Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m256i row12Temp1Temp = row12Temp1; row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm256_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 return _mm256_add_ph(row1, row2); } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m512i row12Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m512i row12Temp1Temp = row12Temp1; row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm512_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 return _mm512_add_ph(row1, row2); } } template ShuffleValues> constexpr static VectorF16 Blend(VectorF16 a, VectorF16 b) { if constexpr(std::is_same_v) { return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b.v), GetBlendMaskEpi16())); } else if constexpr(std::is_same_v) { #ifndef __AVX512BW__ #ifndef __AVX512VL__ static_assert(false, "No __AVX512BW__ and __AVX512VL__ support"); #endif #endif return _mm256_castsi256_ph(_mm256_mask_blend_epi16(_mm256_castph_si256(a.v), _mm256_castph_si256(b.v), GetBlendMaskEpi16())); } else { return _mm512_castsi512_ph(_mm512_blend_epi16(GetBlendMaskEpi16(), _mm512_castph_si512(a.v), _mm512_castph_si512(b.v))); } } constexpr static VectorF16 Rotate(VectorF16<3, Packing> v, VectorF16<4, Packing> q) requires(Len == 3) { VectorF16<3, Packing> qv(q.v); VectorF16 t = Cross(qv, v) * _Float16(2); return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); } constexpr static VectorF16<4, 2> RotatePivot(VectorF16<3, Packing> v, VectorF16<4, Packing> q, VectorF16<3, Packing> pivot) requires(Len == 3) { VectorF16 translated = v - pivot; VectorF16<3, Packing> qv(q.v); VectorF16 t = Cross(qv, translated) * _Float16(2); VectorF16 rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); return rotated + pivot; } constexpr static VectorF16<4, Packing> QuanternionFromEuler(VectorF16<3, Packing> EulerHalf) requires(Len == 4) { std::tuple, VectorF16<3, Packing>> sinCos = EulerHalf.SinCos(); VectorF16<4, Packing> sin = std::get<0>(sinCos); VectorF16<4, Packing> cos = std::get<1>(sinCos); VectorF16<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>(); row1 = Blend<{{0,1,1,1}}>(sin, row1); VectorF16<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>(); row2 = Blend<{{1,0,1,1}}>(sin, row2); row1 *= row2; VectorF16<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>(); row3 = Blend<{{1,1,0,1}}>(sin, row3); row1 *= row3; VectorF16<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>(); row4 = Blend<{{0,1,1,1}}>(cos, row4); VectorF16<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>(); row5 = Blend<{{1,0,1,1}}>(cos, row5); row4 *= row5; VectorF16<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>(); row6 = Blend<{{1,1,0,1}}>(cos, row6); row6 = row6.template Negate<{{true,false,true,false}}>(); row1 = MulitplyAdd(row4, row6, row1); return row1; } private: template values> static consteval std::array GetNegateMask() { std::array mask; for(std::uint8_t i = 0; i < Len; i++) { if(values[i]) { mask[i] = 0b1000000000000000; } else { mask[i] = 0; } } return mask; } static consteval std::array GetNegateMaskAll() { std::array mask; for(std::uint8_t i = 0; i < Len; i++) { mask[i] = 0b1000000000000000; } return mask; } template ShuffleValues> static consteval bool GetShuffleMaskEpi32() { std::uint8_t mask = 0; for(std::uint8_t i = 0; i < std::min(Len, std::uint32_t(8)); i+=2) { mask = mask | (ShuffleValues[i] & 0b11) << i; } return mask; } template ShuffleValues> static consteval std::array GetShuffleMaskEpi8() requires (std::is_same_v){ std::array shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2); shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2); } } return shuffleMask; } template ShuffleValues> static consteval std::array GetShuffleMaskEpi8() requires (std::is_same_v){ std::array shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2); shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2); } } return shuffleMask; } template ShuffleValues> static consteval std::array GetShuffleMaskEpi8() requires (std::is_same_v){ std::array shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2); shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2); } } return shuffleMask; } static consteval std::array GetAllTrue() { std::array arr{}; arr.fill(true); return arr; } template ShuffleValues> static consteval bool CheckEpi32Shuffle() { for(std::uint8_t i = 1; i < Len; i+=2) { if(ShuffleValues[i-1] != ShuffleValues[i] - 1) { return false; } } for(std::uint8_t i = 0; i < Len; i++) { for(std::uint8_t i2 = 0; i2 < Len; i2 += 8) { if(ShuffleValues[i] != ShuffleValues[i2]) { return false; } } } return true; } template ShuffleValues> static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint8_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint16_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint32_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } static constexpr float two_over_pi = 0.6366197723675814f; static constexpr float pi_over_2_hi = 1.5707963267341256f; static constexpr float pi_over_2_lo = 6.077100506506192e-11f; // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... static constexpr float c0 = 1.0f; static constexpr float c2 = -0.4999999642372f; static constexpr float c4 = 0.0416666418707f; static constexpr float c6 = -0.0013888397720f; static constexpr float c8 = 0.0000248015873f; static constexpr float c10 = -0.0000002752258f; // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) static constexpr float s1 = -0.1666666641831f; static constexpr float s3 = 0.0083333293858f; static constexpr float s5 = -0.0001984090955f; static constexpr float s7 = 0.0000027526372f; static constexpr float s9 = -0.0000000239013f; // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm256_cvtps_epi32(fq); r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); r2 = _mm256_mul_ps(r, r); } // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 constexpr __m256 cos_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); __m256i need_neg = _mm256_and_si256( _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); return _mm256_xor_ps(result, neg_mask); } constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); sin_r = _mm256_mul_ps(sin_r, r); } // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign constexpr __m256 sin_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); result = _mm256_xor_ps(result, neg_mask); // Apply original sign of x return _mm256_xor_ps(result, x_sign); } // --- 512-bit helpers --- constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm512_cvtps_epi32(fq); r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); r2 = _mm512_mul_ps(r, r); } constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); sin_r = _mm512_mul_ps(sin_r, r); } constexpr __m512 cos_f32x16(__m512 x) { __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i need_neg = _mm512_and_si512( _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); return _mm512_xor_ps(result, neg_mask); } constexpr __m512 sin_f32x16(__m512 x) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); result = _mm512_xor_ps(result, neg_mask); return _mm512_xor_ps(result, x_sign); } // --- 256-bit sincos --- constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); // cos: swap on odd, negate if (q+1)&2 out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); __m256i cos_neg = _mm256_and_si256( _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); out_cos = _mm256_xor_ps(out_cos, _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); // sin: swap on odd, negate if q&2, apply input sign out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); out_sin = _mm256_xor_ps(out_sin, _mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); out_sin = _mm256_xor_ps(out_sin, x_sign); } // --- 512-bit sincos --- constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); // cos out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i cos_neg = _mm512_and_si512( _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); out_cos = _mm512_xor_ps(out_cos, _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); // sin out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); out_sin = _mm512_xor_ps(out_sin, x_sign); } }; } export template struct std::formatter> : std::formatter { constexpr auto format(const Crafter::VectorF16& obj, format_context& ctx) const { Crafter::Vector<_Float16, Len * Packing, 0> vec = obj.Store(); std::string out = "{"; for(std::uint32_t i = 0; i < Packing; i++) { out += "{"; for(std::uint32_t i2 = 0; i2 < Len; i2++) { out += std::format("{}", static_cast(vec.v[i * Len + i2])); if (i2 + 1 < Len) out += ","; } out += "}"; } out += "}"; return std::formatter::format(out, ctx); } }; #endif