1331 lines
No EOL
75 KiB
C++
Executable file
1331 lines
No EOL
75 KiB
C++
Executable file
/*
|
||
Crafter®.Math
|
||
Copyright (C) 2026 Catcrafts®
|
||
catcrafts.net
|
||
|
||
This library is free software; you can redistribute it and/or
|
||
modify it under the terms of the GNU Lesser General Public
|
||
License version 3.0 as published by the Free Software Foundation;
|
||
|
||
This library is distributed in the hope that it will be useful,
|
||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||
Lesser General Public License for more details.
|
||
|
||
You should have received a copy of the GNU Lesser General Public
|
||
License along with this library; if not, write to the Free Software
|
||
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||
*/
|
||
module;
|
||
#ifdef __x86_64
|
||
#include <immintrin.h>
|
||
#endif
|
||
export module Crafter.Math:VectorF32;
|
||
import std;
|
||
import :Common;
|
||
|
||
#ifdef __AVX512FP16__
|
||
namespace Crafter {
|
||
export template <std::uint8_t Len, std::uint8_t Packing>
|
||
struct VectorF32 : public VectorBase<Len, Packing, float> {
|
||
template <std::uint8_t Len2, std::uint8_t Packing2>
|
||
friend struct VectorF32;
|
||
|
||
constexpr VectorF32() = default;
|
||
constexpr VectorF32(VectorBase<Len, Packing, float>::VectorType v) {
|
||
this->v = v;
|
||
}
|
||
constexpr VectorF32(const float* vB) {
|
||
Load(vB);
|
||
};
|
||
constexpr VectorF32(float val) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_set1_ps(val);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_set1_ps(val);
|
||
} else {
|
||
this->v = _mm512_set1_ps(val);
|
||
}
|
||
};
|
||
constexpr void Load(const float* vB) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_loadu_ps(vB);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_loadu_ps(vB);
|
||
} else {
|
||
this->v = _mm512_loadu_ps(vB);
|
||
}
|
||
}
|
||
constexpr void Store(float* vB) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
_mm_storeu_ps(vB, this->v);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
_mm256_storeu_ps(vB, this->v);
|
||
} else {
|
||
_mm512_storeu_ps(vB, this->v);
|
||
}
|
||
}
|
||
|
||
constexpr std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
|
||
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
|
||
Store(returnArray.data());
|
||
return returnArray;
|
||
}
|
||
|
||
template <std::uint8_t BLen, std::uint8_t BPacking>
|
||
constexpr operator VectorF32<BLen, BPacking>() const {
|
||
if constexpr (Len == BLen) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<BLen, BPacking>(_mm256_castps256_ps128(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps128(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps256(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<BLen, BPacking>(_mm256_castps128_ps256(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
|
||
return VectorF32<BLen, BPacking>(_mm512_castps128_ps512(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
|
||
return VectorF32<BLen, BPacking>(_mm512_castps256_ps512(this->v));
|
||
} else {
|
||
return VectorF32<BLen, BPacking>(this->v);
|
||
}
|
||
} else if constexpr (BLen <= Len) {
|
||
return this->template ExtractLo<BLen>();
|
||
} else {
|
||
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<BLen>();
|
||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask =VectorBase<Len, Packing, float>::template GetExtractLoMaskepi32<BLen>();
|
||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm_castps_si256(this->v));
|
||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm256_castsi256_si128(result)));
|
||
} else {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
|
||
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm512_castsi512_si128(result)));
|
||
}
|
||
} else if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m256>) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castsi128_si256(_mm_castps_si128(this->v)));
|
||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
|
||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||
} else {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m256i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m256i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi512_si256(_mm512_castps_si512(this->v)));
|
||
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
|
||
}
|
||
} else {
|
||
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi128_si512(_mm_castps_si128(this->v)));
|
||
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi256_si512(_mm256_castps_si256(this->v)));
|
||
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
|
||
} else {
|
||
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
|
||
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator+(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_add_ps(this->v, b.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_add_ps(this->v, b.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_add_ps(this->v, b.v));
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator-(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_sub_ps(this->v, b.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_sub_ps(this->v, b.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_sub_ps(this->v, b.v));
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator*(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_mul_ps(this->v, b.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_mul_ps(this->v, b.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_mul_ps(this->v, b.v));
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator/(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_div_ps(this->v, b.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_div_ps(this->v, b.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_div_ps(this->v, b.v));
|
||
}
|
||
}
|
||
|
||
|
||
constexpr void operator+=(VectorF32<Len, Packing> b) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_add_ps(this->v, b.v);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_add_ps(this->v, b.v);
|
||
} else {
|
||
this->v = _mm512_add_ps(this->v, b.v);
|
||
}
|
||
}
|
||
|
||
constexpr void operator-=(VectorF32<Len, Packing> b) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_sub_ps(this->v, b.v);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_sub_ps(this->v, b.v);
|
||
} else {
|
||
this->v = _mm512_sub_ps(this->v, b.v);
|
||
}
|
||
}
|
||
|
||
constexpr void operator*=(VectorF32<Len, Packing> b) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_mul_ps(this->v, b.v);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_mul_ps(this->v, b.v);
|
||
} else {
|
||
this->v = _mm512_mul_ps(this->v, b.v);
|
||
}
|
||
}
|
||
|
||
constexpr void operator/=(VectorF32<Len, Packing> b) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
this->v = _mm_div_ps(this->v, b.v);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
this->v = _mm256_div_ps(this->v, b.v);
|
||
} else {
|
||
this->v = _mm512_div_ps(this->v, b.v);
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator+(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
return *this + vB;
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator-(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
return *this - vB;
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator*(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
return *this * vB;
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator/(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
return *this / vB;
|
||
}
|
||
|
||
constexpr void operator+=(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
*this += vB;
|
||
}
|
||
|
||
constexpr void operator-=(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
*this -= vB;
|
||
}
|
||
|
||
constexpr void operator*=(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
*this *= vB;
|
||
}
|
||
|
||
constexpr void operator/=(float b) {
|
||
VectorF32<Len, Packing> vB(b);
|
||
*this /= vB;
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> operator-(){
|
||
return Negate<VectorBase<Len, Packing, float>::GetAllTrue()>();
|
||
}
|
||
|
||
constexpr bool operator==(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 15;
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 255;
|
||
} else {
|
||
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 65535;
|
||
}
|
||
}
|
||
|
||
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 15;
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 255;
|
||
} else {
|
||
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 65535;
|
||
}
|
||
}
|
||
|
||
template<std::uint32_t ExtractLen>
|
||
constexpr VectorF32<ExtractLen, Packing> ExtractLo() const {
|
||
if constexpr(Packing > 1) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint8_t,VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<ExtractLen>();
|
||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<ExtractLen, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
|
||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
|
||
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(_mm256_castsi256_ps(result)));
|
||
} else {
|
||
return VectorF32<ExtractLen, Packing>(_mm256_castsi256_ps(result));
|
||
}
|
||
} else {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
|
||
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(_mm512_castsi512_ps(result)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(_mm512_castsi512_ps(result)));
|
||
} else {
|
||
return VectorF32<ExtractLen, Packing>(_mm512_castsi512_ps(result));
|
||
}
|
||
}
|
||
} else {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(this->v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(this->v));
|
||
} else {
|
||
return VectorF32<ExtractLen, Packing>(this->v);
|
||
}
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> Cos() {
|
||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x4(this->v));
|
||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x8(this->v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x16(this->v));
|
||
}
|
||
}
|
||
|
||
constexpr VectorF32<Len, Packing> Sin() {
|
||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x4(this->v));
|
||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x8(this->v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x16(this->v));
|
||
}
|
||
}
|
||
|
||
std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> SinCos() {
|
||
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
__m128 s, c;
|
||
VectorBase<Len, Packing, float>::sincos_f32x4(this->v, s, c);
|
||
return {
|
||
VectorF32<Len, Packing>(s),
|
||
VectorF32<Len, Packing>(c)
|
||
};
|
||
|
||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
__m256 s, c;
|
||
VectorBase<Len, Packing, float>::sincos_f32x8(this->v, s, c);
|
||
return {
|
||
VectorF32<Len, Packing>(s),
|
||
VectorF32<Len, Packing>(c)
|
||
};
|
||
|
||
} else {
|
||
__m512 s, c;
|
||
VectorBase<Len, Packing, float>::sincos_f32x16(this->v, s, c);
|
||
return {
|
||
VectorF32<Len, Packing>(s),
|
||
VectorF32<Len, Packing>(c)
|
||
};
|
||
}
|
||
}
|
||
|
||
template <std::array<bool, Len> values>
|
||
constexpr VectorF32<Len, Packing> Negate() {
|
||
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> mask = VectorBase<Len, Packing, float>::template GetNegateMask<values>();
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_epi32(mask.data()))));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_epi32(mask.data()))));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(this->v), _mm512_loadu_epi32(mask.data()))));
|
||
}
|
||
}
|
||
|
||
static constexpr VectorF32<Len, Packing> MulitplyAdd(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b, VectorF32<Len, Packing> add) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_fmadd_ps(a.v, b.v, add.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_fmadd_ps(a.v, b.v, add.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_fmadd_ps(a.v, b.v, add.v));
|
||
}
|
||
}
|
||
|
||
static constexpr VectorF32<Len, Packing> MulitplySub(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b, VectorF32<Len, Packing> sub) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_fmsub_ps(a.v, b.v, sub.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_fmsub_ps(a.v, b.v, sub.v));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_fmsub_ps(a.v, b.v, sub.v));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<Len, Packing> Cross(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) requires(Len == 3) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data());
|
||
__m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec1));
|
||
__m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec1));
|
||
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data());
|
||
__m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec3));
|
||
__m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec3));
|
||
|
||
__m128 result = _mm_mul_ps(row3, row4);
|
||
return _mm_fmsub_ps(row1,row2,result);
|
||
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||
__m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data()));
|
||
__m256 row1 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec1)));
|
||
__m256 row4 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec1)));
|
||
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||
|
||
__m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data()));
|
||
__m256 row3 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec3)));
|
||
__m256 row2 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec3)));
|
||
|
||
__m256 result = _mm256_mul_ps(row3, row4);
|
||
return _mm256_fmsub_ps(row1,row2,result);
|
||
} else {
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
||
|
||
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data());
|
||
__m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec1));
|
||
__m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec1));
|
||
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
||
|
||
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data());
|
||
__m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec3));
|
||
__m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec3));
|
||
|
||
__m512 result = _mm512_mul_ps(row3, row4);
|
||
return _mm512_fmsub_ps(row1,row2,result);
|
||
}
|
||
}
|
||
|
||
template <const std::array<std::uint8_t, Len> ShuffleValues>
|
||
constexpr VectorF32<Len, Packing> Shuffle() {
|
||
if constexpr(VectorBase<Len, Packing, float>::template CheckEpi32Shuffle<ShuffleValues>()) {
|
||
constexpr std::uint8_t imm = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi32<ShuffleValues>();
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(this->v), imm)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
|
||
} else {
|
||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(this->v), imm)));
|
||
}
|
||
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()){
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)), _mm512_castsi256_si512(shuffleVec)))));
|
||
} else {
|
||
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(this->v), shuffleVec)));
|
||
}
|
||
} else {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
|
||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
|
||
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
|
||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v))));
|
||
} else {
|
||
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
|
||
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
|
||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v))));
|
||
}
|
||
}
|
||
}
|
||
|
||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
VectorF32<1, 4> lenght = LengthNoShuffle(A, C, B, D);
|
||
constexpr float oneArr[] {1, 1, 1, 1};
|
||
__m128 one = _mm_loadu_ps(oneArr);
|
||
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
|
||
|
||
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>();
|
||
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>();
|
||
VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>();
|
||
VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>();
|
||
|
||
return {
|
||
_mm_mul_ps(A.v, fLenghtA.v),
|
||
_mm_mul_ps(B.v, fLenghtB.v),
|
||
_mm_mul_ps(C.v, fLenghtC.v),
|
||
_mm_mul_ps(D.v, fLenghtD.v)
|
||
};
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
VectorF32<1, 8> lenght = LengthNoShuffle(A, C, B, D);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m256 one = _mm256_loadu_ps(oneArr);
|
||
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
|
||
|
||
|
||
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4}}>();
|
||
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5}}>();
|
||
VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6}}>();
|
||
VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7}}>();
|
||
|
||
return {
|
||
_mm256_mul_ps(A.v, fLenghtA.v),
|
||
_mm256_mul_ps(B.v, fLenghtB.v),
|
||
_mm256_mul_ps(C.v, fLenghtC.v),
|
||
_mm256_mul_ps(D.v, fLenghtD.v)
|
||
};
|
||
} else {
|
||
VectorF32<1, 16> lenght = LengthNoShuffle(A, C, B, D);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m512 one = _mm512_loadu_ps(oneArr);
|
||
VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v));
|
||
VectorF32<16, 1> fLenght2(lenght.v);
|
||
|
||
VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4,8,8,8,8,12,12,12,12}}>();
|
||
VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5,9,9,9,9,13,13,13,13}}>();
|
||
VectorF32<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14}}>();
|
||
VectorF32<16, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7,11,11,11,11,15,15,15,15}}>();
|
||
|
||
|
||
return {
|
||
VectorF32<Len, Packing>(_mm512_mul_ps(A.v, fLenghtA.v)),
|
||
VectorF32<Len, Packing>(_mm512_mul_ps(B.v, fLenghtB.v)),
|
||
VectorF32<Len, Packing>(_mm512_mul_ps(C.v, fLenghtC.v)),
|
||
VectorF32<Len, Packing>(_mm512_mul_ps(D.v, fLenghtD.v)),
|
||
};
|
||
}
|
||
}
|
||
|
||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 1) {
|
||
VectorF32<1, 4> lenght = Length(A, B, C, D);
|
||
constexpr float oneArr[] {1, 1, 1, 1};
|
||
__m128 one = _mm_loadu_ps(oneArr);
|
||
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
|
||
|
||
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>();
|
||
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>();
|
||
VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>();
|
||
VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>();
|
||
|
||
return {
|
||
_mm_mul_ps(A.v, fLenghtA.v),
|
||
_mm_mul_ps(B.v, fLenghtB.v),
|
||
_mm_mul_ps(C.v, fLenghtC.v),
|
||
_mm_mul_ps(D.v, fLenghtD.v)
|
||
};
|
||
}
|
||
|
||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 2) {
|
||
VectorF32<1, 8> lenght = Length(A, B, C, D);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m256 one = _mm256_loadu_ps(oneArr);
|
||
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
|
||
|
||
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1}}>();
|
||
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,2, 3,3,3}}>();
|
||
VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{4,4,4, 5,5,5}}>();
|
||
VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{6,6,6, 7,7,7}}>();
|
||
|
||
return {
|
||
_mm256_mul_ps(A.v, fLenghtA.v),
|
||
_mm256_mul_ps(B.v, fLenghtB.v),
|
||
_mm256_mul_ps(C.v, fLenghtC.v),
|
||
_mm256_mul_ps(D.v, fLenghtD.v)
|
||
};
|
||
}
|
||
|
||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 3 && Packing == 5) {
|
||
VectorF32<1, 15> lenght = Length(A, B, C);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m512 one = _mm512_loadu_ps(oneArr);
|
||
VectorF32<15, 1> fLenght(_mm512_div_ps(one, lenght.v));
|
||
|
||
VectorF32<15, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4}}>();
|
||
VectorF32<15, 1> fLenghtB = fLenght.template Shuffle<{{5,5,5, 6,6,6, 7,7,7, 8,8,8, 9,9,9}}>();
|
||
VectorF32<15, 1> fLenghtC = fLenght.template Shuffle<{{10,10,10, 11,11,11, 12,12,12, 13,13,13, 14,14,14}}>();
|
||
|
||
return {
|
||
_mm512_mul_ps(A.v, fLenghtA.v),
|
||
_mm512_mul_ps(B.v, fLenghtB.v),
|
||
_mm512_mul_ps(C.v, fLenghtC.v),
|
||
};
|
||
}
|
||
|
||
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
VectorF32<1, 4> lenght = LengthNoShuffle(A, B);
|
||
constexpr float oneArr[] {1, 1, 1, 1};
|
||
__m128 one = _mm_loadu_ps(oneArr);
|
||
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
|
||
|
||
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1}}>();
|
||
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3}}>();
|
||
|
||
return {
|
||
_mm_mul_ps(A.v, fLenghtA.v),
|
||
_mm_mul_ps(B.v, fLenghtB.v),
|
||
};
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
VectorF32<1, 8> lenght = LengthNoShuffle(A, B);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m256 one = _mm256_loadu_ps(oneArr);
|
||
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
|
||
|
||
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5}}>();
|
||
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7}}>();
|
||
|
||
return {
|
||
_mm256_mul_ps(A.v, fLenghtA.v),
|
||
_mm256_mul_ps(B.v, fLenghtB.v),
|
||
};
|
||
} else {
|
||
VectorF32<1, 16> lenght = LengthNoShuffle(A, B);
|
||
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||
__m512 one = _mm512_loadu_ps(oneArr);
|
||
VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v));
|
||
|
||
VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5,8,8,9,9,12,12,13,13}}>();
|
||
VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7,10,10,11,11,14,14,15,15}}>();
|
||
|
||
return {
|
||
_mm512_mul_ps(A.v, fLenghtA.v),
|
||
_mm512_mul_ps(B.v, fLenghtB.v),
|
||
};
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*4> Length(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
VectorF32<1, Packing*4> lenghtSq = LengthSq(A, B, C, D);
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
|
||
} else {
|
||
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, 4> Length(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 1) {
|
||
VectorF32<1, 4> lenghtSq = LengthSq(A, B, C, D);
|
||
return VectorF32<1, 4>(_mm_sqrt_ps(lenghtSq.v));
|
||
}
|
||
|
||
constexpr static VectorF32<1, 8> Length(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 2) {
|
||
VectorF32<1, 8> lenghtSq = LengthSq(A, B, C, D);
|
||
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
|
||
}
|
||
|
||
constexpr static VectorF32<1, 15> Length(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 3 && Packing == 5) {
|
||
VectorF32<1, 15> lenghtSq = LengthSq(A, B, C);
|
||
return VectorF32<1, 15>(_mm512_sqrt_ps(lenghtSq.v));
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> Length(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
VectorF32<1, Packing*2> lenghtSq = LengthSq(A, C);
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
|
||
} else {
|
||
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*4> LengthSq(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
return Dot(A, A, B, B, C, C, D, D);
|
||
}
|
||
|
||
constexpr static VectorF32<1, 4> LengthSq(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 1) {
|
||
return Dot(A, A, B, B, C, C, D, D);
|
||
}
|
||
|
||
constexpr static VectorF32<1, 8> LengthSq(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 3 && Packing == 2) {
|
||
return Dot(A, A, B, B, C, C, D, D);
|
||
}
|
||
|
||
constexpr static VectorF32<1, 15> LengthSq(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 3 && Packing == 5) {
|
||
return Dot(A, A, B, B, C, C);
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> LengthSq(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
return Dot(A, A, C, C);
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*4> Dot(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
|
||
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return DotNoShuffle(A0, A1, C0, C1, B0, B1, D0, D1);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
VectorF32<8, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
|
||
vec = vec.template Shuffle<{{
|
||
0,4,2,6,
|
||
1,5,3,7,
|
||
}}>();
|
||
return vec.v;
|
||
} else {
|
||
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
|
||
vec = vec.template Shuffle<{{
|
||
0,4,8,12,
|
||
2,6,10,14,
|
||
1,5,9,13,
|
||
3,7,11,15
|
||
}}>();
|
||
return vec.v;
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, 4> Dot(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
|
||
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
|
||
) requires(Len == 3 && Packing == 1) {
|
||
// Each register: [X1 X2 X3 _]
|
||
// 4 pairs (A,B,C,D) → 4 dot products → 1 x __m128
|
||
//
|
||
// After element-wise multiply:
|
||
// mulA = [a1 a2 a3 _] (where ai = A0[i]*A1[i])
|
||
// mulB = [b1 b2 b3 _]
|
||
// mulC = [c1 c2 c3 _]
|
||
// mulD = [d1 d2 d3 _]
|
||
//
|
||
// We need: result = [a1+a2+a3, b1+b2+b3, c1+c2+c3, d1+d2+d3]
|
||
//
|
||
// Transpose to get:
|
||
// row1 = [a1 b1 c1 d1]
|
||
// row2 = [a2 b2 c2 d2]
|
||
// row3 = [a3 b3 c3 d3]
|
||
// Then sum rows.
|
||
|
||
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
|
||
__m128 mulB = _mm_mul_ps(B0.v, B1.v);
|
||
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
|
||
__m128 mulD = _mm_mul_ps(D0.v, D1.v);
|
||
|
||
// Standard 4x4 transpose (only first 3 rows matter, 4th is garbage)
|
||
// unpacklo/hi interleave pairs of 32-bit elements
|
||
__m128 tmp0 = _mm_unpacklo_ps(mulA, mulB); // a1 b1 a2 b2
|
||
__m128 tmp1 = _mm_unpackhi_ps(mulA, mulB); // a3 b3 _ _
|
||
__m128 tmp2 = _mm_unpacklo_ps(mulC, mulD); // c1 d1 c2 d2
|
||
__m128 tmp3 = _mm_unpackhi_ps(mulC, mulD); // c3 d3 _ _
|
||
|
||
__m128 row1 = _mm_movelh_ps(tmp0, tmp2); // a1 b1 c1 d1
|
||
__m128 row2 = _mm_movehl_ps(tmp2, tmp0); // a2 b2 c2 d2
|
||
__m128 row3 = _mm_movelh_ps(tmp1, tmp3); // a3 b3 c3 d3
|
||
|
||
row1 = _mm_add_ps(row1, row2);
|
||
row1 = _mm_add_ps(row1, row3);
|
||
|
||
return row1;
|
||
}
|
||
|
||
constexpr static VectorF32<1, 8> Dot(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
|
||
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
|
||
) requires(Len == 3 && Packing == 2) {
|
||
// Each register: [X1 X2 X3 Y1 Y2 Y3 _ _]
|
||
// 4 pairs × 2 vectors each = 8 dot products → 1 x __m256
|
||
//
|
||
// After multiply:
|
||
// mulA = [a1 a2 a3 b1 b2 b3 _ _]
|
||
// mulB = [c1 c2 c3 d1 d2 d3 _ _]
|
||
// mulC = [e1 e2 e3 f1 f2 f3 _ _]
|
||
// mulD = [g1 g2 g3 h1 h2 h3 _ _]
|
||
//
|
||
// We need result = [a·, b·, c·, d·, e·, f·, g·, h·]
|
||
// where x· = x1+x2+x3
|
||
//
|
||
// Strategy: use permute to gather element 1s, 2s, 3s across all 8 vectors,
|
||
// then add.
|
||
//
|
||
// Gather indices (from the concatenated view of mulA|mulB|mulC|mulD):
|
||
// vec: a a a b b b _ _ c c c d d d _ _ e e e f f f _ _ g g g h h h _ _
|
||
// idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
||
//
|
||
// elem1 = [a1, b1, c1, d1, e1, f1, g1, h1] → indices [0, 3, 8, 11, 16, 19, 24, 27]
|
||
// elem2 = [a2, b2, c2, d2, e2, f2, g2, h2] → indices [1, 4, 9, 12, 17, 20, 25, 28]
|
||
// elem3 = [a3, b3, c3, d3, e3, f3, g3, h3] → indices [2, 5, 10, 13, 18, 21, 26, 29]
|
||
//
|
||
// Unfortunately AVX2 doesn't have cross-register permutes for 8x32 easily.
|
||
// Use vpermd (_mm256_permutevar8x32) within pairs, then blend/combine.
|
||
//
|
||
// Within each 256-bit register [X1 X2 X3 Y1 Y2 Y3 _ _]:
|
||
// elem1_local = [X1 Y1 ...] → gather from indices 0,3
|
||
// elem2_local = [X2 Y2 ...] → gather from indices 1,4
|
||
// elem3_local = [X3 Y3 ...] → gather from indices 2,5
|
||
//
|
||
// After permutevar8x32 on each mul register:
|
||
// From mulA: row1_part = [a1 b1 _ _ _ _ _ _]
|
||
// From mulB: row1_part = [c1 d1 _ _ _ _ _ _]
|
||
// From mulC: row1_part = [e1 f1 _ _ _ _ _ _]
|
||
// From mulD: row1_part = [g1 h1 _ _ _ _ _ _]
|
||
//
|
||
// Then combine with unpack/shuffle to get full rows.
|
||
|
||
__m256 mulA = _mm256_mul_ps(A0.v, A1.v); // a1 a2 a3 b1 b2 b3 _ _
|
||
__m256 mulB = _mm256_mul_ps(B0.v, B1.v); // c1 c2 c3 d1 d2 d3 _ _
|
||
__m256 mulC = _mm256_mul_ps(C0.v, C1.v); // e1 e2 e3 f1 f2 f3 _ _
|
||
__m256 mulD = _mm256_mul_ps(D0.v, D1.v); // g1 g2 g3 h1 h2 h3 _ _
|
||
|
||
// Permute each register to gather elements by position.
|
||
// For each register [X1 X2 X3 Y1 Y2 Y3 U U]:
|
||
// perm1: [X1 Y1 X2 Y2 X3 Y3 _ _] → indices {0,3,1,4,2,5,6,7}
|
||
__m256i permIdx = _mm256_setr_epi32(0, 3, 1, 4, 2, 5, 6, 7);
|
||
|
||
// After permute: [X1 Y1 X2 Y2 X3 Y3 _ _]
|
||
__m256 pA = _mm256_permutevar8x32_ps(mulA, permIdx); // a1 b1 a2 b2 a3 b3 _ _
|
||
__m256 pB = _mm256_permutevar8x32_ps(mulB, permIdx); // c1 d1 c2 d2 c3 d3 _ _
|
||
__m256 pC = _mm256_permutevar8x32_ps(mulC, permIdx); // e1 f1 e2 f2 e3 f3 _ _
|
||
__m256 pD = _mm256_permutevar8x32_ps(mulD, permIdx); // g1 h1 g2 h2 g3 h3 _ _
|
||
|
||
// Now combine pairs. Each pair contributes 4 consecutive results.
|
||
// pA has [a1 b1 a2 b2 a3 b3 _ _], pB has [c1 d1 c2 d2 c3 d3 _ _]
|
||
// We want:
|
||
// row1 = [a1 b1 c1 d1 | e1 f1 g1 h1]
|
||
// row2 = [a2 b2 c2 d2 | e2 f2 g2 h2]
|
||
// row3 = [a3 b3 c3 d3 | e3 f3 g3 h3]
|
||
//
|
||
// From pA: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3
|
||
// From pB: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3
|
||
//
|
||
// Use unpacklo_epi64 to interleave 64-bit chunks:
|
||
// unpacklo64(pA, pB) within 128-bit lanes:
|
||
// lo lane: pA[0:1]=a1,b1 | pB[0:1]=c1,d1 → [a1 b1 c1 d1]
|
||
// hi lane: pA[4:5]=a3,b3 | pB[4:5]=c3,d3 → [a3 b3 c3 d3]
|
||
// → [a1 b1 c1 d1 | a3 b3 c3 d3]
|
||
//
|
||
// unpackhi64(pA, pB) within 128-bit lanes:
|
||
// lo lane: pA[2:3]=a2,b2 | pB[2:3]=c2,d2 → [a2 b2 c2 d2]
|
||
// hi lane: pA[6:7]=_,_ | pB[6:7]=_,_ → garbage
|
||
// → [a2 b2 c2 d2 | _ _ _ _]
|
||
|
||
__m256i AB_lo = _mm256_unpacklo_epi64(
|
||
_mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a1 b1 c1 d1 | a3 b3 c3 d3]
|
||
__m256i AB_hi = _mm256_unpackhi_epi64(
|
||
_mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a2 b2 c2 d2 | _ _ _ _]
|
||
|
||
__m256i CD_lo = _mm256_unpacklo_epi64(
|
||
_mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e1 f1 g1 h1 | e3 f3 g3 h3]
|
||
__m256i CD_hi = _mm256_unpackhi_epi64(
|
||
_mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e2 f2 g2 h2 | _ _ _ _]
|
||
|
||
// row1 = [a1 b1 c1 d1 | e1 f1 g1 h1] → lo 128 of AB_lo, lo 128 of CD_lo
|
||
// row2 = [a2 b2 c2 d2 | e2 f2 g2 h2] → lo 128 of AB_hi, lo 128 of CD_hi
|
||
// row3 = [a3 b3 c3 d3 | e3 f3 g3 h3] → hi 128 of AB_lo, hi 128 of CD_lo
|
||
|
||
__m256 row1 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x20)); // lo,lo
|
||
__m256 row2 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_hi, CD_hi, 0x20)); // lo,lo
|
||
__m256 row3 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x31)); // hi,hi
|
||
|
||
row1 = _mm256_add_ps(row1, row2);
|
||
row1 = _mm256_add_ps(row1, row3);
|
||
|
||
return row1;
|
||
}
|
||
|
||
constexpr static VectorF32<1, 15> Dot(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
|
||
) requires(Len == 3 && Packing == 5) {
|
||
// __m512: Each register: [A1 A2 A3 B1 B2 B3 C1 C2 C3 D1 D2 D3 E1 E2 E3 _]
|
||
// 3 pairs × 5 vectors each = 15 dot products → fits in 1 x __m512 (slot 16 unused)
|
||
//
|
||
// After multiply of 3 pairs:
|
||
// mul0 = [a1 a2 a3 b1 b2 b3 c1 c2 c3 d1 d2 d3 e1 e2 e3 _]
|
||
// mul1 = [f1 f2 f3 g1 g2 g3 h1 h2 h3 i1 i2 i3 j1 j2 j3 _]
|
||
// mul2 = [k1 k2 k3 l1 l2 l3 m1 m2 m3 n1 n2 n3 o1 o2 o3 _]
|
||
//
|
||
// Result = [a· b· c· d· e· f· g· h· i· j· k· l· m· n· o· _]
|
||
//
|
||
// Strategy: for each mul register, gather element 1s, 2s, 3s with vpermps,
|
||
// then combine across registers.
|
||
//
|
||
// From mul0: 5 vectors at positions {0,1,2}, {3,4,5}, {6,7,8}, {9,10,11}, {12,13,14}
|
||
// elem1 = indices {0, 3, 6, 9, 12} → positions 0..4 of result
|
||
// elem2 = indices {1, 4, 7, 10, 13}
|
||
// elem3 = indices {2, 5, 8, 11, 14}
|
||
|
||
__m512 mul0 = _mm512_mul_ps(A0.v, A1.v);
|
||
__m512 mul1 = _mm512_mul_ps(B0.v, B1.v);
|
||
__m512 mul2 = _mm512_mul_ps(C0.v, C1.v);
|
||
|
||
// Gather elem1, elem2, elem3 from each mul register
|
||
// Each register has 5 vec3s: extract element 1,2,3 of each into consecutive positions
|
||
__m512i idx1 = _mm512_setr_epi32(0, 3, 6, 9, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||
__m512i idx2 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||
__m512i idx3 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||
|
||
// From mul0 → results 0..4, from mul1 → results 5..9, from mul2 → results 10..14
|
||
// Gather from each, then combine.
|
||
|
||
__m512 e1_0 = _mm512_permutexvar_ps(idx1, mul0); // [a1 b1 c1 d1 e1 ...]
|
||
__m512 e2_0 = _mm512_permutexvar_ps(idx2, mul0); // [a2 b2 c2 d2 e2 ...]
|
||
__m512 e3_0 = _mm512_permutexvar_ps(idx3, mul0); // [a3 b3 c3 d3 e3 ...]
|
||
|
||
__m512 e1_1 = _mm512_permutexvar_ps(idx1, mul1); // [f1 g1 h1 i1 j1 ...]
|
||
__m512 e2_1 = _mm512_permutexvar_ps(idx2, mul1); // [f2 g2 h2 i2 j2 ...]
|
||
__m512 e3_1 = _mm512_permutexvar_ps(idx3, mul1); // [f3 g3 h3 i3 j3 ...]
|
||
|
||
__m512 e1_2 = _mm512_permutexvar_ps(idx1, mul2); // [k1 l1 m1 n1 o1 ...]
|
||
__m512 e2_2 = _mm512_permutexvar_ps(idx2, mul2); // [k2 l2 m2 n2 o2 ...]
|
||
__m512 e3_2 = _mm512_permutexvar_ps(idx3, mul2); // [k3 l3 m3 n3 o3 ...]
|
||
|
||
// Now combine: we need positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2
|
||
// Use masked moves to assemble the final row vectors.
|
||
// mask for positions 0-4: 0b0000000000011111 = 0x001F
|
||
// mask for positions 5-9: 0b0000001111100000 = 0x03E0
|
||
// mask for positions 10-14: 0b0111110000000000 = 0x7C00
|
||
|
||
// For reg1, its results are in positions 0..4 but need to go to 5..9.
|
||
// For reg2, its results are in positions 0..4 but need to go to 10..14.
|
||
// Use a different approach: permute reg1/reg2 results to their target positions.
|
||
|
||
// Shift reg1 results from slots 0..4 to slots 5..9
|
||
__m512i shiftIdx1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0);
|
||
// Shift reg2 results from slots 0..4 to slots 10..14
|
||
__m512i shiftIdx2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0);
|
||
|
||
__m512 e1_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e1_1);
|
||
__m512 e2_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e2_1);
|
||
__m512 e3_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e3_1);
|
||
|
||
__m512 e1_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e1_2);
|
||
__m512 e2_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e2_2);
|
||
__m512 e3_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e3_2);
|
||
|
||
// Blend: take positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2
|
||
__mmask16 mask_5_9 = 0x03E0u; // bits 5-9
|
||
__mmask16 mask_10_14 = 0x7C00u; // bits 10-14
|
||
|
||
__m512 row1 = _mm512_mask_mov_ps(e1_0, mask_5_9, e1_1_shifted);
|
||
row1 = _mm512_mask_mov_ps(row1, mask_10_14, e1_2_shifted);
|
||
|
||
__m512 row2 = _mm512_mask_mov_ps(e2_0, mask_5_9, e2_1_shifted);
|
||
row2 = _mm512_mask_mov_ps(row2, mask_10_14, e2_2_shifted);
|
||
|
||
__m512 row3 = _mm512_mask_mov_ps(e3_0, mask_5_9, e3_1_shifted);
|
||
row3 = _mm512_mask_mov_ps(row3, mask_10_14, e3_2_shifted);
|
||
|
||
row1 = _mm512_add_ps(row1, row2);
|
||
row1 = _mm512_add_ps(row1, row3);
|
||
|
||
return row1;
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> Dot(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return DotNoShuffle(A0, A1, C0, C1);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
VectorF32<8, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
|
||
vec = vec.template Shuffle<{{
|
||
0,1, 4,5,
|
||
2,3, 6,7,
|
||
}}>();
|
||
return vec.v;
|
||
} else {
|
||
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
|
||
vec = vec.template Shuffle<{{
|
||
0,1, 4,5,
|
||
8,9, 12,13,
|
||
2,3, 6,7,
|
||
10,11, 14,15
|
||
}}>();
|
||
return vec.v;
|
||
}
|
||
}
|
||
|
||
|
||
private:
|
||
constexpr static VectorF32<1, Packing*4> LengthNoShuffle(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
VectorF32<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, B, C, D);
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
|
||
} else {
|
||
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> LengthNoShuffle(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
VectorF32<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, C);
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
|
||
} else {
|
||
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*4> LengthSqNoShuffle(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> B,
|
||
VectorF32<Len, Packing> C,
|
||
VectorF32<Len, Packing> D
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
return DotNoShuffle(A, A, B, B, C, C, D, D);
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> LengthSqNoShuffle(
|
||
VectorF32<Len, Packing> A,
|
||
VectorF32<Len, Packing> C
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
return DotNoShuffle(A, A, C, C);
|
||
}
|
||
|
||
|
||
constexpr static VectorF32<1, Packing*4> DotNoShuffle(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
|
||
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
|
||
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
|
||
__m128 mulB = _mm_mul_ps(B0.v, B1.v);
|
||
|
||
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 B1 A2 B2
|
||
__m128i row34Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A3 B3 A4 B4
|
||
|
||
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
|
||
__m128 mulD = _mm_mul_ps(D0.v, D1.v);
|
||
|
||
__m128i row12Temp2 = _mm_unpacklo_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C1 D1 C2 D2
|
||
__m128i row34Temp2 = _mm_unpackhi_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C3 D3 C4 D4
|
||
|
||
__m128 row1 = _mm_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
|
||
__m128 row2 = _mm_unpackhi_epi32(row12Temp1, row12Temp2); // A2 C2 B2 D2
|
||
__m128 row3 = _mm_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
|
||
__m128 row4 = _mm_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
|
||
|
||
|
||
row1 = _mm_add_ps(row1, row2);
|
||
row1 = _mm_add_ps(row1, row3);
|
||
row1 = _mm_add_ps(row1, row4);
|
||
|
||
return row1;
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
|
||
__m256 mulB = _mm256_mul_ps(B0.v, B1.v);
|
||
|
||
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 B1 A2 B2
|
||
__m256i row34Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A3 B3 A4 B4
|
||
|
||
__m256 mulC = _mm256_mul_ps(C0.v, C1.v);
|
||
__m256 mulD = _mm256_mul_ps(D0.v, D1.v);
|
||
|
||
__m256i row12Temp2 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C1 D1 C2 D2
|
||
__m256i row34Temp2 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C3 D3 C4 D4
|
||
|
||
__m256 row1 = _mm256_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
|
||
__m256 row2 = _mm256_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2
|
||
__m256 row3 = _mm256_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
|
||
__m256 row4 = _mm256_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
|
||
|
||
row1 = _mm256_add_ps(row1, row2);
|
||
row1 = _mm256_add_ps(row1, row3);
|
||
row1 = _mm256_add_ps(row1, row4);
|
||
|
||
return row1;
|
||
} else {
|
||
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
|
||
__m512 mulB = _mm512_mul_ps(B0.v, B1.v);
|
||
|
||
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 B1 A2 B2
|
||
__m512i row34Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A3 B3 A4 B4
|
||
|
||
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
|
||
__m512 mulD = _mm512_mul_ps(D0.v, D1.v);
|
||
|
||
__m512i row12Temp2 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C1 D1 C2 D2
|
||
__m512i row34Temp2 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C3 D3 C4 D4
|
||
|
||
__m512 row1 = _mm512_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
|
||
__m512 row2 = _mm512_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2
|
||
__m512 row3 = _mm512_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
|
||
__m512 row4 = _mm512_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
|
||
|
||
row1 = _mm512_add_ps(row1, row2);
|
||
row1 = _mm512_add_ps(row1, row3);
|
||
row1 = _mm512_add_ps(row1, row4);
|
||
|
||
return row1;
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<1, Packing*2> DotNoShuffle(
|
||
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
|
||
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
|
||
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
|
||
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
|
||
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // A1 C1 A2 C2
|
||
__m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // B1 D1 B2 D2
|
||
__m128i row1TempTemp1 = row12Temp1;
|
||
__m128i row5TempTemp1 = row56Temp1;
|
||
|
||
row12Temp1 = _mm_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
|
||
row56Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
|
||
|
||
return _mm_add_ps(row12Temp1, row56Temp1);
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
|
||
__m256 mulC = _mm256_mul_ps(C0.v, C1.v);
|
||
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // A1 C1 A2 C2
|
||
__m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // B1 D1 B2 D2
|
||
__m256i row1TempTemp1 = row12Temp1;
|
||
__m256i row5TempTemp1 = row56Temp1;
|
||
|
||
row12Temp1 = _mm256_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
|
||
row56Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
|
||
|
||
return _mm256_add_ps(row12Temp1, row56Temp1);
|
||
} else {
|
||
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
|
||
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
|
||
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // A1 C1 A2 C2
|
||
__m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // B1 D1 B2 D2
|
||
__m512i row1TempTemp1 = row12Temp1;
|
||
__m512i row5TempTemp1 = row56Temp1;
|
||
|
||
row12Temp1 = _mm512_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
|
||
row56Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
|
||
|
||
return _mm512_add_ps(row12Temp1, row56Temp1);
|
||
}
|
||
}
|
||
public:
|
||
|
||
template <std::array<bool, Len> ShuffleValues>
|
||
constexpr static VectorF32<Len, Packing> Blend(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) {
|
||
constexpr auto mask = VectorBase<Len, Packing, float>::template GetBlendMaskEpi32<ShuffleValues>();
|
||
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
|
||
return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b.v), mask));
|
||
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
|
||
#ifndef __AVX512BW__
|
||
#ifndef __AVX512VL__
|
||
static_assert(false, "No __AVX512BW__ and __AVX512VL__ support");
|
||
#endif
|
||
#endif
|
||
return _mm256_castsi256_ps(_mm256_mask_blend_epi32(mask, _mm256_castps_si256(a.v), _mm256_castps_si256(b.v)));
|
||
} else {
|
||
return _mm512_castsi512_ps(_mm512_mask_blend_epi32(mask, _mm512_castps_si512(a.v), _mm512_castps_si512(b.v)));
|
||
}
|
||
}
|
||
|
||
constexpr static VectorF32<Len, Packing> Rotate(VectorF32<3, Packing> v, VectorF32<4, Packing> q) requires(Len == 3) {
|
||
VectorF32<3, Packing> qv(q);
|
||
VectorF32<Len, Packing> t = Cross(qv, v) * float(2);
|
||
return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
|
||
}
|
||
|
||
constexpr static VectorF32<4, 2> RotatePivot(VectorF32<3, Packing> v, VectorF32<4, Packing> q, VectorF32<3, Packing> pivot) requires(Len == 3) {
|
||
VectorF32<Len, Packing> translated = v - pivot;
|
||
VectorF32<3, Packing> qv(q.v);
|
||
VectorF32<Len, Packing> t = Cross(qv, translated) * float(2);
|
||
VectorF32<Len, Packing> rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
|
||
return rotated + pivot;
|
||
}
|
||
|
||
constexpr static VectorF32<4, Packing> QuanternionFromEuler(VectorF32<3, Packing> EulerHalf) requires(Len == 4) {
|
||
std::tuple<VectorF32<3, Packing>, VectorF32<3, Packing>> sinCos = EulerHalf.SinCos();
|
||
VectorF32<4, Packing> sin = std::get<0>(sinCos);
|
||
VectorF32<4, Packing> cos = std::get<1>(sinCos);
|
||
|
||
VectorF32<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>();
|
||
row1 = Blend<{{0,1,1,1}}>(sin, row1);
|
||
|
||
VectorF32<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>();
|
||
row2 = Blend<{{1,0,1,1}}>(sin, row2);
|
||
|
||
row1 *= row2;
|
||
|
||
VectorF32<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>();
|
||
row3 = Blend<{{1,1,0,1}}>(sin, row3);
|
||
|
||
row1 *= row3;
|
||
|
||
VectorF32<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>();
|
||
row4 = Blend<{{0,1,1,1}}>(cos, row4);
|
||
|
||
VectorF32<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>();
|
||
row5 = Blend<{{1,0,1,1}}>(cos, row5);
|
||
|
||
row4 *= row5;
|
||
|
||
VectorF32<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>();
|
||
row6 = Blend<{{1,1,0,1}}>(cos, row6);
|
||
row6 = row6.template Negate<{{true,false,true,false}}>();
|
||
|
||
row1 = MulitplyAdd(row4, row6, row1);
|
||
|
||
return row1;
|
||
}
|
||
};
|
||
}
|
||
|
||
|
||
export template <std::uint32_t Len, std::uint32_t Packing>
|
||
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
|
||
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
|
||
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.Store();
|
||
std::string out = "{";
|
||
for(std::uint32_t i = 0; i < Packing; i++) {
|
||
out += "{";
|
||
for(std::uint32_t i2 = 0; i2 < Len; i2++) {
|
||
out += std::format("{}", static_cast<float>(vec[i * Len + i2]));
|
||
if (i2 + 1 < Len) out += ",";
|
||
}
|
||
out += "}";
|
||
}
|
||
out += "}";
|
||
return std::formatter<std::string>::format(out, ctx);
|
||
}
|
||
};
|
||
#endif |