2026-03-19 02:19:01 +01:00
|
|
|
/*
|
|
|
|
|
Crafter®.Math
|
|
|
|
|
Copyright (C) 2026 Catcrafts®
|
|
|
|
|
catcrafts.net
|
|
|
|
|
|
|
|
|
|
This library is free software; you can redistribute it and/or
|
|
|
|
|
modify it under the terms of the GNU Lesser General Public
|
|
|
|
|
License version 3.0 as published by the Free Software Foundation;
|
|
|
|
|
|
|
|
|
|
This library is distributed in the hope that it will be useful,
|
|
|
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
|
|
|
Lesser General Public License for more details.
|
|
|
|
|
|
|
|
|
|
You should have received a copy of the GNU Lesser General Public
|
|
|
|
|
License along with this library; if not, write to the Free Software
|
|
|
|
|
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
|
|
|
*/
|
|
|
|
|
module;
|
|
|
|
|
#ifdef __x86_64
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
#endif
|
|
|
|
|
export module Crafter.Math:VectorF16;
|
|
|
|
|
import std;
|
|
|
|
|
import :Vector;
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX512FP16__
|
|
|
|
|
namespace Crafter {
|
|
|
|
|
export template <std::uint32_t Len, std::uint32_t Packing, std::uint32_t Repeats>
|
|
|
|
|
struct VectorF16 {
|
|
|
|
|
static constexpr std::uint32_t MaxSize = 32;
|
|
|
|
|
static constexpr std::uint32_t MaxElement = 8;
|
|
|
|
|
static consteval std::uint32_t GetAlignment() {
|
|
|
|
|
if constexpr (Len * Packing <= 8) {
|
|
|
|
|
return 8;
|
|
|
|
|
}
|
|
|
|
|
if constexpr (Len * Packing <= 16) {
|
|
|
|
|
return 16;
|
|
|
|
|
}
|
|
|
|
|
if constexpr (Len * Packing <= 32) {
|
|
|
|
|
return 32;
|
|
|
|
|
}
|
|
|
|
|
static_assert(Len * Packing <= 32, "Len * Packing is larger than supported max size of 32");
|
|
|
|
|
static_assert(Len * Packing <= 8, "Len * Packing is larger than supported packed size of 8");
|
|
|
|
|
static_assert(Len * Packing * Repeats <= 32, "Len * Packing * Repeats is larger than supported max of 32");
|
|
|
|
|
}
|
|
|
|
|
static consteval std::uint32_t GetTotalSize() {
|
|
|
|
|
return GetAlignment() * Repeats;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using VectorType = std::conditional_t<
|
|
|
|
|
(GetTotalSize() == 32), __m512h,
|
|
|
|
|
std::conditional_t<(GetTotalSize() == 16), __m256h, __m128h>
|
|
|
|
|
>;
|
|
|
|
|
|
|
|
|
|
VectorType v;
|
2026-03-18 03:16:29 +01:00
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
constexpr VectorF16() = default;
|
|
|
|
|
constexpr VectorF16(VectorType v) : v(v) {}
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr VectorF16(const _Float16* vB) {
|
|
|
|
|
Load(vB);
|
|
|
|
|
};
|
|
|
|
|
constexpr VectorF16(_Float16 val) {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm_set1_ph(val);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm256_set1_ph(val);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm512_set1_ph(val);
|
2026-03-19 02:19:01 +01:00
|
|
|
}
|
|
|
|
|
};
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr void Load(const _Float16* vB) {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm_loadu_ph(vB);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm256_loadu_ph(vB);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else {
|
2026-03-22 03:51:09 +01:00
|
|
|
v = _mm512_loadu_ph(vB);
|
2026-03-19 02:19:01 +01:00
|
|
|
}
|
|
|
|
|
}
|
2026-03-22 20:53:17 +01:00
|
|
|
constexpr void Store(_Float16* vB) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
_mm_storeu_ph(vB, v);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
2026-03-22 03:51:09 +01:00
|
|
|
_mm256_storeu_ph(vB, v);
|
2026-03-19 02:19:01 +01:00
|
|
|
} else {
|
2026-03-22 03:51:09 +01:00
|
|
|
_mm512_storeu_ph(vB, v);
|
2026-03-19 02:19:01 +01:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::uint32_t VLen, std::uint32_t VAlign>
|
|
|
|
|
constexpr Vector<_Float16, VLen, VAlign> Store() const {
|
|
|
|
|
Vector<_Float16, VLen, VAlign> returnVec;
|
2026-03-22 20:53:17 +01:00
|
|
|
Store(returnVec.v);
|
2026-03-19 02:19:01 +01:00
|
|
|
return returnVec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::uint32_t BLen, std::uint32_t BPacking, std::uint32_t BRepeats>
|
|
|
|
|
constexpr operator VectorF16<BLen, BPacking, BRepeats>() const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m256h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm256_castph256_ph128(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m512h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm512_castph512_ph128(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m512h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm512_castph512_ph256(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m128h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm256_castph128_ph256(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m128h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m512h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm512_castph128_ph512(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h> && std::is_same_v<typename VectorF16<BLen, BPacking, BRepeats>::VectorType, __m512h>) {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(_mm512_castph256_ph512(v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<BLen, BPacking, BRepeats>(v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator+(VectorF16<Len, Packing, Repeats> b) const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_add_ph(v, b.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_add_ph(v, b.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_add_ph(v, b.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator-(VectorF16<Len, Packing, Repeats> b) const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_sub_ph(v, b.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_sub_ph(v, b.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_sub_ph(v, b.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator*(VectorF16<Len, Packing, Repeats> b) const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_mul_ph(v, b.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_mul_ph(v, b.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(v, b.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator/(VectorF16<Len, Packing, Repeats> b) const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_div_ph(v, b.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_div_ph(v, b.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_div_ph(v, b.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
2026-03-18 03:16:29 +01:00
|
|
|
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr void operator+=(VectorF16<Len, Packing, Repeats> b) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
v = _mm_add_ph(v, b.v);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
v = _mm256_add_ph(v, b.v);
|
|
|
|
|
} else {
|
|
|
|
|
v = _mm512_add_ph(v, b.v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr void operator-=(VectorF16<Len, Packing, Repeats> b) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
v = _mm_sub_ph(v, b.v);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
v = _mm256_sub_ph(v, b.v);
|
|
|
|
|
} else {
|
|
|
|
|
v = _mm512_sub_ph(v, b.v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr void operator*=(VectorF16<Len, Packing, Repeats> b) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
v = _mm_mul_ph(v, b.v);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
v = _mm256_mul_ph(v, b.v);
|
|
|
|
|
} else {
|
|
|
|
|
v = _mm512_mul_ph(v, b.v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr void operator/=(VectorF16<Len, Packing, Repeats> b) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
v = _mm_div_ph(v, b.v);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
v = _mm256_div_ph(v, b.v);
|
|
|
|
|
} else {
|
|
|
|
|
v = _mm512_div_ph(v, b.v);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator+(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
return this + vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator-(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
return this - vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator*(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
return this * vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator/(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
return this / vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr void operator+=(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
this += vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr void operator-=(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
this -= vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr void operator*=(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
this *= vB;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr void operator/=(_Float16 b) const {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> vB(b);
|
|
|
|
|
this /= vB;
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
constexpr VectorF16<Len, Packing, Repeats> operator-(){
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
alignas(16) constexpr std::uint64_t mask[] {0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000};
|
|
|
|
|
__m128i sign_mask = _mm_load_si128(reinterpret_cast<const __m128i*>(mask));
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_castsi128_ph(_mm_xor_si128(sign_mask, _mm_castph_si128(v))));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
alignas(16) constexpr std::uint64_t mask[] {0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000};
|
|
|
|
|
__m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_xor_si256(sign_mask, _mm256_castph_si256(v))));
|
|
|
|
|
} else {
|
|
|
|
|
alignas(16) constexpr std::uint64_t mask[] {0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000};
|
|
|
|
|
__m512i sign_mask = _mm512_load_si512(reinterpret_cast<const __m256i*>(mask));
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_xor_si512(sign_mask, _mm512_castph_si512(v))));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr bool operator==(VectorF16<Len, Packing, Repeats> b) const {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 255;
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 65535;
|
|
|
|
|
} else {
|
|
|
|
|
return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 4294967295;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr bool operator!=(VectorF16<Len, Packing, Repeats> b) const {
|
2026-03-19 02:19:01 +01:00
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 255;
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 65535;
|
|
|
|
|
} else {
|
|
|
|
|
return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 4294967295;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr void Normalize() {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
_Float16 dot = LengthSq();
|
|
|
|
|
__m128h vec = _mm_set1_ph(dot);
|
|
|
|
|
__m128h sqrt = _mm_rsqrt_ph(vec);
|
|
|
|
|
v = _mm_div_ps(v, sqrt);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
_Float16 dot = LengthSq();
|
|
|
|
|
__m256h vec = _mm256_set1_ph(dot);
|
|
|
|
|
__m256h sqrt = _mm256_rsqrt_ph(vec);
|
|
|
|
|
v = _mm256_div_ps(v, sqrt);
|
|
|
|
|
} else {
|
|
|
|
|
_Float16 dot = LengthSq();
|
|
|
|
|
__m512h vec = _mm512_set1_ph(dot);
|
|
|
|
|
__m512h sqrt = _mm512_rsqrt_ph(vec);
|
|
|
|
|
v = _mm512_div_ps(v, sqrt);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr _Float16 Length() const {
|
|
|
|
|
_Float16 Result = LengthSq();
|
|
|
|
|
return std::sqrtf(Result);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr _Float16 LengthSq() const {
|
|
|
|
|
return Dot(*this, *this);
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr VectorF16<Len, Packing, Repeats> Cos() requires(Len == 3) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_cos_ph(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_cos_ph(v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_cos_ph(v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> Sin() requires(Len == 3) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_sin_ph(v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_sin_ph(v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_sin_ph(v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D, std::uint8_t E, std::uint8_t F, std::uint8_t G, std::uint8_t H>
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> Shuffle() {
|
|
|
|
|
if constexpr(A == B-1 && C == D-1 && E == F-1 && G == H-1) {
|
|
|
|
|
constexpr std::uint32_t val =
|
|
|
|
|
(A & 0x3) |
|
|
|
|
|
((B & 0x3) << 2) |
|
|
|
|
|
((C & 0x3) << 4) |
|
|
|
|
|
((D & 0x3) << 6) |
|
|
|
|
|
((E & 0x3) << 8) |
|
|
|
|
|
((F & 0x3) << 10) |
|
|
|
|
|
((G & 0x3) << 12) |
|
|
|
|
|
((H & 0x3) << 14);
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), val)));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), val)));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm_512castph_si512(v), val)));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask[] {
|
|
|
|
|
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H
|
|
|
|
|
};
|
|
|
|
|
__m128h shuffleVec = _mm_loadu_epi8(shuffleMask);
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask[] {
|
|
|
|
|
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H,
|
|
|
|
|
A+16,A+16,B+16,B+16,C+16,C+16,D+16,D+16,E+16,E+16,F+16,F+16,G+16,G+16,H+16,H+16,
|
|
|
|
|
};
|
|
|
|
|
__m256h shuffleVec = _mm256_loadu_epi8(shuffleMask);
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(v), shuffleVec)));
|
|
|
|
|
} else {
|
|
|
|
|
constexpr std::uint8_t shuffleMask[] {
|
|
|
|
|
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H,
|
|
|
|
|
A+16,A+16,B+16,B+16,C+16,C+16,D+16,D+16,E+16,E+16,F+16,F+16,G+16,G+16,H+16,H+16,
|
|
|
|
|
A+32,A+32,B+32,B+32,C+32,C+32,D+32,D+32,E+32,E+32,F+32,F+32,G+32,G+32,H+32,H+32,
|
|
|
|
|
A+48,A+48,B+48,B+48,C+48,C+48,D+48,D+48,E+48,E+48,F+48,F+48,G+48,G+48,H+48,H+48,
|
|
|
|
|
};
|
|
|
|
|
__m512h shuffleVec = _mm512_loadu_epi8(shuffleMask);
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
|
|
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
|
|
|
|
|
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1
|
|
|
|
|
>
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> Shuffle() requires(Repeats == 2) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask[] {
|
|
|
|
|
A0,A0,B0,B0,C0,C0,D0,D0,E0,E0,F0,F0,G0,G0,H0,H0,
|
|
|
|
|
A1,A1,B1,B1,C1,C1,D1,D1,E1,E1,F1,F1,G1,G1,H1,H1,
|
|
|
|
|
};
|
|
|
|
|
__m256h shuffleVec = _mm256_loadu_epi8(shuffleMask);
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(v), shuffleVec)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
|
|
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
|
|
|
|
|
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1,
|
|
|
|
|
std::uint8_t A2, std::uint8_t B2, std::uint8_t C2, std::uint8_t D2, std::uint8_t E2, std::uint8_t F2, std::uint8_t G2, std::uint8_t H2,
|
|
|
|
|
std::uint8_t A3, std::uint8_t B3, std::uint8_t C3, std::uint8_t D3, std::uint8_t E3, std::uint8_t F3, std::uint8_t G3, std::uint8_t H3
|
|
|
|
|
>
|
|
|
|
|
constexpr VectorF16<Len, Packing, Repeats> Shuffle() requires(Repeats == 4) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask[] {
|
|
|
|
|
A0,A0,B0,B0,C0,C0,D0,D0,E0,E0,F0,F0,G0,G0,H0,H0,
|
|
|
|
|
A1,A1,B1,B1,C1,C1,D1,D1,E1,E1,F1,F1,G1,G1,H1,H1,
|
|
|
|
|
A2,A2,B2,B2,C2,C2,D2,D2,E2,E2,F2,F2,G2,G2,H2,H2,
|
|
|
|
|
A3,A3,B3,B3,C3,C3,D3,D3,E3,E3,F3,F3,G3,G3,H3,H3,
|
|
|
|
|
};
|
|
|
|
|
__m512h shuffleVec = _mm512_loadu_epi8(shuffleMask);
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr VectorF16<Len, Packing, Repeats> MulitplyAdd(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b, VectorF16<Len, Packing, Repeats> add) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_fmadd_ph(a, b, add));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_fmadd_ph(a, b, add));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_fmadd_ph(a, b, add));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr VectorF16<Len, Packing, Repeats> MulitplySub(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b, VectorF16<Len, Packing, Repeats> sub) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_fmsub_ph(a, b, sub));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_fmsub_ph(a, b, sub));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_fmsub_ph(a, b, sub));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 03:22:22 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Cross(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b) requires(Len == 3 && Packing == 2) {
|
|
|
|
|
if constexpr(Len == 3) {
|
|
|
|
|
if constexpr(Repeats == 1) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask1[] {
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1);
|
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec1));
|
|
|
|
|
__m128h row4 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec1));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMask3[] {
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3);
|
|
|
|
|
__m128h row3 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec3));
|
|
|
|
|
__m128h row2 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec3));
|
|
|
|
|
|
|
|
|
|
__m128h result = _mm_mul_ph(row3, row4);
|
|
|
|
|
return _mm_fmsub_ph(row1,row2,result);
|
|
|
|
|
}
|
|
|
|
|
if constexpr(Repeats == 2) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask1[] {
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16,
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVec1 = _mm256_loadu_epi8(shuffleMask1);
|
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(a.v), shuffleVec1));
|
|
|
|
|
__m256h row4 = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(b.v), shuffleVec1));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMask3[] {
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16,
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVec3 = _mm256_loadu_epi8(shuffleMask3);
|
|
|
|
|
__m256h row3 = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(a.v), shuffleVec3));
|
|
|
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(b.v), shuffleVec3));
|
|
|
|
|
|
|
|
|
|
__m256h result = _mm256_mul_ph(row3, row4);
|
|
|
|
|
return _mm256_fmsub_ph(row1,row2,result);
|
|
|
|
|
}
|
|
|
|
|
if constexpr(Repeats == 4) {
|
|
|
|
|
constexpr std::uint8_t shuffleMask1[] {
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16,
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16,
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16,
|
|
|
|
|
2,3,4,5,0,1,6,7,11,12,13,14,9,10,15,16
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1);
|
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec1));
|
|
|
|
|
__m512h row4 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec1));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMask3[] {
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16,
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16,
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16,
|
|
|
|
|
4,5,0,1,2,3,6,7,13,14,8,9,11,12,15,16
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3);
|
|
|
|
|
__m512h row3 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec3));
|
|
|
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec3));
|
|
|
|
|
|
|
|
|
|
__m512h result = _mm512_mul_ph(row3, row4);
|
|
|
|
|
return _mm512_fmsub_ph(row1,row2,result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2026-03-19 02:19:01 +01:00
|
|
|
|
|
|
|
|
constexpr static _Float16 Dot(VectorF16<Len, 1, 1> a, VectorF16<Len, 1, 1> b) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
__m128h mul = _mm_mul_ph(a.v, b.v);
|
|
|
|
|
return _mm_reduce_add_ph(mul);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
__m256h mul = _mm256_mul_ph(a.v, b.v);
|
|
|
|
|
return _mm256_reduce_add_ph(mul);
|
|
|
|
|
} else {
|
|
|
|
|
__m512h mul = _mm512_mul_ph(a.v, b.v);
|
|
|
|
|
return _mm512_reduce_add_ph(mul);
|
|
|
|
|
}
|
|
|
|
|
}
|
2026-03-22 03:51:09 +01:00
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
|
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Packing == 1) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskB[] {
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecB = _mm_loadu_epi8(shuffleMaskB);
|
|
|
|
|
__m128h fLenghtB = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecB));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskD[] {
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecD = _mm_loadu_epi8(shuffleMaskD);
|
|
|
|
|
__m128h fLenghtD = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecD));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskF[] {
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecF = _mm_loadu_epi8(shuffleMaskF);
|
|
|
|
|
__m128h fLenghtF = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecF));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskH[] {
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecH = _mm_loadu_epi8(shuffleMaskH);
|
|
|
|
|
__m128h fLenghtH = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecH));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm_mul_ph(B.v, fLenghtB),
|
|
|
|
|
_mm_mul_ph(C.v, fLenghtC),
|
|
|
|
|
_mm_mul_ph(D.v, fLenghtD),
|
|
|
|
|
_mm_mul_ph(E.v, fLenghtE),
|
|
|
|
|
_mm_mul_ph(F.v, fLenghtF),
|
|
|
|
|
_mm_mul_ph(G.v, fLenghtG),
|
|
|
|
|
_mm_mul_ph(H.v, fLenghtH)
|
|
|
|
|
};
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m256h one = _mm256_loadu_ph(oneArr);
|
|
|
|
|
__m256h fLenght = _mm256_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskB[] {
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3,
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecB = _mm256_loadu_epi8(shuffleMaskB);
|
|
|
|
|
__m256h fLenghtB = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecB));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskD[] {
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7,
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecD = _mm256_loadu_epi8(shuffleMaskD);
|
|
|
|
|
__m256h fLenghtD = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecD));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9,
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskF[] {
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecF = _mm256_loadu_epi8(shuffleMaskF);
|
|
|
|
|
__m256h fLenghtF = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecF));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskH[] {
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecH = _mm256_loadu_epi8(shuffleMaskH);
|
|
|
|
|
__m256h fLenghtH = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecH));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm256_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm256_mul_ph(B.v, fLenghtB),
|
|
|
|
|
_mm256_mul_ph(C.v, fLenghtC),
|
|
|
|
|
_mm256_mul_ph(D.v, fLenghtD),
|
|
|
|
|
_mm256_mul_ph(E.v, fLenghtE),
|
|
|
|
|
_mm256_mul_ph(F.v, fLenghtF),
|
|
|
|
|
_mm256_mul_ph(G.v, fLenghtG),
|
|
|
|
|
_mm256_mul_ph(H.v, fLenghtH)
|
|
|
|
|
};
|
|
|
|
|
} else {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m512h one = _mm512_loadu_ph(oneArr);
|
|
|
|
|
__m512h fLenght = _mm512_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskB[] {
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3,
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3,
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3,
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecB = _mm512_loadu_epi8(shuffleMaskB);
|
|
|
|
|
__m512h fLenghtB = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecB));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskD[] {
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7,
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7,
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7,
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecD = _mm512_loadu_epi8(shuffleMaskD);
|
|
|
|
|
__m512h fLenghtD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecD));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9,
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9,
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9,
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskF[] {
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecF = _mm512_loadu_epi8(shuffleMaskF);
|
|
|
|
|
__m512h fLenghtF = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecF));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskH[] {
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecH = _mm512_loadu_epi8(shuffleMaskH);
|
|
|
|
|
__m512h fLenghtH = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecH));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm512_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm512_mul_ph(B.v, fLenghtB),
|
|
|
|
|
_mm512_mul_ph(C.v, fLenghtC),
|
|
|
|
|
_mm512_mul_ph(D.v, fLenghtD),
|
|
|
|
|
_mm512_mul_ph(E.v, fLenghtE),
|
|
|
|
|
_mm512_mul_ph(F.v, fLenghtF),
|
|
|
|
|
_mm512_mul_ph(G.v, fLenghtG),
|
|
|
|
|
_mm512_mul_ph(H.v, fLenghtH)
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 05:53:17 +01:00
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
2026-03-19 03:22:22 +01:00
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G
|
|
|
|
|
) requires(Packing == 2) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, C, E, G);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecC = _mm_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m128h fLenghtC = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecG = _mm_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m128h fLenghtG = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm_mul_ph(C.v, fLenghtC),
|
|
|
|
|
_mm_mul_ph(E.v, fLenghtE),
|
|
|
|
|
_mm_mul_ph(G.v, fLenghtG),
|
|
|
|
|
};
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, C, E, G);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m256h one = _mm256_loadu_ph(oneArr);
|
|
|
|
|
__m256h fLenght = _mm256_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3,
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecC = _mm256_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m256h fLenghtC = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11,
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecG = _mm256_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m256h fLenghtG = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm256_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm256_mul_ph(C.v, fLenghtC),
|
|
|
|
|
_mm256_mul_ph(E.v, fLenghtE),
|
|
|
|
|
_mm256_mul_ph(G.v, fLenghtG),
|
|
|
|
|
};
|
|
|
|
|
} else {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, C, E, G);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m512h one = _mm512_loadu_ph(oneArr);
|
|
|
|
|
__m512h fLenght = _mm512_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3,
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3,
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3,
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskC[] {
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecC = _mm512_loadu_epi8(shuffleMaskC);
|
|
|
|
|
__m512h fLenghtC = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecC));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11,
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11,
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11,
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,10,11
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskG[] {
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
12,13,12,13,12,13,12,13,14,15,14,15,14,15,14,15,
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecG = _mm512_loadu_epi8(shuffleMaskG);
|
|
|
|
|
__m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG));
|
|
|
|
|
|
|
|
|
|
return {
|
2026-03-19 05:53:17 +01:00
|
|
|
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(A.v, fLenghtA)),
|
|
|
|
|
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(C.v, fLenghtC)),
|
|
|
|
|
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(E.v, fLenghtE)),
|
|
|
|
|
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(G.v, fLenghtG)),
|
2026-03-19 03:22:22 +01:00
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 05:53:17 +01:00
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
2026-03-19 03:22:22 +01:00
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E
|
|
|
|
|
) requires(Packing == 4) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, E);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecA = _mm_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m128h fLenghtA = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m128i shuffleVecE = _mm_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m128h fLenghtE = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm_mul_ph(E.v, fLenghtE),
|
|
|
|
|
};
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, E);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m256h one = _mm256_loadu_ph(oneArr);
|
|
|
|
|
__m256h fLenght = _mm256_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecA = _mm256_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m256h fLenghtA = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15,
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m256i shuffleVecE = _mm256_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m256h fLenghtE = _mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm256_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm256_mul_ph(E.v, fLenghtE),
|
|
|
|
|
};
|
|
|
|
|
} else {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, E);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m512h one = _mm512_loadu_ph(oneArr);
|
|
|
|
|
__m512h fLenght = _mm512_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskA[] {
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecA = _mm512_loadu_epi8(shuffleMaskA);
|
|
|
|
|
__m512h fLenghtA = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecA));
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskE[] {
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15,
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15,
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15,
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecE = _mm512_loadu_epi8(shuffleMaskE);
|
|
|
|
|
__m512h fLenghtE = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecE));
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
_mm512_mul_ph(A.v, fLenghtA),
|
|
|
|
|
_mm512_mul_ph(E.v, fLenghtE),
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, 8>, VectorF16<Len, Packing, 8>> NormalizeRepeated(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Len == 8 && Packing == 1 && Repeats == 1) {
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskABCD[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,
|
|
|
|
|
2,3,2,3,2,3,2,3,2,3,2,3,2,3,2,3,
|
|
|
|
|
4,5,4,5,4,5,4,5,4,5,4,5,4,5,4,5,
|
|
|
|
|
6,7,6,7,6,7,6,7,6,7,6,7,6,7,6,7
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecABCD = _mm512_loadu_epi8(shuffleMaskABCD); //10 0.5
|
|
|
|
|
__m512h fLenghtABCD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(_mm512_castph128_ph512(fLenght)), shuffleVecABCD)); //1 1
|
|
|
|
|
|
|
|
|
|
__m512h vecABCD;
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(A.v), 0)); //3 1
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(B.v), 1));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(C.v), 2));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(D.v), 3));
|
|
|
|
|
vecABCD = _mm512_mul_ph(vecABCD, fLenghtABCD); //4 0.5
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskEFGH[] {
|
|
|
|
|
8,9,8,9,8,9,8,9,8,9,8,9,8,9,8,9,
|
|
|
|
|
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
|
|
|
|
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
|
|
|
|
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
|
|
|
|
};
|
|
|
|
|
__m512h shuffleVecEFGH = _mm512_loadu_epi8(shuffleMaskEFGH); //10 0.5
|
|
|
|
|
__m512h fLenghtEFGH = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(_mm512_castph128_ph512(fLenght)), _mm512_castph_si512(shuffleVecEFGH))); //1 1
|
|
|
|
|
|
|
|
|
|
__m512h vecEFGH;
|
|
|
|
|
vecEFGH = _mm512_castps_ph(_mm512_insertf32x4(vecEFGH, _mm_castph_ps(E.v), 0)); //3 1
|
|
|
|
|
vecEFGH = _mm512_castps_ph(_mm512_insertf32x4(vecEFGH, _mm_castph_ps(F.v), 1));
|
|
|
|
|
vecEFGH = _mm512_castps_ph(_mm512_insertf32x4(vecEFGH, _mm_castph_ps(G.v), 2));
|
|
|
|
|
vecEFGH = _mm512_castps_ph(_mm512_insertf32x4(vecEFGH, _mm_castph_ps(H.v), 3));
|
|
|
|
|
vecEFGH = _mm512_mul_ph(vecABCD, fLenghtEFGH); //4 0.5
|
|
|
|
|
return {
|
|
|
|
|
vecABCD,
|
|
|
|
|
vecEFGH
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, 8>, VectorF16<Len, Packing, 8>> NormalizeRepeated(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Len == 4 && Packing == 2 && Repeats == 1) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskABCD[] {
|
|
|
|
|
0,1,0,1,0,1,0,1,2,3,2,3,2,3,2,3,
|
|
|
|
|
4,5,4,5,4,5,4,5,6,7,6,7,6,7,6,7,
|
|
|
|
|
8,9,8,9,8,9,8,9,10,11,10,11,10,11,
|
|
|
|
|
12,13,12,13,12,13,14,15,14,15,14,15
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecABCD = _mm512_loadu_epi8(shuffleMaskABCD);
|
|
|
|
|
__m512h fLenghtABCD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(_mm512_castph128_ph512(fLenght)), shuffleVecABCD));
|
|
|
|
|
|
|
|
|
|
__m512h vecABCD;
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(A.v), 0));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(B.v), 1));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(C.v), 2));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(D.v), 3));
|
|
|
|
|
vecABCD = _mm512_mul_ph(vecABCD, fLenghtABCD);
|
|
|
|
|
return vecABCD;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing, 8>, VectorF16<Len, Packing, 8>> NormalizeRepeated(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Len == 2 && Packing == 4 && Repeats == 1) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenght = Length(A, B, C, D, E, F, G, H);
|
|
|
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
|
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
|
|
|
__m128h fLenght = _mm_div_ph(one, lenght.v);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint8_t shuffleMaskABCD[] {
|
|
|
|
|
0,1,0,1,2,3,2,3,4,5,4,5,6,7,6,7,
|
|
|
|
|
8,9,8,9,10,11,10,11,12,13,12,13,14,15,14,15,
|
|
|
|
|
16,17,16,17,18,19,18,19,20,21,20,21,22,23,22,23,
|
|
|
|
|
24,25,24,25,26,27,26,27,28,29,28,29,30,31,30,31
|
|
|
|
|
};
|
|
|
|
|
__m512i shuffleVecABCD = _mm512_loadu_epi8(shuffleMaskABCD);
|
|
|
|
|
__m512h fLenghtABCD = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(_mm512_castph128_ph512(fLenght)), shuffleVecABCD));
|
|
|
|
|
|
|
|
|
|
__m512h vecABCD;
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(A.v), 0));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(B.v), 1));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(C.v), 2));
|
|
|
|
|
vecABCD = _mm512_castps_ph(_mm512_insertf32x4(vecABCD, _mm_castph_ps(D.v), 3));
|
|
|
|
|
vecABCD = _mm512_mul_ph(vecABCD, fLenghtABCD);
|
|
|
|
|
return vecABCD;
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Length(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Packing == 1) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenghtSq = LengthSq(A, B, C, D, E, F, G, H);
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_sqrt_ph(lenghtSq.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 05:53:17 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Length(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G
|
|
|
|
|
) requires(Packing == 2) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenghtSq = LengthSq(A, C, E, G);
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_sqrt_ph(lenghtSq.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Length(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E
|
|
|
|
|
) requires(Packing == 2) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> lenghtSq = LengthSq(A, E);
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm256_sqrt_ph(lenghtSq.v));
|
|
|
|
|
} else {
|
|
|
|
|
return VectorF16<Len, Packing, Repeats>(_mm512_sqrt_ph(lenghtSq.v));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H
|
|
|
|
|
) requires(Packing == 1) {
|
|
|
|
|
return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 05:53:17 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G
|
|
|
|
|
) requires(Packing == 2) {
|
|
|
|
|
return Dot(A, A, C, C, E, E, G, G);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E
|
|
|
|
|
) requires(Packing == 4) {
|
|
|
|
|
return Dot(A, A, E, E);
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Dot(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A0, VectorF16<Len, Packing, Repeats> A1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> B0, VectorF16<Len, Packing, Repeats> B1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C0, VectorF16<Len, Packing, Repeats> C1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> D0, VectorF16<Len, Packing, Repeats> D1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E0, VectorF16<Len, Packing, Repeats> E1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> F0, VectorF16<Len, Packing, Repeats> F1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G0, VectorF16<Len, Packing, Repeats> G1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> H0, VectorF16<Len, Packing, Repeats> H1
|
|
|
|
|
) requires(Packing == 1) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m128h mulB = _mm_mul_ph(B0.v, B1.v);
|
|
|
|
|
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
|
|
|
|
__m128i row56Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
|
|
|
|
__m128i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m128i row5TempTemp1 = row56Temp1;
|
|
|
|
|
|
|
|
|
|
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m128h mulD = _mm_mul_ph(D0.v, D1.v);
|
|
|
|
|
__m128i row34Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulC), _mm_castph_si128(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
|
|
|
|
__m128i row78Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
|
|
|
|
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
|
|
|
|
row56Temp1 = _mm_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
|
|
|
|
row78Temp1 = _mm_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
|
|
|
|
|
|
|
|
|
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m128h mulF = _mm_mul_ph(F0.v, F1.v);
|
|
|
|
|
__m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
|
|
|
|
__m128i row56Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
|
|
|
|
__m128i row1TempTemp2 = row12Temp2;
|
|
|
|
|
__m128i row5TempTemp2 = row56Temp2;
|
|
|
|
|
|
|
|
|
|
__m128h mulG = _mm_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m128h mulH = _mm_mul_ph(H0.v, H1.v);
|
|
|
|
|
__m128i row34Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulG), _mm_castph_si128(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
|
|
|
|
__m128i row78Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
|
|
|
|
|
|
|
|
|
row12Temp2 = _mm_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
|
|
|
|
row34Temp2 = _mm_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
|
|
|
|
row56Temp2 = _mm_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
|
|
|
|
row78Temp2 = _mm_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
|
|
|
|
|
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
|
|
|
|
__m128h row2 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
|
|
|
|
__m128h row3 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
|
|
|
|
__m128h row4 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
|
|
|
|
__m128h row5 = _mm_castsi128_ph(_mm_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
|
|
|
|
__m128h row6 = _mm_castsi128_ph(_mm_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
|
|
|
|
__m128h row7 = _mm_castsi128_ph(_mm_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
|
|
|
|
__m128h row8 = _mm_castsi128_ph(_mm_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
|
|
|
|
|
|
|
|
|
row1 = _mm_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm_add_ph(row1, row4);
|
|
|
|
|
row1 = _mm_add_ph(row1, row5);
|
|
|
|
|
row1 = _mm_add_ph(row1, row6);
|
|
|
|
|
row1 = _mm_add_ph(row1, row7);
|
|
|
|
|
row1 = _mm_add_ph(row1, row8);
|
2026-03-18 03:16:29 +01:00
|
|
|
|
2026-03-19 02:19:01 +01:00
|
|
|
return row1;
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m256h mulB = _mm256_mul_ph(B0.v, B1.v);
|
|
|
|
|
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
|
|
|
|
__m256i row56Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
|
|
|
|
__m256i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m256i row5TempTemp1 = row56Temp1;
|
|
|
|
|
|
|
|
|
|
__m256h mulC = _mm256_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m256h mulD = _mm256_mul_ph(D0.v, D1.v);
|
|
|
|
|
__m256i row34Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulC), _mm256_castph_si256(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
|
|
|
|
__m256i row78Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
|
|
|
|
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
|
|
|
|
row56Temp1 = _mm256_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
|
|
|
|
row78Temp1 = _mm256_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
|
|
|
|
|
|
|
|
|
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m256h mulF = _mm256_mul_ph(F0.v, F1.v);
|
|
|
|
|
__m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
|
|
|
|
__m256i row56Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
|
|
|
|
__m256i row1TempTemp2 = row12Temp2;
|
|
|
|
|
__m256i row5TempTemp2 = row56Temp2;
|
|
|
|
|
|
|
|
|
|
__m256h mulG = _mm256_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m256h mulH = _mm256_mul_ph(H0.v, H1.v);
|
|
|
|
|
__m256i row34Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulG), _mm256_castph_si256(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
|
|
|
|
__m256i row78Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
|
|
|
|
|
|
|
|
|
row12Temp2 = _mm256_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
|
|
|
|
row34Temp2 = _mm256_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
|
|
|
|
row56Temp2 = _mm256_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
|
|
|
|
row78Temp2 = _mm256_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
|
|
|
|
|
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
|
|
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
|
|
|
|
__m256h row3 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
|
|
|
|
__m256h row4 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
|
|
|
|
__m256h row5 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
|
|
|
|
__m256h row6 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
|
|
|
|
__m256h row7 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
|
|
|
|
__m256h row8 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
|
|
|
|
|
|
|
|
|
row1 = _mm256_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row4);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row5);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row6);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row7);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row8);
|
|
|
|
|
|
|
|
|
|
return row1;
|
|
|
|
|
} else {
|
|
|
|
|
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m512h mulB = _mm512_mul_ph(B0.v, B1.v);
|
|
|
|
|
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A1 B1 A2 B2 A3 B3 A4 B4
|
|
|
|
|
__m512i row56Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // A5 B5 A6 B6 A7 B7 A8 B8
|
|
|
|
|
__m512i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m512i row5TempTemp1 = row56Temp1;
|
|
|
|
|
|
|
|
|
|
__m512h mulC = _mm512_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m512h mulD = _mm512_mul_ph(D0.v, D1.v);
|
|
|
|
|
__m512i row34Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulC), _mm512_castph_si512(mulD)); // C1 D1 C2 D2 C3 D3 C4 D4
|
|
|
|
|
__m512i row78Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulB)); // C5 D5 C6 D6 C7 D7 C8 D8
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row34Temp1); // A1 C1 B1 D1 A2 C2 B2 D2
|
|
|
|
|
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 C3 B3 D3 A4 C4 B4 D4
|
|
|
|
|
row56Temp1 = _mm512_unpacklo_epi16(row56Temp1, row78Temp1); // A5 C5 B5 D5 A6 C6 B6 D6
|
|
|
|
|
row78Temp1 = _mm512_unpackhi_epi16(row5TempTemp1, row78Temp1); // A7 C7 B7 D7 A8 C8 B8 D8
|
|
|
|
|
|
|
|
|
|
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m512h mulF = _mm512_mul_ph(F0.v, F1.v);
|
|
|
|
|
__m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E1 F1 E2 F2 E3 F3 E4 F4
|
|
|
|
|
__m512i row56Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //E5 F5 E6 F6 E7 F7 E8 F8
|
|
|
|
|
__m512i row1TempTemp2 = row12Temp2;
|
|
|
|
|
__m512i row5TempTemp2 = row56Temp2;
|
|
|
|
|
|
|
|
|
|
__m512h mulG = _mm512_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m512h mulH = _mm512_mul_ph(H0.v, H1.v);
|
|
|
|
|
__m512i row34Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulG), _mm512_castph_si512(mulH)); //G1 H1 G2 H2 G3 H3 G4 H4
|
|
|
|
|
__m512i row78Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulF)); //G5 H5 G6 H6 G7 H7 G8 H8
|
|
|
|
|
|
|
|
|
|
row12Temp2 = _mm512_unpacklo_epi16(row12Temp2, row34Temp2); // E1 G1 F1 H1 E2 G2 F2 H2
|
|
|
|
|
row34Temp2 = _mm512_unpackhi_epi16(row1TempTemp2, row34Temp2); // E3 G3 F3 H3 E4 G4 F4 H4
|
|
|
|
|
row56Temp2 = _mm512_unpacklo_epi16(row56Temp2, row78Temp2); // E5 G5 F5 H5 E6 G6 F6 H6
|
|
|
|
|
row78Temp2 = _mm512_unpackhi_epi16(row5TempTemp2, row78Temp2); // E7 G7 F7 H7 E8 G8 F8 H8
|
|
|
|
|
|
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A1 E1 C1 G1 B1 F1 D1 H1
|
|
|
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A2 E2 C2 G2 B2 F2 D2 H2
|
|
|
|
|
__m512h row3 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A3 E3 C3 G3 B3 F3 D3 H3
|
|
|
|
|
__m512h row4 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A4 E4 C4 G4 B4 F4 D4 H4
|
|
|
|
|
__m512h row5 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row56Temp1, row56Temp2));// A5 E5 C5 G5 B5 F5 D5 H5
|
|
|
|
|
__m512h row6 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row56Temp1, row56Temp2));// A6 E6 C6 G6 B6 F6 D6 H6
|
|
|
|
|
__m512h row7 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row78Temp1, row78Temp2));// A7 E7 C7 G7 B7 F7 D7 H7
|
|
|
|
|
__m512h row8 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row78Temp1, row78Temp2));// A8 E8 C8 G8 B8 F8 D8 H8
|
|
|
|
|
|
|
|
|
|
row1 = _mm512_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row4);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row5);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row6);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row7);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row8);
|
|
|
|
|
|
|
|
|
|
return row1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Dot(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A0, VectorF16<Len, Packing, Repeats> A1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> C0, VectorF16<Len, Packing, Repeats> C1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E0, VectorF16<Len, Packing, Repeats> E1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> G0, VectorF16<Len, Packing, Repeats> G1
|
|
|
|
|
) requires(Packing == 2) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
|
|
|
__m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
|
|
|
__m128i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m128i row5TempTemp1 = row34Temp1;
|
|
|
|
|
|
|
|
|
|
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m128h mulG = _mm_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
|
|
|
__m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
|
|
|
row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
|
|
|
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
|
row34Temp2 = _mm_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
|
|
|
|
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
|
|
|
__m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
|
|
|
__m128h row3 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
|
|
|
__m128h row4 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
|
|
|
|
|
|
row1 = _mm_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm_add_ph(row1, row4);
|
|
|
|
|
|
|
|
|
|
return row1;
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m256h mulC = _mm256_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
|
|
|
__m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
|
|
|
__m256i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m256i row5TempTemp1 = row34Temp1;
|
|
|
|
|
|
|
|
|
|
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m256h mulG = _mm256_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
|
|
|
__m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
|
|
|
row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
|
|
|
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
|
row34Temp2 = _mm256_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
|
|
|
|
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
|
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
|
|
|
__m256h row3 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
|
|
|
__m256h row4 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
|
|
|
|
|
|
row1 = _mm256_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm256_add_ph(row1, row4);
|
|
|
|
|
|
|
|
|
|
return row1;
|
|
|
|
|
} else {
|
|
|
|
|
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m512h mulC = _mm512_mul_ph(C0.v, C1.v);
|
|
|
|
|
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
|
|
|
__m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
|
|
|
__m512i row1TempTemp1 = row12Temp1;
|
|
|
|
|
__m512i row5TempTemp1 = row34Temp1;
|
|
|
|
|
|
|
|
|
|
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m512h mulG = _mm512_mul_ph(G0.v, G1.v);
|
|
|
|
|
__m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
|
|
|
__m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
|
|
|
row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
|
|
|
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row34Temp1); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
|
row34Temp2 = _mm512_unpackhi_epi16(row5TempTemp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
|
|
|
|
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
|
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
|
|
|
__m512h row3 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
|
|
|
__m512h row4 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
|
|
|
|
|
|
row1 = _mm512_add_ph(row1, row2);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row3);
|
|
|
|
|
row1 = _mm512_add_ph(row1, row4);
|
|
|
|
|
|
|
|
|
|
return row1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Dot(
|
|
|
|
|
VectorF16<Len, Packing, Repeats> A0, VectorF16<Len, Packing, Repeats> A1,
|
|
|
|
|
VectorF16<Len, Packing, Repeats> E0, VectorF16<Len, Packing, Repeats> E1
|
|
|
|
|
) requires(Packing == 4) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
|
|
|
__m128i row12Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
|
|
|
__m128i row12Temp1Temp = row12Temp1;
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
|
|
|
row12Temp2 = _mm_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
|
|
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
|
|
|
__m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
|
|
|
|
|
|
|
|
return _mm_add_ph(row1, row2);
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
|
|
|
__m256i row12Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
|
|
|
__m256i row12Temp1Temp = row12Temp1;
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
|
|
|
row12Temp2 = _mm256_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
|
|
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
|
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
|
|
|
|
|
|
|
|
return _mm256_add_ph(row1, row2);
|
|
|
|
|
} else {
|
|
|
|
|
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
|
|
|
|
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
|
|
|
|
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
|
|
|
__m512i row12Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
|
|
|
__m512i row12Temp1Temp = row12Temp1;
|
|
|
|
|
|
|
|
|
|
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
|
|
|
row12Temp2 = _mm512_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
|
|
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
|
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
|
|
|
|
|
|
|
|
return _mm512_add_ph(row1, row2);
|
|
|
|
|
}
|
|
|
|
|
}
|
2026-03-22 03:51:09 +01:00
|
|
|
|
|
|
|
|
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D, std::uint8_t E, std::uint8_t F, std::uint8_t G, std::uint8_t H>
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Blend(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b) {
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
constexpr std::uint8_t val =
|
|
|
|
|
(A & 1) |
|
|
|
|
|
((B & 1) << 1) |
|
|
|
|
|
((C & 1) << 2) |
|
|
|
|
|
((D & 1) << 3) |
|
|
|
|
|
((E & 1) << 4) |
|
|
|
|
|
((F & 1) << 5) |
|
|
|
|
|
((G & 1) << 6) |
|
|
|
|
|
((H & 1) << 7);
|
|
|
|
|
return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b), val));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
constexpr std::uint8_t val =
|
|
|
|
|
(A & 1) |
|
|
|
|
|
((B & 1) << 1) |
|
|
|
|
|
((C & 1) << 2) |
|
|
|
|
|
((D & 1) << 3) |
|
|
|
|
|
((E & 1) << 4) |
|
|
|
|
|
((F & 1) << 5) |
|
|
|
|
|
((G & 1) << 6) |
|
|
|
|
|
((H & 1) << 7);
|
|
|
|
|
return _mm256_castsi256_ph(_mm256_blend_epi16(_mm256_castph_si256(a.v), _mm256_castph_si256(b), val));
|
|
|
|
|
} else {
|
|
|
|
|
constexpr std::uint8_t byte =
|
|
|
|
|
(A & 1) |
|
|
|
|
|
((B & 1) << 1) |
|
|
|
|
|
((C & 1) << 2) |
|
|
|
|
|
((D & 1) << 3) |
|
|
|
|
|
((E & 1) << 4) |
|
|
|
|
|
((F & 1) << 5) |
|
|
|
|
|
((G & 1) << 6) |
|
|
|
|
|
((H & 1) << 7);
|
|
|
|
|
|
|
|
|
|
constexpr std::uint32_t val = byte * 0x01010101u;
|
|
|
|
|
return _mm512_castsi512_ph(_mm512_mask_blend_epi16(val, _mm512_castph_si512(a.v), _mm512_castph_si512(b)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<Len, Packing, Repeats> Rotate(VectorF16<3, 2, Repeats> v, VectorF16<4, 2, Repeats> q) requires(Len == 3 && Packing == 2) {
|
|
|
|
|
VectorF16<3, 2, Repeats> qv(q.v);
|
|
|
|
|
VectorF16<Len, Packing, Repeats> t = Cross(qv, v) * _Float16(2);
|
2026-03-22 20:53:17 +01:00
|
|
|
return v + t * q.template Shuffle<3,3,3,3,7,7,7,7>(); + Cross(qv, t);
|
2026-03-22 03:51:09 +01:00
|
|
|
}
|
2026-03-19 02:19:01 +01:00
|
|
|
|
2026-03-22 03:51:09 +01:00
|
|
|
constexpr static VectorF16<4, 2, Repeats> RotatePivot(VectorF16<3, 2, Repeats> v, VectorF16<4, 2, Repeats> q, VectorF16<3, 2, Repeats> pivot) requires(Len == 3 && Packing == 2) {
|
|
|
|
|
VectorF16<Len, Packing, Repeats> translated = v - pivot;
|
|
|
|
|
VectorF16<3, 2, Repeats> qv(q.v);
|
|
|
|
|
VectorF16<Len, Packing, Repeats> t = Cross(qv, translated) * _Float16(2);
|
2026-03-22 20:53:17 +01:00
|
|
|
VectorF16<Len, Packing, Repeats> rotated = translated + t * q.template Shuffle<3,3,3,3,7,7,7,7>() + Cross(qv, t);
|
2026-03-22 03:51:09 +01:00
|
|
|
return rotated + pivot;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr static VectorF16<4, 2, Repeats> QuanternionFromEuler(VectorF16<3, 2, Repeats> EulerHalf) requires(Len == 3 && Packing == 2) {
|
|
|
|
|
VectorF16<3, 2, Repeats> sin = EulerHalf.Sin();
|
|
|
|
|
VectorF16<3, 2, Repeats> cos = EulerHalf.Cos();
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row1 = cos.template Shuffle<0,0,0,0,4,4,4,4>();
|
|
|
|
|
row1 = VectorF16<3, 2, Repeats>::Blend<0,1,1,1, 0,1,1,1>(sin, row1);
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row2 = cos.template Shuffle<1,1,1,1,5,5,5,5>();
|
|
|
|
|
row2 = VectorF16<3, 2, Repeats>::Blend<1,0,1,1, 1,0,1,1>(sin, row2);
|
|
|
|
|
|
|
|
|
|
row1 = row2;
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row3 = cos.template Shuffle<2,2,2,2,6,6,6,6>();
|
|
|
|
|
row3 = VectorF16<3, 2, Repeats>::Blend<1,1,0,1, 1,1,0,1>(sin, row3);
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row4 = sin.template Shuffle<0,0,0,0,4,4,4,4>();
|
|
|
|
|
row4 = VectorF16<3, 2, Repeats>::Blend<1,0,0,0, 1,0,0,0>(sin, row4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
|
|
|
|
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
|
|
|
|
__m128i sign_mask = _mm_load_si128(reinterpret_cast<const __m128i*>(mask));
|
|
|
|
|
row4.v = (_mm_castsi128_ph(_mm_xor_si128(sign_mask, _mm_castph_si128(row4.v))));
|
|
|
|
|
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
|
|
|
|
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
|
|
|
|
__m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
|
|
|
|
|
row4.v = (_mm256_castsi256_ph(_mm256_xor_si256(sign_mask, _mm256_castph_si256(row4.v))));
|
|
|
|
|
} else {
|
|
|
|
|
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
|
|
|
|
__m512i sign_mask = _mm512_load_si512(reinterpret_cast<const __m256i*>(mask));
|
|
|
|
|
row4.v = (_mm512_castsi512_ph(_mm512_xor_si512(sign_mask, _mm512_castph_si512(row4.v))));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
row1 = MulitplyAdd(row1, row3, row4);
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row5 = sin.template Shuffle<1,1,1,1,5,5,5,5>();
|
|
|
|
|
row5 = VectorF16<3, 2, Repeats>::Blend<0,1,0,0, 0,1,0,0>(sin, row5);
|
|
|
|
|
|
|
|
|
|
row1 *= row5;
|
|
|
|
|
|
|
|
|
|
VectorF16<3, 2, Repeats> row6 = sin.template Shuffle<2,2,2,2,6,6,6,6>();
|
|
|
|
|
row6 = VectorF16<3, 2, Repeats>::Blend<0,0,1,0, 0,0,1,0>(sin, row6);
|
|
|
|
|
|
|
|
|
|
return row1 * row6;
|
|
|
|
|
}
|
2026-03-19 02:19:01 +01:00
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export template <std::uint32_t Len, std::uint32_t Packing, std::uint32_t Repeats>
|
|
|
|
|
struct std::formatter<Crafter::VectorF16<Len, Packing, Repeats>> : std::formatter<std::string> {
|
|
|
|
|
auto format(const Crafter::VectorF16<Len, Packing, Repeats>& obj, format_context& ctx) const {
|
|
|
|
|
Crafter::Vector<_Float16, Len * Packing * Repeats, 0> vec = obj.template Store<Len * Packing * Repeats, 0>();
|
|
|
|
|
std::string out;
|
|
|
|
|
for(std::uint32_t i = 0; i < Repeats; i++) {
|
|
|
|
|
out += "{";
|
|
|
|
|
for(std::uint32_t i2 = 0; i2 < Packing; i2++) {
|
|
|
|
|
out += "{";
|
|
|
|
|
for(std::uint32_t i3 = 0; i3 < Len; i3++) {
|
|
|
|
|
out += std::format("{}", static_cast<float>(vec.v[i * Packing * Len + i2 * Len + i3]));
|
|
|
|
|
if (i3 + 1 < Len) out += ",";
|
|
|
|
|
}
|
|
|
|
|
out += "}";
|
|
|
|
|
}
|
|
|
|
|
out += "}";
|
|
|
|
|
}
|
|
|
|
|
return std::formatter<std::string>::format(out, ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|