/* Crafter®.Math Copyright (C) 2026 Catcrafts® catcrafts.net This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License version 3.0 as published by the Free Software Foundation; This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ module; #ifdef __x86_64 #include #endif export module Crafter.Math:VectorF16; import std; import :Vector; #ifdef __AVX512FP16__ namespace Crafter { export template struct VectorF16 { private: static consteval std::uint8_t GetAlingment() { if(Len * Packing <= 8) { return 8; } else if(Len * Packing <= 16) { return 16; } else if(Len * Packing <= 32) { return 32; } } using VectorType = std::conditional_t< (Len * Packing > 16), __m512h, std::conditional_t<(Len * Packing > 8), __m256h, __m128h> >; VectorType v; public: static constexpr std::uint32_t MaxSize = 32; static constexpr std::uint8_t Alignment = GetAlingment(); static_assert(Len * Packing <= MaxSize, "Len * Packing exceeds MaxSize"); private: template values> static consteval std::array GetNegateMask() { std::array mask{0}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { if(values[i]) { mask[i2*Len+i] = 0b1000000000000000; } else { mask[i2*Len+i] = 0; } } } return mask; } static consteval std::array GetNegateMaskAll() { std::array mask{0}; for(std::uint8_t i = 0; i < Packing*Len; i++) { mask[i] = 0b1000000000000000; } return mask; } template ShuffleValues> static consteval bool GetShuffleMaskEpi32() { std::uint8_t mask = 0; for(std::uint8_t i = 0; i < std::min(Len, std::uint32_t(8)); i+=2) { mask = mask | (ShuffleValues[i] & 0b11) << i; } return mask; } template ShuffleValues> static consteval std::array::Alignment> GetPermuteMaskEpi16() { std::array::Alignment> shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; } } return shuffleMask; } static consteval std::array GetAllTrue() { std::array arr{}; arr.fill(true); return arr; } template ShuffleValues> static consteval bool CheckEpi32Shuffle() { for(std::uint8_t i = 1; i < Len; i+=2) { if(ShuffleValues[i-1] != ShuffleValues[i] - 1) { return false; } } for(std::uint8_t i = 0; i < Len; i++) { for(std::uint8_t i2 = 0; i2 < Len; i2 += 8) { if(ShuffleValues[i] != ShuffleValues[i2]) { return false; } } } return true; } template ShuffleValues> static consteval bool CheckEpi8Shuffle() { for(std::uint8_t i = 0; i < Len; i++) { std::uint8_t lane = i / 8; if(ShuffleValues[i] < lane * 8 || ShuffleValues[i] > lane * 8 + 7) { return false; } } return true; } template ShuffleValues> static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint8_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint16_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint32_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::array::Alignment*2> GetShuffleMaskEpi8() { std::array::Alignment*2> shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2); shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2); } } return shuffleMask; } public: template friend class VectorF16; constexpr VectorF16() = default; constexpr VectorF16(VectorType v) : v(v) {} constexpr VectorF16(const _Float16* vB) { Load(vB); }; constexpr VectorF16(_Float16 val) { if constexpr(std::is_same_v) { v = _mm_set1_ph(val); } else if constexpr(std::is_same_v) { v = _mm256_set1_ph(val); } else { v = _mm512_set1_ph(val); } }; constexpr void Load(const _Float16* vB) { if constexpr(std::is_same_v) { v = _mm_loadu_ph(vB); } else if constexpr(std::is_same_v) { v = _mm256_loadu_ph(vB); } else { v = _mm512_loadu_ph(vB); } } constexpr void Store(_Float16* vB) const { if constexpr(std::is_same_v) { _mm_storeu_ph(vB, v); } else if constexpr(std::is_same_v) { _mm256_storeu_ph(vB, v); } else { _mm512_storeu_ph(vB, v); } } constexpr Vector<_Float16, Len*Packing, Alignment> Store() const { Vector<_Float16, Len*Packing, Alignment> returnVec; Store(returnVec.v); return returnVec; } template constexpr operator VectorF16() const { if constexpr (Len == BLen) { if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm256_castph256_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm512_castph512_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm512_castph512_ph256(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm256_castph128_ph256(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { return VectorF16(_mm512_castph128_ph512(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { return VectorF16(_mm512_castph256_ph512(v)); } else { return VectorF16(v); } } else if constexpr (BLen <= Len) { return this->template ExtractLo(); } else { if constexpr(std::is_same_v::VectorType, __m128h>) { if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask = GetExtractLoMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); __m256i result = _mm256_permutexvar_epi16(permIdx, _mm_castph_si256(v)); return VectorF16(_mm_castsi128_ph(_mm256_castsi256_si128(result))); } else { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); return VectorF16(_mm_castsi128_ph(_mm512_castsi512_si128(result))); } } else if constexpr(std::is_same_v::VectorType, __m256h>) { if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castsi128_si256(_mm_castph_si128(v))); return VectorF16(_mm256_castsi256_ph(result)); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)); return VectorF16(_mm256_castsi256_ph(result)); } else { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m256i permIdx = _mm512_loadu_epi16(permMask.data()); __m256i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi512_si256(_mm512_castph_si512(v))); return VectorF16(_mm256_castsi256_ph(result)); } } else { if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi128_si512(_mm_castph_si128(v))); return VectorF16(_mm512_castsi512_ph(result)); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi256_si512(_mm256_castph_si256(v))); return VectorF16(_mm512_castsi512_ph(result)); } else { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); return VectorF16(_mm512_castsi512_ph(result)); } } } } constexpr VectorF16 operator+(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_add_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_add_ph(v, b.v)); } else { return VectorF16(_mm512_add_ph(v, b.v)); } } constexpr VectorF16 operator-(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_sub_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_sub_ph(v, b.v)); } else { return VectorF16(_mm512_sub_ph(v, b.v)); } } constexpr VectorF16 operator*(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_mul_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_mul_ph(v, b.v)); } else { return VectorF16(_mm512_mul_ph(v, b.v)); } } constexpr VectorF16 operator/(VectorF16 b) const { if constexpr(std::is_same_v) { return VectorF16(_mm_div_ph(v, b.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_div_ph(v, b.v)); } else { return VectorF16(_mm512_div_ph(v, b.v)); } } constexpr void operator+=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_add_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_add_ph(v, b.v); } else { v = _mm512_add_ph(v, b.v); } } constexpr void operator-=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_sub_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_sub_ph(v, b.v); } else { v = _mm512_sub_ph(v, b.v); } } constexpr void operator*=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_mul_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_mul_ph(v, b.v); } else { v = _mm512_mul_ph(v, b.v); } } constexpr void operator/=(VectorF16 b) { if constexpr(std::is_same_v) { v = _mm_div_ph(v, b.v); } else if constexpr(std::is_same_v) { v = _mm256_div_ph(v, b.v); } else { v = _mm512_div_ph(v, b.v); } } constexpr VectorF16 operator+(_Float16 b) { VectorF16 vB(b); return *this + vB; } constexpr VectorF16 operator-(_Float16 b) { VectorF16 vB(b); return *this - vB; } constexpr VectorF16 operator*(_Float16 b) { VectorF16 vB(b); return *this * vB; } constexpr VectorF16 operator/(_Float16 b) { VectorF16 vB(b); return *this / vB; } constexpr void operator+=(_Float16 b) { VectorF16 vB(b); *this += vB; } constexpr void operator-=(_Float16 b) { VectorF16 vB(b); *this -= vB; } constexpr void operator*=(_Float16 b) { VectorF16 vB(b); *this *= vB; } constexpr void operator/=(_Float16 b) { VectorF16 vB(b); *this /= vB; } constexpr VectorF16 operator-(){ return Negate(); } constexpr bool operator==(VectorF16 b) const { if constexpr(std::is_same_v) { return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 255; } else if constexpr(std::is_same_v) { return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 65535; } else { return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 4294967295; } } constexpr bool operator!=(VectorF16 b) const { if constexpr(std::is_same_v) { return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 255; } else if constexpr(std::is_same_v) { return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 65535; } else { return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 4294967295; } } template static consteval std::array::Alignment*2> GetExtractLoMaskEpi8() { std::array::Alignment*2> mask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < ExtractLen; i++) { mask[(i2*ExtractLen*2)+(i*2)] = i*2+(i2*Len*2); mask[(i2*ExtractLen*2)+(i*2+1)] = i*2+1+(i2*Len*2); } } return mask; } template static consteval std::array::Alignment> GetExtractLoMaskEpi16() { std::array::Alignment> mask{}; for (std::uint16_t i2 = 0; i2 < Packing; i2++) { for (std::uint16_t i = 0; i < ExtractLen; i++) { mask[i2 * ExtractLen + i] = i + (i2 * Len); } } return mask; } template constexpr VectorF16 ExtractLo() const { if constexpr(Packing > 1) { if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask = GetExtractLoMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)); if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm256_castph256_ph128(_mm256_castsi256_ph(result))); } else { return VectorF16(_mm256_castsi256_ph(result)); } } else { constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm512_castph512_ph128(_mm512_castsi512_ph(result))); } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm512_castph512_ph256(_mm512_castsi512_ph(result))); } else { return VectorF16(_mm512_castsi512_ph(result)); } } } else { if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm256_castph256_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm512_castph512_ph128(v)); } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm512_castph512_ph256(v)); } else { return VectorF16(v); } } } constexpr void Normalize() requires(Packing == 1) { if constexpr(std::is_same_v) { _Float16 dot = LengthSq(); __m128h vec = _mm_set1_ph(dot); __m128h sqrt = _mm_sqrt_ph(vec); v = _mm_div_ph(v, sqrt); } else if constexpr(std::is_same_v) { _Float16 dot = LengthSq(); __m256h vec = _mm256_set1_ph(dot); __m256h sqrt = _mm256_sqrt_ph(vec); v = _mm256_div_ph(v, sqrt); } else { _Float16 dot = LengthSq(); __m512h vec = _mm512_set1_ph(dot); __m512h sqrt = _mm512_sqrt_ph(vec); v = _mm512_div_ph(v, sqrt); } } constexpr _Float16 Length() const requires(Packing == 1) { _Float16 Result = LengthSq(); return std::sqrtf(Result); } constexpr _Float16 LengthSq() const requires(Packing == 1) { return Dot(*this, *this); } constexpr VectorF16 Cos() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); wide = cos_f32x8(wide); return VectorF16( _mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); wide = cos_f32x16(wide); return VectorF16( _mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); __m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); return VectorF16( _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } constexpr VectorF16 Sin() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); wide = sin_f32x8(wide); return VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); wide = sin_f32x16(wide); return VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); __m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); return VectorF16(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } std::tuple, VectorF16> SinCos() { if constexpr (std::is_same_v) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); __m256 s, c; sincos_f32x8(wide, s, c); return { VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; } else if constexpr (std::is_same_v) { __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); __m512 s, c; sincos_f32x16(wide, s, c); return { VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; } else { __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); __m512 s_lo, c_lo, s_hi, c_hi; sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo); sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi); auto pack = [](__m256i lo_ph, __m256i hi_ph) { return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)); }; return { VectorF16(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))), VectorF16(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT))) }; } } template values> constexpr VectorF16 Negate() { std::array mask = GetNegateMask(); if constexpr(std::is_same_v) { return VectorF16(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data())))); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(v), _mm256_loadu_epi16(mask.data())))); } else { return VectorF16(_mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(v), _mm512_loadu_epi16(mask.data())))); } } static constexpr VectorF16 MulitplyAdd(VectorF16 a, VectorF16 b, VectorF16 add) { if constexpr(std::is_same_v) { return VectorF16(_mm_fmadd_ph(a.v, b.v, add.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_fmadd_ph(a.v, b.v, add.v)); } else { return VectorF16(_mm512_fmadd_ph(a.v, b.v, add.v)); } } static constexpr VectorF16 MulitplySub(VectorF16 a, VectorF16 b, VectorF16 sub) { if constexpr(std::is_same_v) { return VectorF16(_mm_fmsub_ph(a.v, b.v, sub.v)); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_fmsub_ph(a.v, b.v, sub.v)); } else { return VectorF16(_mm512_fmsub_ph(a.v, b.v, sub.v)); } } constexpr static VectorF16 Cross(VectorF16 a, VectorF16 b) requires(Len == 3) { if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data()); __m128h row1 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec1)); __m128h row4 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec1)); constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data()); __m128h row3 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec3)); __m128h row2 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec3)); __m128h result = _mm_mul_ph(row3, row4); return _mm_fmsub_ph(row1,row2,result); } else if constexpr (std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data())); __m256h row1 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec1))); __m256h row4 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec1))); constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data())); __m256h row3 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec3))); __m256h row2 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec3))); __m256h result = _mm256_mul_ph(row3, row4); return _mm256_fmsub_ph(row1,row2,result); } else { constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data()); __m512h row1 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec1)); __m512h row4 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec1)); constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data()); __m512h row3 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec3)); __m512h row2 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec3)); __m512h result = _mm512_mul_ph(row3, row4); return _mm512_fmsub_ph(row1,row2,result); } } constexpr static _Float16 Dot(VectorF16 a, VectorF16 b) requires(Packing == 1) { if constexpr(std::is_same_v) { __m128h mul = _mm_mul_ph(a.v, b.v); return _mm_reduce_add_ph(mul); } else if constexpr(std::is_same_v) { static_assert(std::is_same_v, "a.v is NOT VectorType"); __m256h mul = _mm256_mul_ph(a.v, b.v); return _mm256_reduce_add_ph(mul); } else { __m512h mul = _mm512_mul_ph(a.v, b.v); return _mm512_reduce_add_ph(mul); } } template ShuffleValues> constexpr VectorF16 Shuffle() { if constexpr(CheckEpi32Shuffle()) { if constexpr(std::is_same_v) { return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), GetShuffleMaskEpi32()))); } else if constexpr(std::is_same_v) { return VectorF16(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), GetShuffleMaskEpi32()))); } else { return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(v), GetShuffleMaskEpi32()))); } } else if constexpr(CheckEpi8Shuffle()){ if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); return VectorF16(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(v)), _mm512_castsi256_si512(shuffleVec))))); } else { constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec))); } } else { if constexpr(std::is_same_v) { constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); } else if constexpr(std::is_same_v) { constexpr std::array::Alignment> permMask = GetPermuteMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); return VectorF16(_mm256_castsi256_ph(_mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)))); } else { constexpr std::array::Alignment> permMask = GetPermuteMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); return VectorF16(_mm512_castsi512_ph(_mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)))); } } } constexpr static std::tuple, VectorF16, VectorF16, VectorF16> Normalize( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { VectorF16<1, 8> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v)); VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1}}>(); VectorF16<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3}}>(); VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5}}>(); VectorF16<8, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7}}>(); return { _mm_mul_ph(A.v, fLenghtA.v), _mm_mul_ph(C.v, fLenghtC.v), _mm_mul_ph(E.v, fLenghtE.v), _mm_mul_ph(G.v, fLenghtG.v) }; } else if constexpr(std::is_same_v) { VectorF16<1, 16> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v)); VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9}}>(); VectorF16<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11}}>(); VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13}}>(); VectorF16<16, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15}}>(); return { _mm256_mul_ph(A.v, fLenghtA.v), _mm256_mul_ph(C.v, fLenghtC.v), _mm256_mul_ph(E.v, fLenghtE.v), _mm256_mul_ph(G.v, fLenghtG.v) }; } else { VectorF16<1, 32> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v)); VectorF16<32, 1> fLenght2(lenght.v); VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9,16,16,16,16,17,17,17,17,24,24,24,24,25,25,25,25}}>(); VectorF16<32, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11,18,18,18,18,19,19,19,19,26,26,26,26,27,27,27,27}}>(); VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13,20,20,20,20,21,21,21,21,28,28,28,28,29,29,29,29}}>(); VectorF16<32, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15,22,22,22,22,23,23,23,23,30,30,30,30,31,31,31,31}}>(); return { VectorF16(_mm512_mul_ph(A.v, fLenghtA.v)), VectorF16(_mm512_mul_ph(C.v, fLenghtC.v)), VectorF16(_mm512_mul_ph(E.v, fLenghtE.v)), VectorF16(_mm512_mul_ph(G.v, fLenghtG.v)), }; } } constexpr static std::tuple, VectorF16> Normalize( VectorF16 A, VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { VectorF16<1, 8> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v)); VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3}}>(); VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7}}>(); return { _mm_mul_ph(A.v, fLenghtA.v), _mm_mul_ph(E.v, fLenghtE.v), }; } else if constexpr(std::is_same_v) { VectorF16<1, 16> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v)); VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11}}>(); VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15}}>(); return { _mm256_mul_ph(A.v, fLenghtA.v), _mm256_mul_ph(E.v, fLenghtE.v), }; } else { VectorF16<1, 32> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512h one = _mm512_loadu_ph(oneArr); VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v)); VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11,16,16,17,17,18,18,19,19,24,24,25,25,26,26,27,27}}>(); VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15,20,20,21,21,22,22,23,23,28,28,29,29,30,30,31,31}}>(); return { _mm512_mul_ph(A.v, fLenghtA.v), _mm512_mul_ph(E.v, fLenghtE.v), }; } } constexpr static VectorF16<1, Packing*4> Length( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { VectorF16<1, Packing*4> lenghtSq = LengthSq(A, C, E, G); if constexpr(std::is_same_v) { return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16<1, Packing*2> Length( VectorF16 A, VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { VectorF16<1, Packing*2> lenghtSq = LengthSq(A, E); if constexpr(std::is_same_v) { return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16<1, Packing*4> LengthSq( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { return Dot(A, A, C, C, E, E, G, G); } constexpr static VectorF16<1, Packing*2> LengthSq( VectorF16 A, VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { return Dot(A, A, E, E); } constexpr static VectorF16<1, Packing*4> Dot( VectorF16 A0, VectorF16 A1, VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, VectorF16 G0, VectorF16 G1 ) requires(Len == 4 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { return DotNoShuffle(A0, A1, E0, E1, C0, C1, G0, G1); } else if constexpr(std::is_same_v) { VectorF16<16, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v); vec = vec.template Shuffle<{{ 0,1,8,9, 4,5,12,13, 2,3,10,11, 6,7,14,15 }}>(); return vec.v; } else { VectorF16<32, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v); vec = vec.template Shuffle<{{ 0,1,8,9, 16,17,24,25, 4,5,12,13, 20,21,28,29, 2,3,10,11, 18,19,24,25, 6,7,14,15, 22,23,30,31 }}>(); return vec.v; } } constexpr static VectorF16<1, Packing*2> Dot( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 ) requires(Len == 2 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { return DotNoShuffle(A0, A1, E0, E1); } else if constexpr(std::is_same_v) { VectorF16<16, 1> vec(DotNoShuffle(A0, A1, E0, E1).v); vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>(); return vec.v; } else { VectorF16<32, 1> vec(DotNoShuffle(A0, A1, E0, E1).v); vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>(); return vec.v; } } private: constexpr static VectorF16<1, Packing*4> LengthNoShuffle( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { VectorF16<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, C, E, G); if constexpr(std::is_same_v) { return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16<1, Packing*2> LengthNoShuffle( VectorF16 A, VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { VectorF16<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, E); if constexpr(std::is_same_v) { return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v)); } else if constexpr(std::is_same_v) { return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v)); } } constexpr static VectorF16<1, Packing*4> LengthSqNoShuffle( VectorF16 A, VectorF16 C, VectorF16 E, VectorF16 G ) requires(Len == 4 && Packing*Len == Alignment) { return DotNoShuffle(A, A, C, C, E, E, G, G); } constexpr static VectorF16<1, Packing*2> LengthSqNoShuffle( VectorF16 A, VectorF16 E ) requires(Len == 2 && Packing*Len == Alignment) { return DotNoShuffle(A, A, E, E); } constexpr static VectorF16<1, Packing*4> DotNoShuffle( VectorF16 A0, VectorF16 A1, VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, VectorF16 G0, VectorF16 G1 ) requires(Len == 4 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulC = _mm_mul_ph(C0.v, C1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m128i row1TempTemp1 = row12Temp1; __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128h mulG = _mm_mul_ph(G0.v, G1.v); __m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m128i row12Temp2Temp = row12Temp2; __m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp2 = _mm_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m128h row3 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m128h row4 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm_add_ph(row1, row2); row1 = _mm_add_ph(row1, row3); row1 = _mm_add_ph(row1, row4); return row1; } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulC = _mm256_mul_ph(C0.v, C1.v); __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m256i row1TempTemp1 = row12Temp1; __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256h mulG = _mm256_mul_ph(G0.v, G1.v); __m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m256i row12Temp2Temp = row12Temp2; __m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp2 = _mm256_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m256h row3 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m256h row4 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm256_add_ph(row1, row2); row1 = _mm256_add_ph(row1, row3); row1 = _mm256_add_ph(row1, row4); return row1; } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulC = _mm512_mul_ph(C0.v, C1.v); __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4 __m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4 __m512i row1TempTemp1 = row12Temp1; __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512h mulG = _mm512_mul_ph(G0.v, G1.v); __m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4 __m512i row12Temp2Temp = row12Temp2; __m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4 row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2 row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2 row34Temp2 = _mm512_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4 row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4 __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2 __m512h row3 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3 __m512h row4 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4 row1 = _mm512_add_ph(row1, row2); row1 = _mm512_add_ph(row1, row3); row1 = _mm512_add_ph(row1, row4); return row1; } } constexpr static VectorF16<1, Packing*2> DotNoShuffle( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 ) requires(Len == 2 && Packing*Len == Alignment) { if constexpr(std::is_same_v) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m128i row12Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m128i row12Temp1Temp = row12Temp1; row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 return _mm_add_ph(row1, row2); } else if constexpr(std::is_same_v) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulE = _mm256_mul_ph(E0.v, E1.v); __m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m256i row12Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m256i row12Temp1Temp = row12Temp1; row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm256_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 __m256h result = _mm256_add_ph(row1, row2); return result; } else { __m512h mulA = _mm512_mul_ph(A0.v, A1.v); __m512h mulE = _mm512_mul_ph(E0.v, E1.v); __m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 __m512i row12Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2 __m512i row12Temp1Temp = row12Temp1; row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2 row12Temp2 = _mm512_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2 __m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1 __m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 __m512h result = _mm512_add_ph(row1, row2); return result; } } public: template ShuffleValues> constexpr static VectorF16 Blend(VectorF16 a, VectorF16 b) { if constexpr(std::is_same_v) { return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b.v), GetBlendMaskEpi16())); } else if constexpr(std::is_same_v) { #ifndef __AVX512BW__ #ifndef __AVX512VL__ static_assert(false, "No __AVX512BW__ and __AVX512VL__ support"); #endif #endif return _mm256_castsi256_ph(_mm256_mask_blend_epi16(GetBlendMaskEpi16(), _mm256_castph_si256(a.v), _mm256_castph_si256(b.v))); } else { return _mm512_castsi512_ph(_mm512_mask_blend_epi16(GetBlendMaskEpi16(), _mm512_castph_si512(a.v), _mm512_castph_si512(b.v))); } } constexpr static VectorF16 Rotate(VectorF16<3, Packing> v, VectorF16<4, Packing> q) requires(Len == 3) { VectorF16<3, Packing> qv(q); VectorF16 t = Cross(qv, v) * _Float16(2); return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); } constexpr static VectorF16<4, 2> RotatePivot(VectorF16<3, Packing> v, VectorF16<4, Packing> q, VectorF16<3, Packing> pivot) requires(Len == 3) { VectorF16 translated = v - pivot; VectorF16<3, Packing> qv(q.v); VectorF16 t = Cross(qv, translated) * _Float16(2); VectorF16 rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); return rotated + pivot; } constexpr static VectorF16<4, Packing> QuanternionFromEuler(VectorF16<3, Packing> EulerHalf) requires(Len == 4) { std::tuple, VectorF16<3, Packing>> sinCos = EulerHalf.SinCos(); VectorF16<4, Packing> sin = std::get<0>(sinCos); VectorF16<4, Packing> cos = std::get<1>(sinCos); VectorF16<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>(); row1 = Blend<{{0,1,1,1}}>(sin, row1); VectorF16<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>(); row2 = Blend<{{1,0,1,1}}>(sin, row2); row1 *= row2; VectorF16<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>(); row3 = Blend<{{1,1,0,1}}>(sin, row3); row1 *= row3; VectorF16<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>(); row4 = Blend<{{0,1,1,1}}>(cos, row4); VectorF16<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>(); row5 = Blend<{{1,0,1,1}}>(cos, row5); row4 *= row5; VectorF16<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>(); row6 = Blend<{{1,1,0,1}}>(cos, row6); row6 = row6.template Negate<{{true,false,true,false}}>(); row1 = MulitplyAdd(row4, row6, row1); return row1; } static constexpr float two_over_pi = 0.6366197723675814f; static constexpr float pi_over_2_hi = 1.5707963267341256f; static constexpr float pi_over_2_lo = 6.077100506506192e-11f; // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... static constexpr float c0 = 1.0f; static constexpr float c2 = -0.4999999642372f; static constexpr float c4 = 0.0416666418707f; static constexpr float c6 = -0.0013888397720f; static constexpr float c8 = 0.0000248015873f; static constexpr float c10 = -0.0000002752258f; // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) static constexpr float s1 = -0.1666666641831f; static constexpr float s3 = 0.0083333293858f; static constexpr float s5 = -0.0001984090955f; static constexpr float s7 = 0.0000027526372f; static constexpr float s9 = -0.0000000239013f; // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm256_cvtps_epi32(fq); r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); r2 = _mm256_mul_ps(r, r); } // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 constexpr __m256 cos_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); __m256i need_neg = _mm256_and_si256( _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); return _mm256_xor_ps(result, neg_mask); } constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); sin_r = _mm256_mul_ps(sin_r, r); } // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign constexpr __m256 sin_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); result = _mm256_xor_ps(result, neg_mask); // Apply original sign of x return _mm256_xor_ps(result, x_sign); } // --- 512-bit helpers --- constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm512_cvtps_epi32(fq); r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); r2 = _mm512_mul_ps(r, r); } constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); sin_r = _mm512_mul_ps(sin_r, r); } constexpr __m512 cos_f32x16(__m512 x) { __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i need_neg = _mm512_and_si512( _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); return _mm512_xor_ps(result, neg_mask); } constexpr __m512 sin_f32x16(__m512 x) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); result = _mm512_xor_ps(result, neg_mask); return _mm512_xor_ps(result, x_sign); } // --- 256-bit sincos --- constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); // cos: swap on odd, negate if (q+1)&2 out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); __m256i cos_neg = _mm256_and_si256( _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); out_cos = _mm256_xor_ps(out_cos, _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); // sin: swap on odd, negate if q&2, apply input sign out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); out_sin = _mm256_xor_ps(out_sin, _mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); out_sin = _mm256_xor_ps(out_sin, x_sign); } // --- 512-bit sincos --- constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); // cos out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i cos_neg = _mm512_and_si512( _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); out_cos = _mm512_xor_ps(out_cos, _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); // sin out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); out_sin = _mm512_xor_ps(out_sin, x_sign); } }; } export template struct std::formatter> : std::formatter { constexpr auto format(const Crafter::VectorF16& obj, format_context& ctx) const { Crafter::Vector<_Float16, Len * Packing, 0> vec = obj.Store(); std::string out = "{"; for(std::uint32_t i = 0; i < Packing; i++) { out += "{"; for(std::uint32_t i2 = 0; i2 < Len; i2++) { out += std::format("{}", static_cast(vec.v[i * Len + i2])); if (i2 + 1 < Len) out += ","; } out += "}"; } out += "}"; return std::formatter::format(out, ctx); } }; #endif