x86 gating
This commit is contained in:
parent
739ad8e59f
commit
99972d8c81
3 changed files with 22 additions and 4 deletions
|
|
@ -32,16 +32,22 @@ namespace Crafter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __x86_64
|
||||||
using VectorType = std::conditional_t<std::is_same_v<T, _Float16>,
|
using VectorType = std::conditional_t<std::is_same_v<T, _Float16>,
|
||||||
|
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
std::conditional_t<(Len * Packing > 16), __m512h,
|
std::conditional_t<(Len * Packing > 16), __m512h,
|
||||||
std::conditional_t<(Len * Packing > 8), __m256h, __m128h>>,
|
std::conditional_t<(Len * Packing > 8), __m256h, __m128h>>,
|
||||||
|
#else
|
||||||
|
void,
|
||||||
|
#endif
|
||||||
|
|
||||||
std::conditional_t<(Len * Packing > 8), __m512,
|
std::conditional_t<(Len * Packing > 8), __m512,
|
||||||
std::conditional_t<(Len * Packing > 4), __m256, __m128>>
|
std::conditional_t<(Len * Packing > 4), __m256, __m128>>
|
||||||
>;
|
>;
|
||||||
|
|
||||||
VectorType v;
|
VectorType v;
|
||||||
|
#endif
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
|
@ -100,7 +106,7 @@ namespace Crafter {
|
||||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||||
static consteval std::array<std::uint8_t, Alignment> GetShuffleMaskEpi8() {
|
static consteval std::array<std::uint8_t, Alignment> GetShuffleMaskEpi8() {
|
||||||
std::array<std::uint8_t, Alignment> shuffleMask {{0}};
|
std::array<std::uint8_t, Alignment> shuffleMask {{0}};
|
||||||
if constexpr(std::same_as<T, _Float16>) {
|
if constexpr(sizeof(T) == 2) {
|
||||||
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||||
for(std::uint8_t i = 0; i < Len; i++) {
|
for(std::uint8_t i = 0; i < Len; i++) {
|
||||||
shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T));
|
shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T));
|
||||||
|
|
@ -213,6 +219,8 @@ namespace Crafter {
|
||||||
return shuffleMask;
|
return shuffleMask;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __x86_64
|
||||||
|
#ifdef __AVX512FP16__
|
||||||
template <std::array<bool, Len> ShuffleValues>
|
template <std::array<bool, Len> ShuffleValues>
|
||||||
static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m128h>){
|
static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m128h>){
|
||||||
std::uint8_t mask = 0;
|
std::uint8_t mask = 0;
|
||||||
|
|
@ -251,6 +259,7 @@ namespace Crafter {
|
||||||
}
|
}
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <std::array<bool, Len> ShuffleValues>
|
template <std::array<bool, Len> ShuffleValues>
|
||||||
static consteval std::uint8_t GetBlendMaskEpi32() requires (std::is_same_v<VectorType, __m128>){
|
static consteval std::uint8_t GetBlendMaskEpi32() requires (std::is_same_v<VectorType, __m128>){
|
||||||
|
|
@ -290,6 +299,7 @@ namespace Crafter {
|
||||||
}
|
}
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static constexpr float two_over_pi = 0.6366197723675814f;
|
static constexpr float two_over_pi = 0.6366197723675814f;
|
||||||
static constexpr float pi_over_2_hi = 1.5707963267341256f;
|
static constexpr float pi_over_2_hi = 1.5707963267341256f;
|
||||||
|
|
@ -310,6 +320,7 @@ namespace Crafter {
|
||||||
static constexpr float s7 = 0.0000027526372f;
|
static constexpr float s7 = 0.0000027526372f;
|
||||||
static constexpr float s9 = -0.0000000239013f;
|
static constexpr float s9 = -0.0000000239013f;
|
||||||
|
|
||||||
|
#ifdef __x86_64
|
||||||
// --- 128-bit (SSE) helpers ---
|
// --- 128-bit (SSE) helpers ---
|
||||||
static constexpr void range_reduce_f32x4(__m128 ax, __m128& r, __m128& r2, __m128i& q) {
|
static constexpr void range_reduce_f32x4(__m128 ax, __m128& r, __m128& r2, __m128i& q) {
|
||||||
__m128 fq = _mm_round_ps(_mm_mul_ps(ax, _mm_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
__m128 fq = _mm_round_ps(_mm_mul_ps(ax, _mm_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||||
|
|
@ -590,5 +601,6 @@ namespace Crafter {
|
||||||
out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30)));
|
out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30)));
|
||||||
out_sin = _mm512_xor_ps(out_sin, x_sign);
|
out_sin = _mm512_xor_ps(out_sin, x_sign);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -464,9 +464,12 @@ struct std::formatter<Crafter::Vector<T, Len, Aligment>> : std::formatter<std::s
|
||||||
auto format(const Crafter::Vector<T, Len, Aligment>& vec, format_context& ctx) const {
|
auto format(const Crafter::Vector<T, Len, Aligment>& vec, format_context& ctx) const {
|
||||||
std::string out = "{";
|
std::string out = "{";
|
||||||
for(std::uint32_t i2 = 0; i2 < Len; i2++) {
|
for(std::uint32_t i2 = 0; i2 < Len; i2++) {
|
||||||
|
#ifdef __x86_64
|
||||||
if constexpr(std::same_as<T, _Float16>) {
|
if constexpr(std::same_as<T, _Float16>) {
|
||||||
out += std::format("{}", static_cast<float>(vec.v[i2]));
|
out += std::format("{}", static_cast<float>(vec.v[i2]));
|
||||||
} else {
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
out += std::format("{}", vec.v[i2]);
|
out += std::format("{}", vec.v[i2]);
|
||||||
}
|
}
|
||||||
if (i2 + 1 < Len) out += ",";
|
if (i2 + 1 < Len) out += ",";
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ import std;
|
||||||
import :Common;
|
import :Common;
|
||||||
|
|
||||||
namespace Crafter {
|
namespace Crafter {
|
||||||
|
#ifdef __x86_64
|
||||||
export template <std::uint8_t Len, std::uint8_t Packing>
|
export template <std::uint8_t Len, std::uint8_t Packing>
|
||||||
struct VectorF32 : public VectorBase<Len, Packing, float> {
|
struct VectorF32 : public VectorBase<Len, Packing, float> {
|
||||||
template <std::uint8_t Len2, std::uint8_t Packing2>
|
template <std::uint8_t Len2, std::uint8_t Packing2>
|
||||||
|
|
@ -1382,9 +1383,10 @@ namespace Crafter {
|
||||||
return row1;
|
return row1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __x86_64
|
||||||
export template <std::uint32_t Len, std::uint32_t Packing>
|
export template <std::uint32_t Len, std::uint32_t Packing>
|
||||||
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
|
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
|
||||||
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
|
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
|
||||||
|
|
@ -1401,4 +1403,5 @@ struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::st
|
||||||
out += "}";
|
out += "}";
|
||||||
return std::formatter<std::string>::format(out, ctx);
|
return std::formatter<std::string>::format(out, ctx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
Loading…
Add table
Add a link
Reference in a new issue