all tests passing again
This commit is contained in:
parent
d09155736f
commit
8999c8b9ec
7 changed files with 912 additions and 914 deletions
|
|
@ -21,7 +21,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|||
export module Crafter.Math:Basic;
|
||||
import std;
|
||||
import :VectorF16;
|
||||
import :VectorF32;
|
||||
|
||||
namespace Crafter {
|
||||
template<typename T>
|
||||
|
|
|
|||
409
interfaces/Crafter.Math-Common.cppm
Normal file
409
interfaces/Crafter.Math-Common.cppm
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
module;
|
||||
#ifdef __x86_64
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
export module Crafter.Math:Common;
|
||||
import std;
|
||||
|
||||
namespace Crafter {
|
||||
export template <std::uint8_t Len, std::uint8_t Packing>
|
||||
struct VectorF16;
|
||||
|
||||
template <std::uint8_t Len, std::uint8_t Packing, typename T>
|
||||
struct VectorBase {
|
||||
template <std::uint8_t L, std::uint8_t P>
|
||||
friend struct VectorF16;
|
||||
protected:
|
||||
static consteval std::uint8_t GetAlingment() {
|
||||
if(Len * Packing * sizeof(T) <= 16) {
|
||||
return 16;
|
||||
} else if(Len * Packing * sizeof(T) <= 32) {
|
||||
return 32;
|
||||
} else if(Len * Packing * sizeof(T) <= 64) {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
using VectorType = std::conditional_t<
|
||||
(Len * Packing > 16), __m512h,
|
||||
std::conditional_t<(Len * Packing > 8), __m256h, __m128h>
|
||||
>;
|
||||
|
||||
VectorType v;
|
||||
|
||||
public:
|
||||
|
||||
template <std::uint8_t Len2, std::uint8_t Packing2, typename T2>
|
||||
friend struct VectorBase;
|
||||
|
||||
#ifdef __AVX512F__
|
||||
static constexpr std::uint8_t Max = 64;
|
||||
#else
|
||||
static constexpr std::uint8_t Max = 32;
|
||||
#endif
|
||||
static constexpr std::uint8_t MaxElement = Max/sizeof(T);
|
||||
|
||||
static constexpr std::uint8_t AlignmentElement = GetAlingment()/sizeof(T);
|
||||
static constexpr std::uint8_t Alignment = GetAlingment();
|
||||
static_assert(Len * Packing <= MaxElement, "Len * Packing exceeds MaxElement");
|
||||
|
||||
protected:
|
||||
static constexpr std::uint8_t PerLane = 16/sizeof(T);
|
||||
static consteval std::array<bool, Len> GetAllTrue() {
|
||||
std::array<bool, Len> arr{};
|
||||
arr.fill(true);
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||
static consteval bool CheckEpi32Shuffle() {
|
||||
if constexpr (PerLane == 8) {
|
||||
for(std::uint8_t i = 1; i < Len; i+=2) {
|
||||
if(ShuffleValues[i-1] != ShuffleValues[i] - 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for(std::uint8_t i = 0; i < Len; i++) {
|
||||
for(std::uint8_t i2 = PerLane; i2 < Len; i2 += PerLane) {
|
||||
if(ShuffleValues[i] != ShuffleValues[i2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||
static consteval bool CheckEpi8Shuffle() {
|
||||
for(std::uint8_t i = 0; i < Len; i++) {
|
||||
std::uint8_t lane = i / PerLane;
|
||||
if(ShuffleValues[i] < lane * PerLane || ShuffleValues[i] > lane * PerLane + PerLane-1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||
static consteval std::array<std::uint8_t, Alignment> GetShuffleMaskEpi8() {
|
||||
std::array<std::uint8_t, Alignment> shuffleMask {{0}};
|
||||
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for(std::uint8_t i = 0; i < Len; i++) {
|
||||
shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T));
|
||||
shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T));
|
||||
}
|
||||
}
|
||||
return shuffleMask;
|
||||
}
|
||||
|
||||
|
||||
template<std::array<bool, Len> values>
|
||||
static consteval std::array<T, AlignmentElement> GetNegateMask() {
|
||||
std::array<T, AlignmentElement> mask{};
|
||||
|
||||
T high_bit = 0;
|
||||
|
||||
if constexpr(sizeof(T) == 2) {
|
||||
high_bit = std::bit_cast<T>(
|
||||
static_cast<std::uint16_t>(1u << (std::numeric_limits<std::uint16_t>::digits - 1))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
for (std::uint8_t i2 = 0; i2 < Packing; ++i2) {
|
||||
for (std::uint8_t i = 0; i < Len; ++i) {
|
||||
mask[i2 * Len + i] = values[i] ? high_bit : T(0);
|
||||
}
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::uint32_t ExtractLen>
|
||||
static constexpr std::array<std::uint8_t, Alignment> GetExtractLoMaskEpi8() {
|
||||
std::array<std::uint8_t, Alignment> mask {{0}};
|
||||
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for(std::uint8_t i = 0; i < ExtractLen; i++) {
|
||||
mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T))] = i*sizeof(T)+(i2*Len*sizeof(T));
|
||||
mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T)+1)] = i*sizeof(T)+1+(i2*Len*sizeof(T));
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::uint32_t ExtractLen>
|
||||
static consteval std::array<std::uint16_t, AlignmentElement> GetExtractLoMaskEpi16() {
|
||||
std::array<std::uint16_t, AlignmentElement> mask{};
|
||||
for (std::uint16_t i2 = 0; i2 < Packing; i2++) {
|
||||
for (std::uint16_t i = 0; i < ExtractLen; i++) {
|
||||
mask[i2 * ExtractLen + i] = i + (i2 * Len);
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||
static consteval std::uint8_t GetShuffleMaskEpi32() {
|
||||
std::uint8_t mask = 0;
|
||||
for(std::uint8_t i = 0; i < std::min(Len, std::uint8_t(8)); i+=2) {
|
||||
mask = mask | (ShuffleValues[i] & 0b11) << i;
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::array<std::uint8_t, Len> ShuffleValues>
|
||||
static consteval std::array<std::uint16_t, AlignmentElement> GetPermuteMaskEpi16() {
|
||||
std::array<std::uint16_t, AlignmentElement> shuffleMask {{0}};
|
||||
for(std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for(std::uint8_t i = 0; i < Len; i++) {
|
||||
shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len;
|
||||
}
|
||||
}
|
||||
return shuffleMask;
|
||||
}
|
||||
|
||||
template <std::array<bool, Len> ShuffleValues>
|
||||
static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m128h>){
|
||||
std::uint8_t mask = 0;
|
||||
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for (std::uint8_t i = 0; i < Len; i++) {
|
||||
if (ShuffleValues[i]) {
|
||||
mask |= (1u << (i2 * Len + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::array<bool, Len> ShuffleValues>
|
||||
static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m256h>){
|
||||
std::uint16_t mask = 0;
|
||||
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for (std::uint8_t i = 0; i < Len; i++) {
|
||||
if (ShuffleValues[i]) {
|
||||
mask |= (1u << (i2 * Len + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
template <std::array<bool, Len> ShuffleValues>
|
||||
static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v<VectorType, __m512h>){
|
||||
std::uint32_t mask = 0;
|
||||
for (std::uint8_t i2 = 0; i2 < Packing; i2++) {
|
||||
for (std::uint8_t i = 0; i < Len; i++) {
|
||||
if (ShuffleValues[i]) {
|
||||
mask |= (1u << (i2 * Len + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
static constexpr float two_over_pi = 0.6366197723675814f;
|
||||
static constexpr float pi_over_2_hi = 1.5707963267341256f;
|
||||
static constexpr float pi_over_2_lo = 6.077100506506192e-11f;
|
||||
|
||||
// Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ...
|
||||
static constexpr float c0 = 1.0f;
|
||||
static constexpr float c2 = -0.4999999642372f;
|
||||
static constexpr float c4 = 0.0416666418707f;
|
||||
static constexpr float c6 = -0.0013888397720f;
|
||||
static constexpr float c8 = 0.0000248015873f;
|
||||
static constexpr float c10 = -0.0000002752258f;
|
||||
|
||||
// Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...)
|
||||
static constexpr float s1 = -0.1666666641831f;
|
||||
static constexpr float s3 = 0.0083333293858f;
|
||||
static constexpr float s5 = -0.0001984090955f;
|
||||
static constexpr float s7 = 0.0000027526372f;
|
||||
static constexpr float s9 = -0.0000000239013f;
|
||||
|
||||
// Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant
|
||||
static constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) {
|
||||
__m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
q = _mm256_cvtps_epi32(fq);
|
||||
r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi)));
|
||||
r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo)));
|
||||
r2 = _mm256_mul_ps(r, r);
|
||||
}
|
||||
|
||||
|
||||
static constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) {
|
||||
cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0));
|
||||
|
||||
sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f));
|
||||
sin_r = _mm256_mul_ps(sin_r, r);
|
||||
}
|
||||
|
||||
// cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2
|
||||
static constexpr __m256 cos_f32x8(__m256 x) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
__m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin);
|
||||
|
||||
__m256i need_neg = _mm256_and_si256(
|
||||
_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
|
||||
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
|
||||
return _mm256_xor_ps(result, neg_mask);
|
||||
}
|
||||
|
||||
// sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign
|
||||
static constexpr __m256 sin_f32x8(__m256 x) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 x_sign = _mm256_and_ps(x, sign_mask);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
__m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos);
|
||||
|
||||
__m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
|
||||
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
|
||||
result = _mm256_xor_ps(result, neg_mask);
|
||||
|
||||
// Apply original sign of x
|
||||
return _mm256_xor_ps(result, x_sign);
|
||||
}
|
||||
|
||||
// // --- 512-bit helpers ---
|
||||
|
||||
static constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) {
|
||||
__m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
q = _mm512_cvtps_epi32(fq);
|
||||
r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi)));
|
||||
r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo)));
|
||||
r2 = _mm512_mul_ps(r, r);
|
||||
}
|
||||
|
||||
static constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) {
|
||||
cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0));
|
||||
|
||||
sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f));
|
||||
sin_r = _mm512_mul_ps(sin_r, r);
|
||||
}
|
||||
|
||||
static constexpr __m512 cos_f32x16(__m512 x) {
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
__m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r);
|
||||
|
||||
__m512i need_neg = _mm512_and_si512(
|
||||
_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
|
||||
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
|
||||
return _mm512_xor_ps(result, neg_mask);
|
||||
}
|
||||
|
||||
static constexpr __m512 sin_f32x16(__m512 x) {
|
||||
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
__m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r);
|
||||
|
||||
__m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
|
||||
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
|
||||
result = _mm512_xor_ps(result, neg_mask);
|
||||
|
||||
return _mm512_xor_ps(result, x_sign);
|
||||
}
|
||||
|
||||
// --- 256-bit sincos ---
|
||||
static constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 x_sign = _mm256_and_ps(x, sign_mask);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
|
||||
// cos: swap on odd, negate if (q+1)&2
|
||||
out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd);
|
||||
__m256i cos_neg = _mm256_and_si256(_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
|
||||
out_cos = _mm256_xor_ps(out_cos, _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30)));
|
||||
|
||||
// sin: swap on odd, negate if q&2, apply input sign
|
||||
out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd);
|
||||
__m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
|
||||
out_sin = _mm256_xor_ps(out_sin,_mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30)));
|
||||
out_sin = _mm256_xor_ps(out_sin, x_sign);
|
||||
}
|
||||
|
||||
// --- 512-bit sincos ---
|
||||
static constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) {
|
||||
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
|
||||
// cos
|
||||
out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r);
|
||||
__m512i cos_neg = _mm512_and_si512(_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
|
||||
out_cos = _mm512_xor_ps(out_cos, _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30)));
|
||||
|
||||
// sin
|
||||
out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r);
|
||||
__m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
|
||||
out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30)));
|
||||
out_sin = _mm512_xor_ps(out_sin, x_sign);
|
||||
}
|
||||
};
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -18,14 +18,15 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|||
*/
|
||||
module;
|
||||
#ifdef __x86_64
|
||||
#include <immintrin.h>
|
||||
#include <immintrin.>
|
||||
#endif
|
||||
export module Crafter.Math:VectorF32;
|
||||
import std;
|
||||
import :Vector;
|
||||
import :Common;
|
||||
|
||||
namespace Crafter {
|
||||
export template <std::uint32_t Len, std::uint32_t Packing, std::uint32_t Repeats>
|
||||
export template <std::uint32_t Len, std::uint32_t Packing>
|
||||
struct VectorF32 {
|
||||
#ifdef __AVX512F__
|
||||
static constexpr std::uint32_t MaxSize = 16;
|
||||
|
|
@ -45,8 +46,6 @@ namespace Crafter {
|
|||
return 16;
|
||||
}
|
||||
static_assert(Len * Packing <= 16, "Len * Packing is larger than supported max size of 16");
|
||||
static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4");
|
||||
static_assert(Len * Packing * Repeats <= 16, "Len * Packing * Repeats is larger than supported max of 16");
|
||||
#else
|
||||
if constexpr (Len * Packing <= 4) {
|
||||
return 4;
|
||||
|
|
@ -55,17 +54,12 @@ namespace Crafter {
|
|||
return 8;
|
||||
}
|
||||
static_assert(Len * Packing <= 8, "Len * Packing is larger than supported max size of 8");
|
||||
static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4");
|
||||
static_assert(Len * Packing * Repeats <= 8, "Len * Packing * Repeats is larger than supported max of 8");
|
||||
#endif
|
||||
}
|
||||
static consteval std::uint32_t GetTotalSize() {
|
||||
return GetAlignment() * Repeats;
|
||||
}
|
||||
|
||||
using VectorType = std::conditional_t<
|
||||
(GetTotalSize() == 16), __m512,
|
||||
std::conditional_t<(GetTotalSize() == 8), __m256, __m128>
|
||||
(Len * Packing > 8), __m512h,
|
||||
std::conditional_t<(Len * Packing > 4), __m256h, __m128>
|
||||
>;
|
||||
|
||||
VectorType v;
|
||||
|
|
@ -107,91 +101,96 @@ namespace Crafter {
|
|||
}
|
||||
constexpr void Load(const _Float16* vB) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
v = _mm_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
v = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
v = _mm256_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB)));
|
||||
} else {
|
||||
v = _mm512_cvtph_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB)));
|
||||
v = _mm512_cvtps_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB)));
|
||||
}
|
||||
}
|
||||
constexpr void Store(_Float16* vB) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
_mm_storeu_si128(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm_storeu_si128(_mm_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
_mm_storeu_si128(_mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm_storeu_si128(_mm256_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
} else {
|
||||
_mm256_storeu_si256(_mm512_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
_mm256_storeu_si256(_mm512_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v);
|
||||
}
|
||||
}
|
||||
|
||||
template <std::uint32_t VLen, std::uint32_t VAlign>
|
||||
constexpr Vector<float, VLen, VAlign> Store() const {
|
||||
Vector<float, VLen, VAlign> returnVec;
|
||||
Store(returnVec.v);
|
||||
return returnVec;
|
||||
constexpr std::array<float, Alignment> Store() const {
|
||||
std::array<float, Alignment> returnArray;
|
||||
Store(returnArray.data());
|
||||
return returnArray;
|
||||
}
|
||||
|
||||
template <std::uint32_t BLen, std::uint32_t BPacking, std::uint32_t BRepeats>
|
||||
constexpr operator VectorF32<BLen, BPacking, BRepeats>() const {
|
||||
if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm256_castps256_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps512_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps512_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm256_castps128_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps128_ps512(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking, BRepeats>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(_mm512_castps256_ps512(v));
|
||||
template <std::uint32_t BLen, std::uint32_t BPacking>
|
||||
constexpr operator VectorF32<BLen, BPacking>() const {
|
||||
if constexpr (Len == BLen) {
|
||||
if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking>(_mm256_castps256_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m128>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps128(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m512> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps512_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m256>) {
|
||||
return VectorF32<BLen, BPacking>(_mm256_castps128_ps256(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps128_ps512(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256> && std::is_same_v<typename VectorF32<BLen, BPacking>::VectorType, __m512>) {
|
||||
return VectorF32<BLen, BPacking>(_mm512_castps256_ps512(v));
|
||||
} else {
|
||||
return VectorF32<BLen, BPacking, BRepeats>(v);
|
||||
return VectorF32<BLen, BPacking>(v);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator+(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_add_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_add_ps(v, b.v));
|
||||
} else if constexpr (BLen <= Len) {
|
||||
return this->template ExtractLo<BLen>();
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_add_ps(v, b.v));
|
||||
return VectorF32<BLen, BPacking>(v);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_sub_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_sub_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator+(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_add_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_add_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_sub_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_add_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator*(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_mul_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_mul_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator-(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_sub_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_sub_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_mul_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_sub_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator/(VectorF32<Len, Packing, Repeats> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_div_ps(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_div_ps(v, b.v));
|
||||
constexpr VectorF32<Len, Packing> operator*(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_mul_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_mul_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_div_ps(v, b.v));
|
||||
return VectorF32<Len, Packing>(_mm512_mul_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing> operator/(VectorF32<Len, Packing> b) const {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_div_ph(v, b.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_div_ph(v, b.v));
|
||||
} else {
|
||||
return VectorF32<Len, Packing>(_mm512_div_ph(v, b.v));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
constexpr void operator+=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator+=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_add_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -201,7 +200,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator-=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator-=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_sub_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -211,7 +210,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator*=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator*=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_mul_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -221,7 +220,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator/=(VectorF32<Len, Packing, Repeats> b) const {
|
||||
constexpr void operator/=(VectorF32<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
v = _mm_div_ps(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
|
|
@ -231,60 +230,48 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator+(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this + vB;
|
||||
constexpr VectorF32<Len, Packing> operator+(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this + vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this - vB;
|
||||
constexpr VectorF32<Len, Packing> operator-(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this - vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator*(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this * vB;
|
||||
constexpr VectorF32<Len, Packing> operator*(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this * vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator/(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
return this / vB;
|
||||
constexpr VectorF32<Len, Packing> operator/(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
return *this / vB;
|
||||
}
|
||||
|
||||
constexpr void operator+=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this += vB;
|
||||
constexpr void operator+=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this += vB;
|
||||
}
|
||||
|
||||
constexpr void operator-=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this -= vB;
|
||||
constexpr void operator-=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this -= vB;
|
||||
}
|
||||
|
||||
constexpr void operator*=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this *= vB;
|
||||
constexpr void operator*=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this *= vB;
|
||||
}
|
||||
|
||||
constexpr void operator/=(float b) const {
|
||||
VectorF32<Len, Packing, Repeats> vB(b);
|
||||
this /= vB;
|
||||
constexpr void operator/=(float b) {
|
||||
VectorF32<Len, Packing> vB(b);
|
||||
*this /= vB;
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> operator-(){
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m128i sign_mask = _mm_loadu_si128(reinterpret_cast<const __m128i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_castsi128_ps(_mm_xor_si128(sign_mask, _mm_castps_si128(v))));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256>) {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m256i sign_mask = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_xor_si256(sign_mask, _mm256_castps_si256(v))));
|
||||
} else {
|
||||
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
|
||||
__m512i sign_mask = _mm512_loadu_si512(reinterpret_cast<const __m256i*>(mask));
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_xor_si512(sign_mask, _mm512_castps_si512(v))));
|
||||
}
|
||||
constexpr VectorF32<Len, Packing> operator-(){
|
||||
return Negate<GetAllTrue<Len>()>();
|
||||
}
|
||||
|
||||
constexpr bool operator==(VectorF32<Len, Packing, Repeats> b) const {
|
||||
|
|
@ -335,71 +322,47 @@ namespace Crafter {
|
|||
return Dot(*this, *this);
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> Cos() requires(Len == 3) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_cos_ps(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_cos_ps(v));
|
||||
template <const std::array<std::uint8_t, Len> ShuffleValues>
|
||||
constexpr VectorF32<Len, Packing> Shuffle() {
|
||||
if constexpr(CheckEpi32Shuffle<ShuffleValues>()) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_cos_ps(v));
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(v), GetShuffleMaskEpi32<ShuffleValues>())));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF32<Len, Packing, Repeats> Sin() requires(Len == 3) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_sin_ps(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_sin_ps(v));
|
||||
} else if constexpr(CheckEpi8Shuffle<ShuffleValues>()){
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(v)), _mm512_castsi256_si512(shuffleVec)))));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_sin_ps(v));
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec)));
|
||||
}
|
||||
}
|
||||
|
||||
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() {
|
||||
constexpr std::uint32_t val =
|
||||
(A & 0x3) |
|
||||
((B & 0x3) << 2) |
|
||||
((C & 0x3) << 4) |
|
||||
((D & 0x3) << 6);
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), val)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), val)));
|
||||
} else {
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm_512castps_si512(v), val)));
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
constexpr std::array<std::uint8_t, VectorF32<Len, Packing>::Alignment*2> shuffleMask = GetShuffleMaskEpi8<ShuffleValues>();
|
||||
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
||||
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec)));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi16(permIdx, _mm256_castps_si256(v))));
|
||||
} else {
|
||||
constexpr std::array<std::uint16_t, VectorF32<Len, Packing>::Alignment> permMask = GetPermuteMaskEpi32<ShuffleValues>();
|
||||
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
||||
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi16(permIdx, _mm512_castps_si512(v))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0,
|
||||
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1
|
||||
>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() requires(Repeats == 2) {
|
||||
constexpr std::uint8_t shuffleMask[] {
|
||||
A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0,
|
||||
A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1,
|
||||
};
|
||||
__m256 shuffleVec = _mm256_loadu_epi8(shuffleMask);
|
||||
return VectorF32<Len, Packing, Repeats>(_mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(v), shuffleVec)));
|
||||
}
|
||||
|
||||
template <
|
||||
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
|
||||
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1,
|
||||
std::uint8_t A2, std::uint8_t B2, std::uint8_t C2, std::uint8_t D2, std::uint8_t E2, std::uint8_t F2, std::uint8_t G2, std::uint8_t H2,
|
||||
std::uint8_t A3, std::uint8_t B3, std::uint8_t C3, std::uint8_t D3, std::uint8_t E3, std::uint8_t F3, std::uint8_t G3, std::uint8_t H3
|
||||
>
|
||||
constexpr VectorF32<Len, Packing, Repeats> Shuffle() requires(Repeats == 4) {
|
||||
constexpr std::uint8_t shuffleMask[] {
|
||||
A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0,
|
||||
A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1,
|
||||
A2,A2,A2,A2,B2,B2,B2,B2,C2,C2,C2,C2,D2,D2,D2,D2,
|
||||
A3,A3,A3,A3,B3,B3,B3,B3,C3,C3,C3,C3,D3,D3,D3,D3,
|
||||
};
|
||||
__m512 shuffleVec = _mm512_loadu_epi8(shuffleMask);
|
||||
return VectorF32<Len, Packing, Repeats>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec)));
|
||||
}
|
||||
|
||||
static constexpr VectorF32<Len, Packing, Repeats> MulitplyAdd(VectorF32<Len, Packing, Repeats> a, VectorF32<Len, Packing, Repeats> b, VectorF32<Len, Packing, Repeats> add) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128>) {
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|||
|
||||
export module Crafter.Math;
|
||||
export import :Basic;
|
||||
export import :Vector;
|
||||
export import :MatrixRowMajor;
|
||||
export import :Intersection;
|
||||
export import :Common;
|
||||
export import :VectorF16;
|
||||
export import :VectorF32;
|
||||
// export import :VectorF32;
|
||||
|
|
@ -4,13 +4,10 @@
|
|||
{
|
||||
"name": "base",
|
||||
"interfaces": [
|
||||
"interfaces/Crafter.Math-Vector",
|
||||
"interfaces/Crafter.Math-Basic",
|
||||
"interfaces/Crafter.Math-MatrixRowMajor",
|
||||
"interfaces/Crafter.Math",
|
||||
"interfaces/Crafter.Math-Intersection",
|
||||
"interfaces/Crafter.Math-VectorF16",
|
||||
"interfaces/Crafter.Math-VectorF32"
|
||||
"interfaces/Crafter.Math-Common",
|
||||
"interfaces/Crafter.Math-VectorF16"
|
||||
],
|
||||
"implementations": []
|
||||
},
|
||||
|
|
|
|||
209
tests/Vector.cpp
209
tests/Vector.cpp
|
|
@ -45,17 +45,26 @@ consteval std::array<std::uint8_t, Len> GetCountReverse() {
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename T, template<std::uint32_t, std::uint32_t> class VectorType, std::uint32_t MaxSize, std::uint32_t Len = 1, std::uint32_t Packing = 1>
|
||||
template <typename To, typename From, std::size_t N>
|
||||
constexpr std::array<To, N> array_cast(const std::array<From, N>& src) {
|
||||
std::array<To, N> dst{};
|
||||
for (std::size_t i = 0; i < N; ++i) {
|
||||
dst[i] = static_cast<To>(src[i]);
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <typename T, template<std::uint8_t, std::uint8_t> class VectorType, std::uint32_t MaxSize, std::uint32_t Len = 1, std::uint32_t Packing = 1>
|
||||
std::string* TestAllCombinations() {
|
||||
if constexpr (Len > MaxSize) {
|
||||
return nullptr;
|
||||
} else if constexpr (Len * Packing > MaxSize) {
|
||||
return TestAllCombinations<T, VectorType, MaxSize, Len + 1, 1>();
|
||||
} else {
|
||||
T floats[VectorType<Len, Packing>::Alignment];
|
||||
T floats1[VectorType<Len, Packing>::Alignment];
|
||||
T floats2[VectorType<Len, Packing>::Alignment];
|
||||
for (std::uint32_t i = 0; i < VectorType<Len, Packing>::Alignment; i++) {
|
||||
T floats[VectorType<Len, Packing>::AlignmentElement];
|
||||
T floats1[VectorType<Len, Packing>::AlignmentElement];
|
||||
T floats2[VectorType<Len, Packing>::AlignmentElement];
|
||||
for (std::uint32_t i = 0; i < VectorType<Len, Packing>::AlignmentElement; i++) {
|
||||
floats[i] = static_cast<T>(i+1);
|
||||
}
|
||||
for (std::uint32_t i = 0; i < Packing*Len; i++) {
|
||||
|
|
@ -64,7 +73,7 @@ std::string* TestAllCombinations() {
|
|||
for (std::uint32_t i = 0; i < Packing*Len; i++) {
|
||||
floats2[i] = static_cast<T>(i+1+Len);
|
||||
}
|
||||
for (std::uint32_t i = Len*Packing; i < VectorType<Len, Packing>::Alignment; i++) {
|
||||
for (std::uint32_t i = Len*Packing; i < VectorType<Len, Packing>::AlignmentElement; i++) {
|
||||
floats1[i] = 0;
|
||||
floats2[i] = 0;
|
||||
}
|
||||
|
|
@ -81,10 +90,10 @@ std::string* TestAllCombinations() {
|
|||
if constexpr(total > 0 && (total & (total - 1)) == 0) {
|
||||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i])) {
|
||||
return new std::string(std::format("Load/Store mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i]), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i])) {
|
||||
return new std::string(std::format("Load/Store mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i]), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -92,10 +101,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec + vec;
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] + floats[i])) {
|
||||
return new std::string(std::format("Add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + floats[i]), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] + floats[i])) {
|
||||
return new std::string(std::format("Add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + floats[i]), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -103,10 +112,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec - vec;
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], T(0))) {
|
||||
return new std::string(std::format("Subtract mismatch at Len={} Packing={}, Expected: 0, Got: {}", Len, Packing, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], T(0))) {
|
||||
return new std::string(std::format("Subtract mismatch at Len={} Packing={}, Expected: 0, Got: {}", Len, Packing, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -114,10 +123,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec * vec;
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] * floats[i])) {
|
||||
return new std::string(std::format("Multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * floats[i]), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] * floats[i])) {
|
||||
return new std::string(std::format("Multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * floats[i]), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -125,10 +134,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec / vec;
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], T(1))) {
|
||||
return new std::string(std::format("Divide mismatch at Len={} Packing={}, Expected: 1, Got: {}", Len, Packing, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], T(1))) {
|
||||
return new std::string(std::format("Divide mismatch at Len={} Packing={}, Expected: 1, Got: {}", Len, Packing, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -136,10 +145,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec + T(2);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] + T(2))) {
|
||||
return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] + T(2))) {
|
||||
return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -147,10 +156,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec - T(2);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] - T(2))) {
|
||||
return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] - T(2))) {
|
||||
return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -158,10 +167,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec * T(2);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] * T(2))) {
|
||||
return new std::string(std::format("Scalar multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] * T(2))) {
|
||||
return new std::string(std::format("Scalar multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -169,10 +178,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = vec / T(2);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i] / T(2))) {
|
||||
return new std::string(std::format("Scalar divide mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i] / T(2))) {
|
||||
return new std::string(std::format("Scalar divide mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -216,10 +225,10 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
vec = -vec;
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> result = vec.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> result = vec.Store();
|
||||
for (std::uint32_t i = 0; i < Len * Packing; i++) {
|
||||
if (!FloatEquals(result.v[i], -floats[i])) {
|
||||
return new std::string(std::format("Negate mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(-floats[i]), (float)result.v[i]));
|
||||
if (!FloatEquals(result[i], -floats[i])) {
|
||||
return new std::string(std::format("Negate mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(-floats[i]), (float)result[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -228,12 +237,12 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vecA(floats1);
|
||||
VectorType<Len, Packing> vecB(floats2);
|
||||
VectorType<Len, Packing> result = VectorType<Len, Packing>::template Blend<AlternateTrueFalse<Len>()>(vecA, vecB);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
bool useB = (i % 2 == 0);
|
||||
T expected = useB ? floats2[i]: floats1[i];
|
||||
if (!FloatEquals(stored.v[i], expected)) {
|
||||
return new std::string(std::format("Blend mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], expected)) {
|
||||
return new std::string(std::format("Blend mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -243,11 +252,11 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vecB(floats);
|
||||
VectorType<Len, Packing> vecAdd(floats);
|
||||
VectorType<Len, Packing> result = VectorType<Len, Packing>::MulitplyAdd(vecA, vecB, vecAdd);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = floats[i] * floats[i] + floats[i];
|
||||
if (!FloatEquals(stored.v[i], expected)) {
|
||||
return new std::string(std::format("MulitplyAdd mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], expected)) {
|
||||
return new std::string(std::format("MulitplyAdd mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -257,11 +266,11 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vecB(floats);
|
||||
VectorType<Len, Packing> vecSub(floats);
|
||||
VectorType<Len, Packing> result = VectorType<Len, Packing>::MulitplySub(vecA, vecB, vecSub);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = floats[i] * floats[i] - floats[i];
|
||||
if (!FloatEquals(stored.v[i], expected)) {
|
||||
return new std::string(std::format("MulitplySub mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], expected)) {
|
||||
return new std::string(std::format("MulitplySub mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -269,12 +278,12 @@ std::string* TestAllCombinations() {
|
|||
if constexpr(Len > 2){
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
VectorType<Len-1, Packing> result = vec.template ExtractLo<Len-1>();
|
||||
Vector<T, (Len-1)*Packing, VectorType<Len-1, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len-1, Packing>::AlignmentElement> stored = result.Store();
|
||||
for(std::uint32_t i2 = 0; i2 < Packing; i2++){
|
||||
for (std::uint32_t i = 0; i < Len-1; i++) {
|
||||
T expected = floats[i2*(Len)+i];
|
||||
if (!FloatEquals(stored.v[i2*(Len-1)+i], expected)) {
|
||||
return new std::string(std::format("ExtractLo mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i]));
|
||||
if (!FloatEquals(stored[i2*(Len-1)+i], expected)) {
|
||||
return new std::string(std::format("ExtractLo mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -283,12 +292,12 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
VectorType<Len, Packing> result = vec.Sin();
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for(std::uint32_t i2 = 0; i2 < Packing; i2++){
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = (T)std::sin((float)floats[i2*Len+i]);
|
||||
if (!FloatEquals(stored.v[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("Sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i]));
|
||||
if (!FloatEquals(stored[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("Sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -297,12 +306,12 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
VectorType<Len, Packing> result = vec.Cos();
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for(std::uint32_t i2 = 0; i2 < Packing; i2++){
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = (T)std::cos((float)floats[i2*Len+i]);
|
||||
if (!FloatEquals(stored.v[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("Cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i]));
|
||||
if (!FloatEquals(stored[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("Cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -311,18 +320,18 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
auto result = vec.SinCos();
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> storedSin = std::get<0>(result).Store();
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> storedCos = std::get<1>(result).Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> storedSin = std::get<0>(result).Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> storedCos = std::get<1>(result).Store();
|
||||
for(std::uint32_t i2 = 0; i2 < Packing; i2++){
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = (T)std::sin((float)floats[i2*Len+i]);
|
||||
if (!FloatEquals(storedSin.v[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("SinCos sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedSin.v[i2*(Len-1)+i]));
|
||||
if (!FloatEquals(storedSin[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("SinCos sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedSin[i2*(Len-1)+i]));
|
||||
}
|
||||
|
||||
expected = (T)std::cos((float)floats[i2*Len+i]);
|
||||
if (!FloatEquals(storedCos.v[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("SinCos cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedCos.v[i2*(Len-1)+i]));
|
||||
if (!FloatEquals(storedCos[i2*Len+i], expected)) {
|
||||
return new std::string(std::format("SinCos cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedCos[i2*(Len-1)+i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -331,11 +340,11 @@ std::string* TestAllCombinations() {
|
|||
{
|
||||
VectorType<Len, Packing> vec(floats);
|
||||
VectorType<Len, Packing> result = vec.template Shuffle<GetCountReverse<Len>()>();
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
for (std::uint32_t i = 0; i < Len; i++) {
|
||||
T expected = floats[Len - 1 - i];
|
||||
if (!FloatEquals(stored.v[i], expected)) {
|
||||
return new std::string(std::format("Shuffle mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], expected)) {
|
||||
return new std::string(std::format("Shuffle mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -343,7 +352,7 @@ std::string* TestAllCombinations() {
|
|||
|
||||
if constexpr(Packing == 1) {
|
||||
T expectedLengthSq = T(0);
|
||||
for (std::uint32_t i = 0; i < VectorType<Len, Packing>::Alignment; i++) {
|
||||
for (std::uint32_t i = 0; i < VectorType<Len, Packing>::AlignmentElement; i++) {
|
||||
expectedLengthSq += floats[i] * floats[i];
|
||||
}
|
||||
|
||||
|
|
@ -387,13 +396,13 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vec1(floats1);
|
||||
VectorType<Len, Packing> vec2(floats2);
|
||||
VectorType<Len, Packing> result = VectorType<Len, Packing>::Cross(vec1, vec2);
|
||||
Vector<T, Len*Packing, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
if (!FloatEquals(stored.v[0], T(-3)) || !FloatEquals(stored.v[1], T(6)) || !FloatEquals(stored.v[2], T(-3))) {
|
||||
return new std::string(std::format("Cross mismatch at Len={} Packing={}, Expected: -3,6,-3, Got: {},{},{}", Len, Packing, (float)stored.v[0], (float)stored.v[1], (float)stored.v[2]));
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
if (!FloatEquals(stored[0], T(-3)) || !FloatEquals(stored[1], T(6)) || !FloatEquals(stored[2], T(-3))) {
|
||||
return new std::string(std::format("Cross mismatch at Len={} Packing={}, Expected: -3,6,-3, Got: {},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2]));
|
||||
}
|
||||
}
|
||||
if constexpr(4 * Packing < VectorType<1, 1>::MaxSize) {
|
||||
T qData[VectorType<4, Packing>::Alignment];
|
||||
if constexpr(4 * Packing < VectorType<1, 1>::MaxElement) {
|
||||
T qData[VectorType<4, Packing>::AlignmentElement];
|
||||
qData[0] = T(0);
|
||||
qData[1] = T(0);
|
||||
qData[2] = T(0);
|
||||
|
|
@ -402,18 +411,18 @@ std::string* TestAllCombinations() {
|
|||
VectorType<3, Packing> vecV(floats);
|
||||
VectorType<4, Packing> vecQ(qData);
|
||||
VectorType<3, Packing> result = VectorType<3, Packing>::Rotate(vecV, vecQ);
|
||||
Vector<T, 3*Packing, VectorType<3, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
|
||||
for (std::uint32_t i = 0; i < 3; i++) {
|
||||
if (!FloatEquals(stored.v[i], floats[i])) {
|
||||
return new std::string(std::format("Rotate mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)floats[i], (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], floats[i])) {
|
||||
return new std::string(std::format("Rotate mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)floats[i], (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(Len == 4) {
|
||||
T eulerData[VectorType<3, Packing>::Alignment];
|
||||
T eulerData[VectorType<3, Packing>::AlignmentElement];
|
||||
for(std::uint8_t i = 0; i < Packing; i++) {
|
||||
eulerData[i*3] = T(0.7853981);
|
||||
eulerData[i*3+1] = T(0.1243412);
|
||||
|
|
@ -421,27 +430,27 @@ std::string* TestAllCombinations() {
|
|||
}
|
||||
VectorType<3, Packing> eulerVec(eulerData);
|
||||
VectorType<4, Packing> result = VectorType<4, Packing>::QuanternionFromEuler(eulerVec);
|
||||
Vector<T, 4*Packing, VectorType<4, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<4, Packing>::AlignmentElement> stored = result.Store();
|
||||
|
||||
if (!FloatEquals(stored.v[0], T(0.63720703)) || !FloatEquals(stored.v[1], T(0.30688477)) ||
|
||||
!FloatEquals(stored.v[2], T(0.14074707)) || !FloatEquals(stored.v[3], T(0.6933594))) {
|
||||
return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored.v[0], (float)stored.v[1], (float)stored.v[2], (float)stored.v[3]));
|
||||
if (!FloatEquals(stored[0], T(0.63720703)) || !FloatEquals(stored[1], T(0.30688477)) ||
|
||||
!FloatEquals(stored[2], T(0.14074707)) || !FloatEquals(stored[3], T(0.6933594))) {
|
||||
return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2], (float)stored[3]));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(Len == 2 && Packing*Len == VectorType<Len, Packing>::Alignment) {
|
||||
if constexpr(Len == 2 && Packing*Len == VectorType<Len, Packing>::AlignmentElement) {
|
||||
{
|
||||
VectorType<Len, Packing> vecA(floats);
|
||||
VectorType<Len, Packing> vecE = vecA *2;
|
||||
VectorType<1, Packing*2> result = VectorType<Len, Packing>::Length(vecA, vecE);
|
||||
Vector<T, Packing*2, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
|
||||
if (!FloatEquals(stored.v[0], expectedLength[0])) {
|
||||
return new std::string(std::format("Length 2 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0]));
|
||||
if (!FloatEquals(stored[0], expectedLength[0])) {
|
||||
return new std::string(std::format("Length 2 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0]));
|
||||
}
|
||||
|
||||
if (!FloatEquals(stored.v[(Len*Packing)/2], expectedLength[0] * 2)) {
|
||||
return new std::string(std::format("Length 2 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[(Len*Packing)/2]));
|
||||
if (!FloatEquals(stored[(Len*Packing)/2], expectedLength[0] * 2)) {
|
||||
return new std::string(std::format("Length 2 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[(Len*Packing)/2]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -450,39 +459,39 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vecE = vecA * 2;
|
||||
auto result = VectorType<Len, Packing>::Normalize(vecA, vecE);
|
||||
VectorType<1, Packing*2> result2 = VectorType<Len, Packing>::Length(std::get<0>(result), std::get<1>(result));
|
||||
Vector<T, Packing*2, VectorType<Len, Packing>::Alignment> stored = result2.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result2.Store();
|
||||
|
||||
for(std::uint8_t i = 0; i < Len*Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], T(1))) {
|
||||
return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], T(1))) {
|
||||
return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(Len == 4 && Packing*Len == VectorType<Len, Packing>::Alignment) {
|
||||
if constexpr(Len == 4 && Packing*Len == VectorType<Len, Packing>::AlignmentElement) {
|
||||
{
|
||||
VectorType<Len, Packing> vecA(floats);
|
||||
VectorType<Len, Packing> vecC = vecA * 2;
|
||||
VectorType<Len, Packing> vecE = vecA * 3;
|
||||
VectorType<Len, Packing> vecG = vecA * 4;
|
||||
VectorType<1, Packing*4> result = VectorType<Len, Packing>::Length(vecA, vecC, vecE, vecG);
|
||||
Vector<T, Packing*4, VectorType<Len, Packing>::Alignment> stored = result.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result.Store();
|
||||
|
||||
if (!FloatEquals(stored.v[0], expectedLength[0])) {
|
||||
return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0]));
|
||||
if (!FloatEquals(stored[0], expectedLength[0])) {
|
||||
return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0]));
|
||||
}
|
||||
|
||||
if (!FloatEquals(stored.v[Packing], expectedLength[0] * 2)) {
|
||||
return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[Packing]));
|
||||
if (!FloatEquals(stored[Packing], expectedLength[0] * 2)) {
|
||||
return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[Packing]));
|
||||
}
|
||||
|
||||
if (!FloatEquals(stored.v[Packing*2], expectedLength[0] * 3)) {
|
||||
return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored.v[Packing*2]));
|
||||
if (!FloatEquals(stored[Packing*2], expectedLength[0] * 3)) {
|
||||
return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored[Packing*2]));
|
||||
}
|
||||
|
||||
if (!FloatEquals(stored.v[Packing*3], expectedLength[0] * 4)) {
|
||||
return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored.v[Packing*3]));
|
||||
if (!FloatEquals(stored[Packing*3], expectedLength[0] * 4)) {
|
||||
return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored[Packing*3]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -493,11 +502,11 @@ std::string* TestAllCombinations() {
|
|||
VectorType<Len, Packing> vecG = vecA * 4;
|
||||
auto result = VectorType<Len, Packing>::Normalize(vecA, vecC, vecE, vecG);
|
||||
VectorType<1, Packing*4> result2 = VectorType<Len, Packing>::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result));
|
||||
Vector<T, Packing*4, VectorType<Len, Packing>::Alignment> stored = result2.Store();
|
||||
std::array<T, VectorType<Len, Packing>::AlignmentElement> stored = result2.Store();
|
||||
|
||||
for(std::uint8_t i = 0; i < Len*Packing; i++) {
|
||||
if (!FloatEquals(stored.v[i], T(1))) {
|
||||
return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i]));
|
||||
if (!FloatEquals(stored[i], T(1))) {
|
||||
return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -509,7 +518,7 @@ std::string* TestAllCombinations() {
|
|||
|
||||
extern "C" {
|
||||
std::string* RunTest() {
|
||||
std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxSize>();
|
||||
std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxElement>();
|
||||
if (err) {
|
||||
return err;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue