wasm SIMD

This commit is contained in:
Jorijn van der Graaf 2026-05-18 05:23:49 +02:00
commit 48e3b8e26c
4 changed files with 803 additions and 12 deletions

View file

@ -2,20 +2,31 @@ module;
#ifdef __x86_64
#include <immintrin.h>
#endif
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#endif
export module Crafter.Math:Common;
import std;
// VectorF16 exists as a real struct when _Float16 is available AND we are not
// on x86_64 without AVX512FP16 (that path aliases VectorF16 to VectorF32 in
// Crafter.Math:Basic for performance). Each translation unit that needs this
// distinction redefines the same condition since macros do not cross module
// boundaries.
#if defined(__FLT16_MAX__) && (!defined(__x86_64) || defined(__AVX512FP16__))
namespace Crafter {
#ifdef __AVX512FP16__
export template <std::uint8_t Len, std::uint8_t Packing>
struct VectorF16;
#endif
}
#endif
namespace Crafter {
export template <std::uint8_t Len, std::uint8_t Packing>
struct VectorF32;
template <std::uint8_t Len, std::uint8_t Packing, typename T>
struct VectorBase {
#ifdef __AVX512FP16__
#if defined(__FLT16_MAX__) && (!defined(__x86_64) || defined(__AVX512FP16__))
template <std::uint8_t L, std::uint8_t P>
friend struct VectorF16;
#endif
@ -33,8 +44,13 @@ namespace Crafter {
}
#ifdef __x86_64
using VectorType = std::conditional_t<std::is_same_v<T, _Float16>,
using VectorType = std::conditional_t<
#ifdef __FLT16_MAX__
std::is_same_v<T, _Float16>
#else
false
#endif
,
#ifdef __AVX512FP16__
std::conditional_t<(Len * Packing > 16), __m512h,
std::conditional_t<(Len * Packing > 8), __m256h, __m128h>>,
@ -45,9 +61,13 @@ namespace Crafter {
std::conditional_t<(Len * Packing > 8), __m512,
std::conditional_t<(Len * Packing > 4), __m256, __m128>>
>;
#elif defined(__wasm_simd128__)
using VectorType = v128_t;
#else
using VectorType = std::array<T, GetAlingment()/sizeof(T)>;
#endif
VectorType v;
#endif
public:
@ -56,6 +76,10 @@ namespace Crafter {
#ifdef __AVX512F__
static constexpr std::uint8_t Max = 64;
#elif defined(__wasm_simd128__)
// WASM SIMD only has 128-bit vectors; cap at 16 bytes so the entire
// VectorType always fits in a single v128_t.
static constexpr std::uint8_t Max = 16;
#else
static constexpr std::uint8_t Max = 32;
#endif