From dfb4f1b0453ad2a6030d5e5eada5f7b07252062c Mon Sep 17 00:00:00 2001 From: mijamind719 Date: Sun, 22 Feb 2026 23:08:31 +0800 Subject: [PATCH] feat(vectordb): integrate KRL for ARM Kunpeng vector search optimization - Add third_party/krl: Kunpeng Retrieval Library (KRL) source with ARM NEON/SIMD-optimized L2 and inner-product distance routines - vector_base.h: add ARM platform macros OV_PLATFORM_ARM, OV_SIMD_NEON, OV_SIMD_SVE - space_l2.h: on ARM use krl_L2sqr in l2_sqr_neon instead of scalar path - space_ip.h: on ARM use krl_ipdis in inner_product_neon instead of scalar path - CMakeLists.txt: enable OV_PLATFORM_ARM on aarch64, build and link KRL static library On ARM, vectordb uses KRL-optimized paths; on x86 the existing AVX/SSE implementations are unchanged. --- src/CMakeLists.txt | 24 + src/index/detail/vector/common/space_ip.h | 19 +- src/index/detail/vector/common/space_l2.h | 18 +- src/index/detail/vector/common/vector_base.h | 9 + third_party/krl/CMakeLists.txt | 33 + third_party/krl/include/krl.h | 578 ++++ third_party/krl/include/krl_internal.h | 309 ++ third_party/krl/include/platform_macros.h | 49 + third_party/krl/include/safe_memory.h | 44 + third_party/krl/src/IPdistance_simd.cpp | 1602 +++++++++ third_party/krl/src/L2distance_simd.cpp | 3193 ++++++++++++++++++ 11 files changed, 5876 insertions(+), 2 deletions(-) create mode 100644 third_party/krl/CMakeLists.txt create mode 100644 third_party/krl/include/krl.h create mode 100644 third_party/krl/include/krl_internal.h create mode 100644 third_party/krl/include/platform_macros.h create mode 100644 third_party/krl/include/safe_memory.h create mode 100644 third_party/krl/src/IPdistance_simd.cpp create mode 100644 third_party/krl/src/L2distance_simd.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 033ca1ed..add7ac92 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,12 @@ set(Python3_ARCH_INCLUDE_DIR "/usr/include/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu/" find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +# On Linux, pybind11 modules don't need to link against libpython +# This prevents issues with static libpython that wasn't built with -fPIC +if(UNIX AND NOT APPLE) + set(Python3_LIBRARIES "") +endif() + find_package(pybind11 REQUIRED) find_package(Threads REQUIRED) @@ -59,11 +65,24 @@ endif() add_subdirectory(../third_party/spdlog-1.14.1 ${CMAKE_BINARY_DIR}/spdlog_build) +# ARM platform detection and KRL integration +set(OV_PLATFORM_ARM OFF) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64") + set(OV_PLATFORM_ARM ON) + message(STATUS "Building for ARM platform with KRL support") + add_subdirectory(../third_party/krl ${CMAKE_BINARY_DIR}/krl_build) +endif() + include_directories(.) include_directories(../third_party/) include_directories(../third_party/leveldb-1.23/include/) include_directories(../third_party/spdlog-1.14.1/include/) +# Add KRL include directory for ARM platform +if(OV_PLATFORM_ARM) + include_directories(../third_party/krl/include/) +endif() + if(NOT DEFINED Python3_INCLUDE_DIRS) set(Python3_INCLUDE_DIRS ${Python3_ARCH_INCLUDE_DIR} @@ -88,6 +107,11 @@ target_link_libraries(engine_impl PRIVATE leveldb ) +# Link KRL library for ARM platform +if(OV_PLATFORM_ARM) + target_link_libraries(engine_impl PRIVATE krl) +endif() + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "9.0") target_link_libraries(engine_impl PRIVATE stdc++fs) diff --git a/src/index/detail/vector/common/space_ip.h b/src/index/detail/vector/common/space_ip.h index e6f42437..51959017 100644 --- a/src/index/detail/vector/common/space_ip.h +++ b/src/index/detail/vector/common/space_ip.h @@ -108,10 +108,27 @@ static float inner_product_sse(const void* v1, const void* v2, } #endif +#if defined(OV_SIMD_NEON) +#include "krl.h" + +// ARM NEON optimized inner product using KRL library +static float inner_product_neon(const void* v1, const void* v2, + const void* params) { + const float* pv1 = static_cast(v1); + const float* pv2 = static_cast(v2); + size_t dim = *static_cast(params); + float dis = 0; + krl_ipdis(pv1, pv2, dim, &dis, 1); + return dis; +} +#endif + class InnerProductSpace : public VectorSpace { public: explicit InnerProductSpace(size_t dim) : dim_(dim) { -#if defined(OV_SIMD_AVX512) +#if defined(OV_SIMD_NEON) + metric_func_ = inner_product_neon; +#elif defined(OV_SIMD_AVX512) metric_func_ = inner_product_avx512; #elif defined(OV_SIMD_AVX) metric_func_ = inner_product_avx; diff --git a/src/index/detail/vector/common/space_l2.h b/src/index/detail/vector/common/space_l2.h index a6bef28d..69c402ad 100644 --- a/src/index/detail/vector/common/space_l2.h +++ b/src/index/detail/vector/common/space_l2.h @@ -121,12 +121,28 @@ static float l2_sqr_sse(const void* v1, const void* v2, const void* params) { } #endif +#if defined(OV_SIMD_NEON) +#include "krl.h" + +// ARM NEON optimized L2 squared distance using KRL library +static float l2_sqr_neon(const void* v1, const void* v2, const void* params) { + const float* pv1 = static_cast(v1); + const float* pv2 = static_cast(v2); + size_t dim = *static_cast(params); + float dis = 0; + krl_L2sqr(pv1, pv2, dim, &dis, 1); + return 1.0f - dis; +} +#endif + class L2Space : public VectorSpace { public: explicit L2Space(size_t dim) : dim_(dim) { // Select best implementation at runtime based on compile-time flags // In a real scenario, we might want dynamic dispatch based on CPUID -#if defined(OV_SIMD_AVX512) +#if defined(OV_SIMD_NEON) + metric_func_ = l2_sqr_neon; +#elif defined(OV_SIMD_AVX512) metric_func_ = l2_sqr_avx512; #elif defined(OV_SIMD_AVX) metric_func_ = l2_sqr_avx; diff --git a/src/index/detail/vector/common/vector_base.h b/src/index/detail/vector/common/vector_base.h index b9f915bc..2f02e362 100644 --- a/src/index/detail/vector/common/vector_base.h +++ b/src/index/detail/vector/common/vector_base.h @@ -21,6 +21,15 @@ #endif #endif +// ARM Platform Detection +#if defined(__aarch64__) || defined(_M_ARM64) +#define OV_PLATFORM_ARM +#define OV_SIMD_NEON +#if defined(__ARM_FEATURE_SVE) +#define OV_SIMD_SVE +#endif +#endif + // Memory Alignment Macros #if defined(_MSC_VER) #define OV_ALIGN_32 __declspec(align(32)) diff --git a/third_party/krl/CMakeLists.txt b/third_party/krl/CMakeLists.txt new file mode 100644 index 00000000..05f43d58 --- /dev/null +++ b/third_party/krl/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.12) + +project(krl CXX) + +# Only build on ARM platform +if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64") + # Set C++ standard + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + + # ARM compile options - use more conservative flags + add_compile_options(-O2 -fPIC -fvisibility=hidden) + + # Minimal set for OpenViking: only krl_L2sqr and krl_ipdis (float, single-vector) + # C++ sources following OpenViking code style + set(KRL_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/src/L2distance_simd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/IPdistance_simd.cpp + ) + + # Create static library + add_library(krl STATIC ${KRL_SOURCES}) + + # Include directories + target_include_directories(krl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) + + # PIC for static library + set_target_properties(krl PROPERTIES POSITION_INDEPENDENT_CODE ON) + + message(STATUS "KRL library configured for ARM platform (core distance functions only)") +else() + message(STATUS "KRL library skipped - not ARM platform") +endif() diff --git a/third_party/krl/include/krl.h b/third_party/krl/include/krl.h new file mode 100644 index 00000000..8d8caa23 --- /dev/null +++ b/third_party/krl/include/krl.h @@ -0,0 +1,578 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#ifndef KRL_H +#define KRL_H + +#include +#include +#include + +#define KRL_API_PUBLIC __attribute__((visibility("default"))) + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * @brief Handle for distance computation. + */ +typedef struct KRLBatchDistanceHandle KRLDistanceHandle; + +/* + * @brief Create a distance computation handle. + * @param kdh Pointer to the distance handle. + * @param accu_level Accuracy level, 1, 2, or 3. + * @param blocksize Block size for computation, 16, 32, or 64. + * @param codes_num Number of base vectors. + * @param dim Dimension of vectors. + * @param num_base Number of base vectors. + * @param metric_type Distance measure type, 0 for inner product, 1 for L2. + * @param codes Base vector data. + * @param codes_size Length of codes. + * @return int 0 on success, non-zero on failure. + */ +KRL_API_PUBLIC int krl_create_distance_handle(KRLDistanceHandle **kdh, size_t accu_level, size_t blocksize, + size_t codes_num, size_t dim, size_t num_base, int metric_type, const uint8_t *codes, size_t codes_size); + +/* + * @brief Create a distance computation handle with additional accuracy levels. + * @param kdh Pointer to the distance handle. + * @param accu_level Accuracy level for initial computation, 1, 2, or 3. + * @param full_accu_level Accuracy level for final computation, 1, 2, or 3. + * @param codes_num Number of base vectors. + * @param dim Dimension of vectors. + * @param metric_type Distance measure type, 0 for inner product, 1 for L2. + * @param codes Base vector data. + * @param codes_size Length of codes. + * @return int 0 on success, non-zero on failure. + */ +KRL_API_PUBLIC int krl_create_reorder_handle(KRLDistanceHandle **kdh, size_t accu_level, size_t full_accu_level, + size_t codes_num, size_t dim, int metric_type, const uint8_t *codes, size_t codes_size); + +/* + * @brief Clean up and release the distance computation handle. + * @param kdh Pointer to the distance handle. + */ +KRL_API_PUBLIC void krl_clean_distance_handle(KRLDistanceHandle **kdh); + +/* + * @brief Handle for 8-bit lookup table operations. + */ +typedef struct KRLLookupTable8bitHandle KRLLUT8bHandle; + +/* + * @brief Create an 8-bit lookup table handle. + * @param klh Pointer to the lookup table handle. + * @param use_idx Whether to use index buffer. + * @param capacity Capacity of the lookup table. + * @return int 0 on success, non-zero on failure. + */ +KRL_API_PUBLIC int krl_create_LUT8b_handle(KRLLUT8bHandle **klh, int use_idx, size_t capacity); + +/* + * @brief Clean up and release the 8-bit lookup table handle. + * @param klh Pointer to the lookup table handle. + */ +KRL_API_PUBLIC void krl_clean_LUT8b_handle(KRLLUT8bHandle **klh); + +/* + * @brief Get the index pointer from the lookup table handle. + * @param klh Pointer to the lookup table handle. + * @return size_t* Pointer to the index buffer. + */ +KRL_API_PUBLIC size_t *krl_get_idx_pointer(const KRLLUT8bHandle *klh); + +/* + * @brief Get the distance pointer from the lookup table handle. + * @param klh Pointer to the lookup table handle. + * @return float* Pointer to the distance buffer. + */ +KRL_API_PUBLIC float *krl_get_dist_pointer(const KRLLUT8bHandle *klh); + +/* -------------------------------------- 1 to 1 distance compute -------------------------------------- */ + +/* + * @brief Compute L2 square distance between two vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the computed L2 square result (float). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr(const float *x, const float *__restrict y, const size_t d, float *dis, size_t dis_size); + +/* + * @brief Compute L2 square distance between two 16-bit floating point vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the computed L2 square result (float). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_f16f32( + const uint16_t *x, const uint16_t *__restrict y, size_t d, float *dis, size_t dis_size); + +/* + * @brief Compute L2 square distance between two 8-bit integer vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the computed L2 square result (uint32_t). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_u8u32( + const uint8_t *x, const uint8_t *__restrict y, size_t d, uint32_t *dis, size_t dis_size); + +/* + * @brief Compute inner product distance between two vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the inner product result (float). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_ipdis(const float *x, const float *__restrict y, const size_t d, float *dis, size_t dis_size); + +/* + * @brief Compute negative inner product distance between two 16-bit floating point vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the inner product result (float). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_negative_ipdis_f16f32( + const uint16_t *x, const uint16_t *__restrict y, const size_t d, float *dis, size_t dis_size); + +/* + * @brief Compute negative inner product distance between two 8-bit integer vectors. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param d Dimension of vectors. + * @param dis Stores the inner product result (int32_t). + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_negative_ipdis_s8s32( + const int8_t *x, const int8_t *__restrict y, const size_t d, int32_t *dis, size_t dis_size); + +/* -------------------------------------- Sparse distance calculation -------------------------------------- */ + +/* + * @brief Compute L2 square distance between vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_by_idx( + float *dis, const float *x, const float *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute L2 square distance between 16-bit floating point vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_by_idx_f16f32( + float *dis, const uint16_t *x, const uint16_t *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute L2 square distance between 8-bit integer vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_by_idx_u8f32( + float *dis, const uint8_t *x, const uint8_t *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute negative inner product distance between 16-bit floating point vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_negative_inner_product_by_idx_f16f32( + float *dis, const uint16_t *x, const uint16_t *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute inner product distance between 8-bit integer vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_by_idx_s8f32( + float *dis, const int8_t *x, const int8_t *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute inner product distance between 16-bit floating point vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_by_idx_f16f32( + float *dis, const uint16_t *x, const uint16_t *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* + * @brief Compute inner product distance between vectors using indices. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ids Indices of vectors. + * @param d Dimension of vectors. + * @param ny Number of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_by_idx( + float *dis, const float *x, const float *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size); + +/* -------------------------------------- dense distance calculation -------------------------------------- */ + +/* + * @brief Compute L2 square distance between multiple vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_ny(float *dis, const float *x, const float *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute L2 square distance between multiple 16-bit floating point vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_ny_f16f32( + float *dis, const uint16_t *x, const uint16_t *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute L2 square distance between multiple 8-bit integer vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_L2sqr_ny_u8f32( + float *dis, const uint8_t *x, const uint8_t *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute L2 square distance using a distance handle. + * @param kdh Pointer to the distance handle. + * @param dis Output distance array. + * @param x Pointer to the query vector. + * @param dis_size Length of dis. + * @param x_size Length of x. + */ +KRL_API_PUBLIC int krl_L2sqr_ny_with_handle( + const KRLDistanceHandle *kdh, float *dis, const float *x, size_t dis_size, size_t x_size); + +/* + * @brief Compute inner product distance between multiple vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_ny( + float *dis, const float *x, const float *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute inner product distance between multiple 16-bit floating point vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_ny_f16f32( + float *dis, const uint16_t *x, const uint16_t *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute inner product distance between multiple 8-bit integer vectors. + * @param dis Output distance array. + * @param x Pointer to the first set of vectors. + * @param y Pointer to the second set of vectors. + * @param ny Number of vectors. + * @param d Dimension of vectors. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_inner_product_ny_s8f32( + float *dis, const int8_t *x, const int8_t *y, size_t ny, size_t d, size_t dis_size); + +/* + * @brief Compute inner product distance using a distance handle. + * @param kdh Pointer to the distance handle. + * @param dis Output distance array. + * @param x Pointer to the query vector. + * @param dis_size Length of dis. + * @param x_size Length of x. + */ +KRL_API_PUBLIC int krl_inner_product_ny_with_handle( + const KRLDistanceHandle *kdh, float *dis, const float *x, size_t dis_size, size_t x_size); + +/* -------------------------------------- 8-bits table lookup -------------------------------------- */ + +/* + * @brief Lookup table function for 8-bit codes. + * @param nsq Number of subquantizers. + * @param ncode Number of codes. + * @param codes Input codes. + * @param sim_table Similarity table. + * @param dis Output distance array. + * @param dis0 Initial distance value. + * @param codes_size Length of codes. + * @param sim_table_size Length of sim_table. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_table_lookup_8b_f32(size_t nsq, size_t ncode, const uint8_t *codes, const float *sim_table, + float *dis, float dis0, size_t codes_size, size_t sim_table_size, size_t dis_size); + +/* + * @brief Lookup table function for 8-bit codes with indices. + * @param nsq Number of subquantizers. + * @param ncode Number of codes. + * @param codes Input codes. + * @param sim_table Similarity table. + * @param dis Output distance array. + * @param dis0 Initial distance value. + * @param idx Indices of codes. + * @param codes_size Length of codes. + * @param sim_table_size Length of sim_table. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_table_lookup_8b_f32_by_idx(size_t nsq, size_t ncode, const uint8_t *codes, + const float *sim_table, float *dis, float dis0, const size_t *idx, size_t codes_size, size_t sim_table_size, + size_t dis_size); + +/* + * @brief Lookup table function for 8-bit codes with a handle. + * @param klh Pointer to the lookup table handle. + * @param dim Dimension of vectors. + * @param ncode Number of codes. + * @param codes Input codes. + * @param sim_table Similarity table. + * @param dis0 Initial distance value. + * @param codes_size Length of codes. + * @param sim_table_size Length of sim_table. + */ +KRL_API_PUBLIC int krl_table_lookup_8b_f32_with_handle(KRLLUT8bHandle *klh, size_t dim, size_t ncode, + const uint8_t *codes, const float *sim_table, float dis0, size_t codes_size, size_t sim_table_size); + +/* -------------------------------------- 4-bits table lookup -------------------------------------- */ + +/* + * @brief Fast table lookup function for 4-bit codes (batched). + * @param nq Number of queries. + * @param nsq Number of subquantizers. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param threshold Filter threshold. + * @param lt_mask Filter result mask. + * @param keep_min Whether to keep minimum values. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + * @param threshold_size Length of threshold. + * @param lt_mask_size Length of lt_mask. + */ +KRL_API_PUBLIC int krl_fast_table_lookup_step(int nq, int nsq, const uint8_t *codes, const uint8_t *LUT, uint16_t *dis, + const uint16_t *threshold, uint32_t *lt_mask, int keep_min, size_t codes_size, size_t LUT_size, size_t dis_size, + size_t threshold_size, size_t lt_mask_size); + +/* + * @brief Fast table lookup function for 4-bit codes (single query, batch size 64). + * @param nsq Number of subquantizers. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param threshold Filter threshold. + * @param lt_mask Filter result mask. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + * @param lt_mask_size Length of lt_mask. + */ +KRL_API_PUBLIC int krl_L2_table_lookup_fast_scan_bs64(int nsq, const uint8_t *codes, const uint8_t *LUT, uint16_t *dis, + uint16_t threshold, uint32_t *lt_mask, size_t codes_size, size_t LUT_size, size_t dis_size, size_t lt_mask_size); + +/* + * @brief Fast table lookup function for 4-bit codes (single query, batch size 64, inner product). + * @param nsq Number of subquantizers. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param threshold Filter threshold. + * @param lt_mask Filter result mask. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + * @param lt_mask_size Length of lt_mask. + */ +KRL_API_PUBLIC int krl_IP_table_lookup_fast_scan_bs64(int nsq, const uint8_t *codes, const uint8_t *LUT, uint16_t *dis, + uint16_t threshold, uint32_t *lt_mask, size_t codes_size, size_t LUT_size, size_t dis_size, size_t lt_mask_size); + +/* + * @brief Fast table lookup function for 4-bit codes (single query, batch size 96). + * @param nsq Number of subquantizers. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param threshold Filter threshold. + * @param lt_mask Filter result mask. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + * @param lt_mask_size Length of lt_mask. + */ +KRL_API_PUBLIC int krl_L2_table_lookup_fast_scan_bs96(int nsq, const uint8_t *codes, const uint8_t *LUT, uint16_t *dise, + uint16_t threshold, uint32_t *lt_mask, size_t codes_size, size_t LUT_size, size_t dis_size, size_t lt_mask_size); + +/* + * @brief Fast table lookup function for 4-bit codes (single query, batch size 96, inner product). + * @param nsq Number of subquantizers. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param threshold Filter threshold. + * @param lt_mask Filter result mask. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + * @param lt_mask_size Length of lt_mask. + */ +KRL_API_PUBLIC int krl_IP_table_lookup_fast_scan_bs96(int nsq, const uint8_t *codes, const uint8_t *LUT, uint16_t *dis, + uint16_t threshold, uint32_t *lt_mask, size_t codes_size, size_t LUT_size, size_t dis_size, size_t lt_mask_size); + +/* + * @brief Lookup table function for 4-bit codes. + * @param nsq Number of subquantizers. + * @param ncode Number of codes. + * @param codes Input codes. + * @param LUT Precomputed distances. + * @param dis Output distances. + * @param dis0 Distance between query and the center of the bucket. + * @param codes_size Length of codes. + * @param LUT_size Length of LUT. + * @param dis_size Length of dis. + */ +KRL_API_PUBLIC int krl_table_lookup_4b_f16(size_t nsq, size_t ncode, const uint8_t *codes, const uint16_t *LUT, + float *dis, uint16_t dis0, size_t codes_size, size_t LUT_size, size_t dis_size); + +/* + * @brief Pack 4-bit codes into blocks. + * @param codes Input codes. + * @param ncode Total number of codes. + * @param nsq Number of subquantizers. + * @param blocks Output packed blocks. + * @param batchsize Number of base vectors per batch. + * @param dim_cross Whether to arrange dimensions in cross mode. + * @param codes_size Length of codes. + * @param blocks_size Length of blocks. + */ +KRL_API_PUBLIC int krl_pack_codes_4b(const uint8_t *codes, size_t ncode, size_t nsq, uint8_t *blocks, size_t batchsize, + int dim_cross, size_t codes_size, size_t blocks_size); + +/* -------------------------------------- reorder function -------------------------------------- */ + +/* + * @brief Reorder two vectors based on distance. + * @param kdh Pointer to the distance handle. + * @param base_k Number of base vectors obtained in the first phase. + * @param base_dis Distance array from the first phase. + * @param base_idx Index array from the first phase. + * @param query_vector Query vector. + * @param k Number of final output base vectors. + * @param dis Final distance array. + * @param idx Final index array. + * @param query_vector_size Length of query_vector. + */ +KRL_API_PUBLIC int krl_reorder_2_vector(const KRLDistanceHandle *kdh, int64_t base_k, float *base_dis, + int64_t *base_idx, const float *query_vector, int64_t k, float *dis, int64_t *idx, size_t query_vector_size); + +/* + * @brief Reorder two vectors with continuous indices. + * @param kdh Pointer to the distance handle. + * @param base_k Number of base vectors obtained in the first phase. + * @param begin_id Starting index of base vectors. + * @param query_vector Query vector. + * @param k Number of final output base vectors. + * @param dis Final distance array. + * @param idx Final index array. + * @param query_vector_size Length of query_vector. + */ +KRL_API_PUBLIC int krl_reorder_2_vector_continuous(const KRLDistanceHandle *kdh, int64_t base_k, int64_t begin_id, + const float *query_vector, int64_t k, float *dis, int64_t *idx, size_t query_vector_size); + +/* -------------------------------------- handle IO function -------------------------------------- */ + +/* + * @brief Store the 8-bit lookup table handle to a file. + * @param f File pointer. + * @param klh Pointer to the lookup table handle. + */ +KRL_API_PUBLIC int krl_store_LUT8Handle(FILE *f, const KRLLUT8bHandle *klh); + +/* + * @brief Build the 8-bit lookup table handle from a file. + * @param f File pointer. + * @param klh Pointer to the lookup table handle. + * @return int 0 on success, non-zero on failure. + */ +KRL_API_PUBLIC int krl_build_LUT8Handle_fromfile(FILE *f, KRLLUT8bHandle **klh); + +/* + * @brief Store the distance handle to a file. + * @param f File pointer. + * @param kdh Pointer to the distance handle. + */ +KRL_API_PUBLIC int krl_store_distanceHandle(FILE *f, const KRLDistanceHandle *kdh); + +/* + * @brief Build the distance handle from a file. + * @param f File pointer. + * @param kdh Pointer to the distance handle. + * @return int 0 on success, non-zero on failure. + */ +KRL_API_PUBLIC int krl_build_distanceHandle_fromfile(FILE *f, KRLDistanceHandle **kdh); + +#ifdef __cplusplus +} +#endif + +#endif // KRL_H \ No newline at end of file diff --git a/third_party/krl/include/krl_internal.h b/third_party/krl/include/krl_internal.h new file mode 100644 index 00000000..604634be --- /dev/null +++ b/third_party/krl/include/krl_internal.h @@ -0,0 +1,309 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#include "krl.h" +#include +#include +#include +#include +#include + +typedef int64_t idx_t; + +/* + * @brief Handle for batch distance computation. + * @param metric_type Measurement type (e.g., L2, inner product). + * @param quanted_scale Quantization scale parameter. + * @param quanted_bias Quantization bias parameter. + * @param data_bits Data bit width, supports 8, 16, 32. + * @param full_data_bits Full data bit width, supports 8, 16, 32. Only used when full_data_bits > data_bits for + * second-stage rearrangement. + * @param M Number of query vectors (only for GEMM). + * @param blocksize Block size for transpose GEMM, supports 16, 32, 64. 0 or 1 indicates using parameters. + * @param d Dimension of vectors. + * @param ny Number of base vectors per query. + * @param ceil_ny Number of base vectors per query (rounded up to blocksize). + * @param quanted_bytes Size for storing or reading quantized data. + * @param transposed_bytes Size for storing or reading transposed data. + * @param quanted_codes Pointer to quantized vector matrix. + * @param transposed_codes Pointer to transposed codes (only for data_bit=32). + */ +typedef struct KRLBatchDistanceHandle { + int metric_type; + float quanted_scale; + float quanted_bias; + size_t data_bits; + size_t full_data_bits; + size_t M; + size_t blocksize; + size_t d; + size_t ny; + size_t ceil_ny; + size_t quanted_bytes; + size_t transposed_bytes; + uint8_t *quanted_codes; + float *transposed_codes; +} KRLDistanceHandle; + +/* + * @brief Handle for 8-bit lookup table. + * @param use_idx Whether to use index buffer. + * @param capacity Capacity of the lookup table. + * @param idx_buffer Index buffer for storing indices. + * @param distance_buffer Distance buffer for storing distances. + */ +typedef struct KRLLookupTable8bitHandle { + int use_idx; + size_t capacity; + size_t *idx_buffer; + float *distance_buffer; +} KRLLUT8bHandle; + +/* -------------------------------------- L2 Distance Compute -------------------------------------- */ +#ifdef __cplusplus +extern "C" { +#endif +/* + * @brief Compute L2 square distance between vectors. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ny Number of vectors. + * @param d Dimension of vectors. + */ +void krl_L2sqr_ny_u8u32(uint32_t *dis, const uint8_t *x, const uint8_t *y, size_t ny, size_t d); + +/* + * @brief Compute the L2 square of a float16 vector with multiple float16 vectors in batches + * @param dis Pointer to the array storing the computed L2 squares + * @param x Pointer to the input float16 vector + * @param y Pointer to the array of float16 vectors + * @param d The dimension of the vectors + * @param ny The number of y vectors to process + */ +void krl_L2sqr_ny_f16f16(uint16_t *dis, const uint16_t *x, const uint16_t *y, size_t ny, size_t d); + +/* + * @brief Compute the L2 square of a float16 vector with multiple float16 vectors based on given indices + * @param dis Pointer to the array storing the computed L2 squares + * @param x Pointer to the input float16 vector + * @param y Pointer to the array of float16 vectors + * @param ids Pointer to the array of indices specifying which y vectors to use + * @param d The dimension of the vectors + * @param ny The number of y vectors to process + */ +void krl_L2sqr_by_idx_f16f16(uint16_t *dis, const uint16_t *x, const uint16_t *y, + const int64_t *ids, /* ids of y vecs */ + size_t d, size_t ny); + +/* -------------------------------------- IP Distance Compute -------------------------------------- */ + +/* + * @brief Compute inner product between vectors. + * @param dis Output distance array. + * @param x Pointer to the first vector. + * @param y Pointer to the second vector. + * @param ny Number of vectors. + * @param d Dimension of vectors. + */ +KRL_API_PUBLIC void krl_inner_product_ny_s8s32(int32_t *dis, const int8_t *x, const int8_t *y, size_t ny, size_t d); + +/* + * @brief Compute the inner product of a float16 vector with multiple float16 vectors based on given indices + * @param dis Pointer to the array storing the computed inner products + * @param x Pointer to the input float16 vector + * @param y Pointer to the array of float16 vectors + * @param ids Pointer to the array of indices specifying which y vectors to use + * @param d The dimension of the vectors + * @param ny The number of y vectors to process + */ +void krl_inner_product_by_idx_f16f16( + uint16_t *dis, const uint16_t *x, const uint16_t *y, const int64_t *ids, size_t d, size_t ny); + +/* + * @brief Compute the inner product of a float16 vector with multiple float16 vectors in batches + * @param dis Pointer to the array storing the computed inner products + * @param x Pointer to the input float16 vector + * @param y Pointer to the array of float16 vectors + * @param d The dimension of the vectors + * @param ny The number of y vectors to process + */ +void krl_inner_product_ny_f16f16(uint16_t *dis, const uint16_t *x, const uint16_t *y, size_t ny, size_t d); + +/* + * @brief Compute the negative inner product distance between a int8 vector and multiple int8 vectors based on indices. + * @param dis Pointer to the output array storing the computed distances. + * @param x Pointer to the input int8 vector. + * @param y Pointer to the input int8 vector array. + * @param ids Pointer to the indices of the y vectors. + * @param d Length of the vectors. + * @param ny Number of vectors to compute. + */ +void krl_negative_inner_product_by_idx_s8f32(float *dis, const int8_t *x, const int8_t *y, + const int64_t *ids, /* ids of y vecs */ + size_t d, size_t ny); + +/* -------------------------------------- 4bits lookup table -------------------------------------- */ + +/* -------------------------------------- 8bits lookup table -------------------------------------- */ + +#ifdef __cplusplus +} +#endif +/* + * @brief Matrix block transpose function. + * @param src Input matrix. + * @param ny Number of vectors. + * @param dim Dimension of vectors. + * @param blocksize Block size for transpose. + * @param block Output transposed matrix. + * @param block_size Length of block. + */ +int krl_matrix_block_transpose( + const uint32_t *src, size_t ny, size_t dim, size_t blocksize, uint32_t *block, size_t block_size); + +/* + * @brief Lookup table function for 8-bit codes. + * @param nsq Number of subquantizers. + * @param ncode Number of codes. + * @param codes Input codes. + * @param sim_table Similarity table. + * @param distance Output distance array. + * @param dis0 Initial distance value. + */ +void krl_table_lookup_8b_f32_f16( + size_t nsq, size_t ncode, const uint8_t *codes, const float16_t *sim_table, float *distance, float16_t dis0); + +/* -------------------------------------- minmax quant -------------------------------------- */ + +/* + * @brief Quantize float to float16. + * @param src Input float array. + * @param n Number of elements. + * @param out Output float16 array. + */ +void quant_f16(const float *src, idx_t n, float16_t *out); + +/* + * @brief Quantize float to uint8. + * @param src Input float array. + * @param n Number of elements. + * @param out Output uint8 array. + */ +void quant_u8(const float *src, idx_t n, uint8_t *out); + +/* + * @brief Quantize float to uint8 with scale and bias. + * @param src Input float array. + * @param n Number of elements. + * @param out Output uint8 array. + * @param scale Scale factor. + * @param bias Bias value. + */ +void quant_u8_with_parm(const float *src, idx_t n, uint8_t *out, float scale, float bias); + +/* + * @brief Quantize float to int8. + * @param src Input float array. + * @param n Number of elements. + * @param out Output int8 array. + */ +void quant_s8(const float *src, idx_t n, int8_t *out); + +/* + * @brief Quantize float to int8 with scale. + * @param src Input float array. + * @param n Number of elements. + * @param out Output int8 array. + * @param scale Scale factor. + */ +void quant_s8_with_parm(const float *src, idx_t n, int8_t *out, float scale); + +/* + * @brief Compute quantization parameters. + * @param n Number of elements. + * @param x Input float array. + * @param metric_type Distance metric type. + * @param range Quantization range. + * @param scale Output scale factor. + * @param bias Output bias value. + * @return size_t Number of quantization parameters. + */ +size_t compute_quant_parm(idx_t n, const float *x, int metric_type, int range, float *scale, float *bias); + +/* + * @brief Quantize float to uint8 with specific metric type. + * @param n Number of elements. + * @param x Input float array. + * @param out Output uint8 array. + * @param metric_type Distance metric type. + * @param use_parm Whether to use parameters. + * @param scale Scale factor. + * @param bias Bias value. + */ +void quant_sq8(idx_t n, const float *x, uint8_t *out, int metric_type, int use_parm, float scale, float bias); + +/* -------------------------------------- heap sort -------------------------------------- */ + +/* + * @brief Obtain top-k elements in descending order using heap sort. + * @param k Number of top elements. + * @param distances Distance array. + * @param k_base Base index for top elements. + * @param base_distances Base distance array. + */ +void krl_obtion_topk_heap_desc(idx_t k, float *distances, idx_t k_base, const float *base_distances); + +/* + * @brief Obtain top-k elements in ascending order using heap sort. + * @param k Number of top elements. + * @param distances Distance array. + * @param k_base Base index for top elements. + * @param base_distances Base distance array. + */ +void krl_obtion_topk_heap_asce(idx_t k, float *distances, idx_t k_base, const float *base_distances); + +/* + * @brief Reorder two heaps in descending order. + * @param k Number of top elements. + * @param labels Label array. + * @param distances Distance array. + * @param k_base Base index for top elements. + * @param base_labels Base label array. + * @param base_distances Base distance array. + */ +void krl_reorder_2_heaps_desc( + idx_t k, idx_t *labels, float *distances, idx_t k_base, const idx_t *base_labels, const float *base_distances); + +/* + * @brief Reorder two heaps in ascending order. + * @param k Number of top elements. + * @param labels Label array. + * @param distances Distance array. + * @param k_base Base index for top elements. + * @param base_labels Base label array. + * @param base_distances Base distance array. + */ +void krl_reorder_2_heaps_asce( + idx_t k, idx_t *labels, float *distances, idx_t k_base, const idx_t *base_labels, const float *base_distances); + +/* + * @brief Adaptively reorder elements in ascending order. + * @param dis Distance array. + * @param label Label array. + * @param n Number of elements. + * @param target Target value. + * @return idx_t Index of the target value. + */ +idx_t Adapt_reorder_asce(float *dis, idx_t *label, idx_t n, float target); + +/* + * @brief Adaptively reorder elements in descending order. + * @param dis Distance array. + * @param label Label array. + * @param n Number of elements. + * @param target Target value. + * @return idx_t Index of the target value. + */ +idx_t Adapt_reorder_desc(float *dis, idx_t *label, idx_t n, float target); \ No newline at end of file diff --git a/third_party/krl/include/platform_macros.h b/third_party/krl/include/platform_macros.h new file mode 100644 index 00000000..e17a55cb --- /dev/null +++ b/third_party/krl/include/platform_macros.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#pragma once + +#define SUCCESS 0 +#define INVALPOINTER -1 +#define FAILALLOC -2 +#define INVALPARAM -3 +#define DOUBLEFREE -4 +#define UNSAFEMEM -5 +#define FAILIO -6 + +#define METRIC_INNER_PRODUCT 0 +#define METRIC_L2 1 + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +inline void prefetch_L1(const void *address) +{ + __builtin_prefetch(address, 0, 3); +} +inline void prefetch_L2(const void *address) +{ + __builtin_prefetch(address, 0, 2); +} +inline void prefetch_L3(const void *address) +{ + __builtin_prefetch(address, 0, 1); +} +inline void prefetch_Lx(const void *address) +{ + __builtin_prefetch(address, 0, 0); +} + +#define KRL_DEFAULT_ALIGNED (64) +#define ALIGNED(x) __attribute__((aligned(x))) + +#ifdef __GNUC__ + +#define KRL_IMPRECISE_FUNCTION_BEGIN \ + _Pragma("GCC push_options") _Pragma("GCC optimize (\"unroll-loops,associative-math,no-signed-zeros\")") +#define KRL_IMPRECISE_FUNCTION_END _Pragma("GCC pop_options") +#else +#define KRL_IMPRECISE_FUNCTION_BEGIN +#define KRL_IMPRECISE_FUNCTION_END +#endif \ No newline at end of file diff --git a/third_party/krl/include/safe_memory.h b/third_party/krl/include/safe_memory.h new file mode 100644 index 00000000..6aac31ef --- /dev/null +++ b/third_party/krl/include/safe_memory.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ + +#pragma once +#include +#include +#include + +namespace SafeMemory { + +template +int CheckAndMemcpy(D *dest, size_t destBufferSize, const S *src, size_t srcBufferSize) +{ + if (srcBufferSize > destBufferSize) { + std::cerr << "Memcpy failed: destBufferSize[" << destBufferSize << "] should be >= srcBufferSize[" + << srcBufferSize << "].\n"; + return -1; + } + if (dest == nullptr || src == nullptr) { + std::cerr << "Memcpy failed: null pointer detected\n"; + return -1; + } + memcpy(dest, src, srcBufferSize); + return 0; +} + +template +int CheckAndMemset(D *dest, size_t destBufferSize, int memsetValue, size_t setSize) +{ + if (setSize > destBufferSize) { + std::cerr << "Memset failed: destBufferSize[" << destBufferSize << "] should be >= setSize[" << setSize + << "].\n"; + return -1; + } + if (dest == nullptr) { + std::cerr << "Memset failed: null pointer detected\n"; + return -1; + } + memset(dest, memsetValue, setSize); + return 0; +} + +} // namespace SafeMemory \ No newline at end of file diff --git a/third_party/krl/src/IPdistance_simd.cpp b/third_party/krl/src/IPdistance_simd.cpp new file mode 100644 index 00000000..f4553ab3 --- /dev/null +++ b/third_party/krl/src/IPdistance_simd.cpp @@ -0,0 +1,1602 @@ +// Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +// SPDX-License-Identifier: Apache-2.0 +// Adapted from KRL (Kunpeng Retrieval Library) for ARM NEON optimizations. + +#include "krl.h" +#include "krl_internal.h" +#include "safe_memory.h" +#include "platform_macros.h" +#include + +extern "C" { + +/* +* @brief Compute the inner product of two float vectors. +* @param x Pointer to the first vector (float). +* @param y Pointer to the second vector (float). +* @param d Dimension of the vectors. +* @param dis Stores the inner product result (float). +* @param dis_size Length of dis. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +int krl_ipdis(const float *x, const float *__restrict y, const size_t d, float *dis, size_t dis_size) +{ + size_t i; + float res; + constexpr size_t single_round = 16; + + if (d < 1 || d > 65535) { + std::printf("Error: INVALPARAM in krl_ipdis\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || dis == nullptr || dis_size < 1) { + std::printf("Error: INVALPOINTER in krl_ipdis\n"); + return INVALPOINTER; + } + + if (likely(d >= single_round)) { + float32x4_t x8_0 = vld1q_f32(x); + float32x4_t x8_1 = vld1q_f32(x + 4); + float32x4_t x8_2 = vld1q_f32(x + 8); + float32x4_t x8_3 = vld1q_f32(x + 12); + + float32x4_t y8_0 = vld1q_f32(y); + float32x4_t y8_1 = vld1q_f32(y + 4); + float32x4_t y8_2 = vld1q_f32(y + 8); + float32x4_t y8_3 = vld1q_f32(y + 12); + + float32x4_t d8_0 = vmulq_f32(x8_0, y8_0); + float32x4_t d8_1 = vmulq_f32(x8_1, y8_1); + float32x4_t d8_2 = vmulq_f32(x8_2, y8_2); + float32x4_t d8_3 = vmulq_f32(x8_3, y8_3); + + for (i = single_round; i <= d - single_round; i += single_round) { + x8_0 = vld1q_f32(x + i); + y8_0 = vld1q_f32(y + i); + d8_0 = vmlaq_f32(d8_0, x8_0, y8_0); + + x8_1 = vld1q_f32(x + i + 4); + y8_1 = vld1q_f32(y + i + 4); + d8_1 = vmlaq_f32(d8_1, x8_1, y8_1); + + x8_2 = vld1q_f32(x + i + 8); + y8_2 = vld1q_f32(y + i + 8); + d8_2 = vmlaq_f32(d8_2, x8_2, y8_2); + + x8_3 = vld1q_f32(x + i + 12); + y8_3 = vld1q_f32(y + i + 12); + d8_3 = vmlaq_f32(d8_3, x8_3, y8_3); + } + + d8_0 = vaddq_f32(d8_0, d8_1); + d8_2 = vaddq_f32(d8_2, d8_3); + d8_0 = vaddq_f32(d8_0, d8_2); + res = vaddvq_f32(d8_0); + } else { + i = 0; + res = 0; + } + + for (; i < d; i++) { + const float tmp = x[i] * y[i]; + res += tmp; + } + *dis = res; + return SUCCESS; +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for two float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_batch2(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 8; + + if (likely(d >= single_round)) { + float32x4_t x_0 = vld1q_f32(x); + float32x4_t x_1 = vld1q_f32(x + 4); + + float32x4_t y0_0 = vld1q_f32(y); + float32x4_t y0_1 = vld1q_f32(y + 4); + float32x4_t y1_0 = vld1q_f32(y + d); + float32x4_t y1_1 = vld1q_f32(y + d + 4); + + float32x4_t d0_0 = vmulq_f32(x_0, y0_0); + float32x4_t d0_1 = vmulq_f32(x_1, y0_1); + float32x4_t d1_0 = vmulq_f32(x_0, y1_0); + float32x4_t d1_1 = vmulq_f32(x_1, y1_1); + + for (i = single_round; i <= d - single_round; i += single_round) { + x_0 = vld1q_f32(x + i); + y0_0 = vld1q_f32(y + i); + y1_0 = vld1q_f32(y + d + i); + d0_0 = vmlaq_f32(d0_0, x_0, y0_0); + d1_0 = vmlaq_f32(d1_0, x_0, y1_0); + + x_1 = vld1q_f32(x + i + 4); + y0_1 = vld1q_f32(y + i + 4); + y1_1 = vld1q_f32(y + d + i + 4); + d0_1 = vmlaq_f32(d0_1, x_1, y0_1); + d1_1 = vmlaq_f32(d1_1, x_1, y1_1); + } + + d0_0 = vaddq_f32(d0_0, d0_1); + d1_0 = vaddq_f32(d1_0, d1_1); + dis[0] = vaddvq_f32(d0_0); + dis[1] = vaddvq_f32(d1_0); + } else { + dis[0] = 0; + dis[1] = 0; + i = 0; + } + + for (; i < d; i++) { + const float tmp0 = x[i] * *(y + i); + const float tmp1 = x[i] * *(y + d + i); + dis[0] += tmp0; + dis[1] += tmp1; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for four float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_batch4(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + } else { + for (int i = 0; i < 4; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float d0 = x[i] * *(y + i); + float d1 = x[i] * *(y + d + i); + float d2 = x[i] * *(y + 2 * d + i); + float d3 = x[i] * *(y + 3 * d + i); + + for (i++; i < d; ++i) { + d0 += x[i] * *(y + i); + d1 += x[i] * *(y + d + i); + d2 += x[i] * *(y + 2 * d + i); + d3 += x[i] * *(y + 3 * d + i); + } + + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for eight float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_batch8(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + float32x4_t neon_base5 = vld1q_f32(y + 4 * d); + float32x4_t neon_base6 = vld1q_f32(y + 5 * d); + float32x4_t neon_base7 = vld1q_f32(y + 6 * d); + float32x4_t neon_base8 = vld1q_f32(y + 7 * d); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_query); + + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + neon_base5 = vld1q_f32(y + 4 * d + i); + neon_base6 = vld1q_f32(y + 5 * d + i); + neon_base7 = vld1q_f32(y + 6 * d + i); + neon_base8 = vld1q_f32(y + 7 * d + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + } + + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + } else { + for (int i = 0; i < 8; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float d0 = x[i] * *(y + i); + float d1 = x[i] * *(y + d + i); + float d2 = x[i] * *(y + 2 * d + i); + float d3 = x[i] * *(y + 3 * d + i); + float d4 = x[i] * *(y + 4 * d + i); + float d5 = x[i] * *(y + 5 * d + i); + float d6 = x[i] * *(y + 6 * d + i); + float d7 = x[i] * *(y + 7 * d + i); + + for (i++; i < d; ++i) { + d0 += x[i] * *(y + i); + d1 += x[i] * *(y + d + i); + d2 += x[i] * *(y + 2 * d + i); + d3 += x[i] * *(y + 3 * d + i); + d4 += x[i] * *(y + 4 * d + i); + d5 += x[i] * *(y + 5 * d + i); + d6 += x[i] * *(y + 6 * d + i); + d7 += x[i] * *(y + 7 * d + i); + } + + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for sixteen float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_batch16(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + + if (likely(d >= single_round)) { + /* Load query vector and database vectors */ + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + float32x4_t neon_base5 = vld1q_f32(y + 4 * d); + float32x4_t neon_base6 = vld1q_f32(y + 5 * d); + float32x4_t neon_base7 = vld1q_f32(y + 6 * d); + float32x4_t neon_base8 = vld1q_f32(y + 7 * d); + + /* Compute initial inner products */ + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_query); + + /* Load additional database vectors */ + neon_base1 = vld1q_f32(y + 8 * d); + neon_base2 = vld1q_f32(y + 9 * d); + neon_base3 = vld1q_f32(y + 10 * d); + neon_base4 = vld1q_f32(y + 11 * d); + neon_base5 = vld1q_f32(y + 12 * d); + neon_base6 = vld1q_f32(y + 13 * d); + neon_base7 = vld1q_f32(y + 14 * d); + neon_base8 = vld1q_f32(y + 15 * d); + + /* Compute additional inner products */ + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res13 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res14 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res15 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res16 = vmulq_f32(neon_base8, neon_query); + + for (i = single_round; i <= d - single_round; i += single_round) { + /* Update query and database vectors */ + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + neon_base5 = vld1q_f32(y + 4 * d + i); + neon_base6 = vld1q_f32(y + 5 * d + i); + neon_base7 = vld1q_f32(y + 6 * d + i); + neon_base8 = vld1q_f32(y + 7 * d + i); + + /* Update inner products for first 8 vectors */ + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + + /* Update database vectors for additional 8 vectors */ + neon_base1 = vld1q_f32(y + 8 * d + i); + neon_base2 = vld1q_f32(y + 9 * d + i); + neon_base3 = vld1q_f32(y + 10 * d + i); + neon_base4 = vld1q_f32(y + 11 * d + i); + neon_base5 = vld1q_f32(y + 12 * d + i); + neon_base6 = vld1q_f32(y + 13 * d + i); + neon_base7 = vld1q_f32(y + 14 * d + i); + neon_base8 = vld1q_f32(y + 15 * d + i); + + /* Update inner products for additional 8 vectors */ + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_query); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_query); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_query); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_query); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_query); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_query); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_query); + } + + /* Store results for all 16 vectors */ + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else { + /* Initialize results to zero if dimension is less than single_round */ + for (int i = 0; i < 16; i++) { + dis[i] = 0.0f; + } + i = 0; + } + + /* Handle remaining elements if dimension is not a multiple of single_round */ + if (i < d) { + float d0 = x[i] * *(y + i); + float d1 = x[i] * *(y + d + i); + float d2 = x[i] * *(y + 2 * d + i); + float d3 = x[i] * *(y + 3 * d + i); + float d4 = x[i] * *(y + 4 * d + i); + float d5 = x[i] * *(y + 5 * d + i); + float d6 = x[i] * *(y + 6 * d + i); + float d7 = x[i] * *(y + 7 * d + i); + float d8 = x[i] * *(y + 8 * d + i); + float d9 = x[i] * *(y + 9 * d + i); + float d10 = x[i] * *(y + 10 * d + i); + float d11 = x[i] * *(y + 11 * d + i); + float d12 = x[i] * *(y + 12 * d + i); + float d13 = x[i] * *(y + 13 * d + i); + float d14 = x[i] * *(y + 14 * d + i); + float d15 = x[i] * *(y + 15 * d + i); + + for (i++; i < d; ++i) { + d0 += x[i] * *(y + i); + d1 += x[i] * *(y + d + i); + d2 += x[i] * *(y + 2 * d + i); + d3 += x[i] * *(y + 3 * d + i); + d4 += x[i] * *(y + 4 * d + i); + d5 += x[i] * *(y + 5 * d + i); + d6 += x[i] * *(y + 6 * d + i); + d7 += x[i] * *(y + 7 * d + i); + d8 += x[i] * *(y + 8 * d + i); + d9 += x[i] * *(y + 9 * d + i); + d10 += x[i] * *(y + 10 * d + i); + d11 += x[i] * *(y + 11 * d + i); + d12 += x[i] * *(y + 12 * d + i); + d13 += x[i] * *(y + 13 * d + i); + d14 += x[i] * *(y + 14 * d + i); + d15 += x[i] * *(y + 15 * d + i); + } + + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for two vectors with float precision and store results in dis array. +* @param x Pointer to the query vector (float). +* @param y0 Pointer to the first database vector (float). +* @param y1 Pointer to the second database vector (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_idx_batch2( + const float *x, const float *__restrict y0, const float *__restrict y1, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 8; + + if (likely(d >= single_round)) { + float32x4_t x_0 = vld1q_f32(x); + float32x4_t x_1 = vld1q_f32(x + 4); + + float32x4_t y0_0 = vld1q_f32(y0); + float32x4_t y0_1 = vld1q_f32(y0 + 4); + float32x4_t y1_0 = vld1q_f32(y1); + float32x4_t y1_1 = vld1q_f32(y1 + 4); + + float32x4_t d0_0 = vmulq_f32(x_0, y0_0); + float32x4_t d0_1 = vmulq_f32(x_1, y0_1); + float32x4_t d1_0 = vmulq_f32(x_0, y1_0); + float32x4_t d1_1 = vmulq_f32(x_1, y1_1); + for (i = single_round; i <= d - single_round; i += single_round) { + x_0 = vld1q_f32(x + i); + y0_0 = vld1q_f32(y0 + i); + y1_0 = vld1q_f32(y1 + i); + d0_0 = vmlaq_f32(d0_0, x_0, y0_0); + d1_0 = vmlaq_f32(d1_0, x_0, y1_0); + + x_1 = vld1q_f32(x + i + 4); + y0_1 = vld1q_f32(y0 + i + 4); + y1_1 = vld1q_f32(y1 + i + 4); + d0_1 = vmlaq_f32(d0_1, x_1, y0_1); + d1_1 = vmlaq_f32(d1_1, x_1, y1_1); + } + + d0_0 = vaddq_f32(d0_0, d0_1); + d1_0 = vaddq_f32(d1_0, d1_1); + dis[0] = vaddvq_f32(d0_0); + dis[1] = vaddvq_f32(d1_0); + } else { + dis[0] = 0; + dis[1] = 0; + i = 0; + } + + for (; i < d; i++) { + const float tmp0 = x[i] * y0[i]; + const float tmp1 = x[i] * y1[i]; + dis[0] += tmp0; + dis[1] += tmp1; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for four vectors with float precision and store results in dis array. +* @param x Pointer to the query vector (float). +* @param y Array of pointers to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_idx_batch4(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + } else { + for (int i = 0; i < 4; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float d0 = x[i] * *(y[0] + i); + float d1 = x[i] * *(y[1] + i); + float d2 = x[i] * *(y[2] + i); + float d3 = x[i] * *(y[3] + i); + + for (i++; i < d; ++i) { + d0 += x[i] * *(y[0] + i); + d1 += x[i] * *(y[1] + i); + d2 += x[i] * *(y[2] + i); + d3 += x[i] * *(y[3] + i); + } + + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for eight vectors with float precision and store results in dis array. +* @param x Pointer to the query vector (float). +* @param y Array of pointers to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_idx_batch8(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + float32x4_t neon_base5 = vld1q_f32(y[4]); + float32x4_t neon_base6 = vld1q_f32(y[5]); + float32x4_t neon_base7 = vld1q_f32(y[6]); + float32x4_t neon_base8 = vld1q_f32(y[7]); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_query); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + neon_base5 = vld1q_f32(y[4] + i); + neon_base6 = vld1q_f32(y[5] + i); + neon_base7 = vld1q_f32(y[6] + i); + neon_base8 = vld1q_f32(y[7] + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + } else { + for (int i = 0; i < 8; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float d0 = x[i] * *(y[0] + i); + float d1 = x[i] * *(y[1] + i); + float d2 = x[i] * *(y[2] + i); + float d3 = x[i] * *(y[3] + i); + float d4 = x[i] * *(y[4] + i); + float d5 = x[i] * *(y[5] + i); + float d6 = x[i] * *(y[6] + i); + float d7 = x[i] * *(y[7] + i); + for (i++; i < d; ++i) { + d0 += x[i] * *(y[0] + i); + d1 += x[i] * *(y[1] + i); + d2 += x[i] * *(y[2] + i); + d3 += x[i] * *(y[3] + i); + d4 += x[i] * *(y[4] + i); + d5 += x[i] * *(y[5] + i); + d6 += x[i] * *(y[6] + i); + d7 += x[i] * *(y[7] + i); + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for sixteen vectors with indices and prefetch optimization. +* @param x Pointer to the query vector (float). +* @param y Array of pointers to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_idx_prefetch_batch16( + const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128 / 8 */ + constexpr size_t multi_round = 32; /* 8 * single_round */ + if (d >= multi_round) { + prefetch_L1(x + multi_round); + prefetch_Lx(y[0] + multi_round); + prefetch_Lx(y[1] + multi_round); + prefetch_Lx(y[2] + multi_round); + prefetch_Lx(y[3] + multi_round); + prefetch_Lx(y[4] + multi_round); + prefetch_Lx(y[5] + multi_round); + prefetch_Lx(y[6] + multi_round); + prefetch_Lx(y[7] + multi_round); + prefetch_Lx(y[8] + multi_round); + prefetch_Lx(y[9] + multi_round); + prefetch_Lx(y[10] + multi_round); + prefetch_Lx(y[11] + multi_round); + prefetch_Lx(y[12] + multi_round); + prefetch_Lx(y[13] + multi_round); + prefetch_Lx(y[14] + multi_round); + prefetch_Lx(y[15] + multi_round); + float32x4_t neon_res1, neon_res2, neon_res3, neon_res4; + float32x4_t neon_res5, neon_res6, neon_res7, neon_res8; + float32x4_t neon_res9, neon_res10, neon_res11, neon_res12; + float32x4_t neon_res13, neon_res14, neon_res15, neon_res16; + { + const float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + float32x4_t neon_base5 = vld1q_f32(y[4]); + float32x4_t neon_base6 = vld1q_f32(y[5]); + float32x4_t neon_base7 = vld1q_f32(y[6]); + float32x4_t neon_base8 = vld1q_f32(y[7]); + + neon_res1 = vmulq_f32(neon_base1, neon_query); + neon_res2 = vmulq_f32(neon_base2, neon_query); + neon_res3 = vmulq_f32(neon_base3, neon_query); + neon_res4 = vmulq_f32(neon_base4, neon_query); + neon_res5 = vmulq_f32(neon_base5, neon_query); + neon_res6 = vmulq_f32(neon_base6, neon_query); + neon_res7 = vmulq_f32(neon_base7, neon_query); + neon_res8 = vmulq_f32(neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8]); + neon_base2 = vld1q_f32(y[9]); + neon_base3 = vld1q_f32(y[10]); + neon_base4 = vld1q_f32(y[11]); + neon_base5 = vld1q_f32(y[12]); + neon_base6 = vld1q_f32(y[13]); + neon_base7 = vld1q_f32(y[14]); + neon_base8 = vld1q_f32(y[15]); + + neon_res9 = vmulq_f32(neon_base1, neon_query); + neon_res10 = vmulq_f32(neon_base2, neon_query); + neon_res11 = vmulq_f32(neon_base3, neon_query); + neon_res12 = vmulq_f32(neon_base4, neon_query); + neon_res13 = vmulq_f32(neon_base5, neon_query); + neon_res14 = vmulq_f32(neon_base6, neon_query); + neon_res15 = vmulq_f32(neon_base7, neon_query); + neon_res16 = vmulq_f32(neon_base8, neon_query); + } + for (i = single_round; i < multi_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + float32x4_t neon_base5 = vld1q_f32(y[4] + i); + float32x4_t neon_base6 = vld1q_f32(y[5] + i); + float32x4_t neon_base7 = vld1q_f32(y[6] + i); + float32x4_t neon_base8 = vld1q_f32(y[7] + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base5 = vld1q_f32(y[12] + i); + neon_base6 = vld1q_f32(y[13] + i); + neon_base7 = vld1q_f32(y[14] + i); + neon_base8 = vld1q_f32(y[15] + i); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_query); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_query); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_query); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_query); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_query); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_query); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_query); + } + for (; i < d - multi_round; i += multi_round) { + prefetch_L1(x + multi_round + i); + prefetch_Lx(y[0] + multi_round + i); + prefetch_Lx(y[1] + multi_round + i); + prefetch_Lx(y[2] + multi_round + i); + prefetch_Lx(y[3] + multi_round + i); + prefetch_Lx(y[4] + multi_round + i); + prefetch_Lx(y[5] + multi_round + i); + prefetch_Lx(y[6] + multi_round + i); + prefetch_Lx(y[7] + multi_round + i); + prefetch_Lx(y[8] + multi_round + i); + prefetch_Lx(y[9] + multi_round + i); + prefetch_Lx(y[10] + multi_round + i); + prefetch_Lx(y[11] + multi_round + i); + prefetch_Lx(y[12] + multi_round + i); + prefetch_Lx(y[13] + multi_round + i); + prefetch_Lx(y[14] + multi_round + i); + prefetch_Lx(y[15] + multi_round + i); + for (size_t j = 0; j < multi_round; j += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i + j); + float32x4_t neon_base1 = vld1q_f32(y[0] + i + j); + float32x4_t neon_base2 = vld1q_f32(y[1] + i + j); + float32x4_t neon_base3 = vld1q_f32(y[2] + i + j); + float32x4_t neon_base4 = vld1q_f32(y[3] + i + j); + float32x4_t neon_base5 = vld1q_f32(y[4] + i + j); + float32x4_t neon_base6 = vld1q_f32(y[5] + i + j); + float32x4_t neon_base7 = vld1q_f32(y[6] + i + j); + float32x4_t neon_base8 = vld1q_f32(y[7] + i + j); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8] + i + j); + neon_base2 = vld1q_f32(y[9] + i + j); + neon_base3 = vld1q_f32(y[10] + i + j); + neon_base4 = vld1q_f32(y[11] + i + j); + neon_base5 = vld1q_f32(y[12] + i + j); + neon_base6 = vld1q_f32(y[13] + i + j); + neon_base7 = vld1q_f32(y[14] + i + j); + neon_base8 = vld1q_f32(y[15] + i + j); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_query); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_query); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_query); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_query); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_query); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_query); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_query); + } + } + for (; i <= d - single_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + float32x4_t neon_base5 = vld1q_f32(y[4] + i); + float32x4_t neon_base6 = vld1q_f32(y[5] + i); + float32x4_t neon_base7 = vld1q_f32(y[6] + i); + float32x4_t neon_base8 = vld1q_f32(y[7] + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base5 = vld1q_f32(y[12] + i); + neon_base6 = vld1q_f32(y[13] + i); + neon_base7 = vld1q_f32(y[14] + i); + neon_base8 = vld1q_f32(y[15] + i); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_query); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_query); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_query); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_query); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_query); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_query); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_query); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else if (d >= single_round) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + float32x4_t neon_base5 = vld1q_f32(y[4]); + float32x4_t neon_base6 = vld1q_f32(y[5]); + float32x4_t neon_base7 = vld1q_f32(y[6]); + float32x4_t neon_base8 = vld1q_f32(y[7]); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8]); + neon_base2 = vld1q_f32(y[9]); + neon_base3 = vld1q_f32(y[10]); + neon_base4 = vld1q_f32(y[11]); + neon_base5 = vld1q_f32(y[12]); + neon_base6 = vld1q_f32(y[13]); + neon_base7 = vld1q_f32(y[14]); + neon_base8 = vld1q_f32(y[15]); + + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_query); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_query); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_query); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_query); + float32x4_t neon_res13 = vmulq_f32(neon_base5, neon_query); + float32x4_t neon_res14 = vmulq_f32(neon_base6, neon_query); + float32x4_t neon_res15 = vmulq_f32(neon_base7, neon_query); + float32x4_t neon_res16 = vmulq_f32(neon_base8, neon_query); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + neon_base5 = vld1q_f32(y[4] + i); + neon_base6 = vld1q_f32(y[5] + i); + neon_base7 = vld1q_f32(y[6] + i); + neon_base8 = vld1q_f32(y[7] + i); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_query); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_query); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_query); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_query); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_query); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_query); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_query); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base5 = vld1q_f32(y[12] + i); + neon_base6 = vld1q_f32(y[13] + i); + neon_base7 = vld1q_f32(y[14] + i); + neon_base8 = vld1q_f32(y[15] + i); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_query); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_query); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_query); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_query); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_query); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_query); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_query); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else { + for (int i = 0; i < 16; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float d0 = x[i] * *(y[0] + i); + float d1 = x[i] * *(y[1] + i); + float d2 = x[i] * *(y[2] + i); + float d3 = x[i] * *(y[3] + i); + float d4 = x[i] * *(y[4] + i); + float d5 = x[i] * *(y[5] + i); + float d6 = x[i] * *(y[6] + i); + float d7 = x[i] * *(y[7] + i); + float d8 = x[i] * *(y[8] + i); + float d9 = x[i] * *(y[9] + i); + float d10 = x[i] * *(y[10] + i); + float d11 = x[i] * *(y[11] + i); + float d12 = x[i] * *(y[12] + i); + float d13 = x[i] * *(y[13] + i); + float d14 = x[i] * *(y[14] + i); + float d15 = x[i] * *(y[15] + i); + for (i++; i < d; ++i) { + d0 += x[i] * *(y[0] + i); + d1 += x[i] * *(y[1] + i); + d2 += x[i] * *(y[2] + i); + d3 += x[i] * *(y[3] + i); + d4 += x[i] * *(y[4] + i); + d5 += x[i] * *(y[5] + i); + d6 += x[i] * *(y[6] + i); + d7 += x[i] * *(y[7] + i); + d8 += x[i] * *(y[8] + i); + d9 += x[i] * *(y[9] + i); + d10 += x[i] * *(y[10] + i); + d11 += x[i] * *(y[11] + i); + d12 += x[i] * *(y[12] + i); + d13 += x[i] * *(y[13] + i); + d14 += x[i] * *(y[14] + i); + d15 += x[i] * *(y[15] + i); + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for 16 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_continuous_transpose_large_kernel( + float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[16]; + float32x4_t single_query = vdupq_n_f32(x[0]); + + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + 4); + float32x4_t neon_base3 = vld1q_f32(y + 8); + float32x4_t neon_base4 = vld1q_f32(y + 12); + float32x4_t neon_base5 = vld1q_f32(y + 16); + float32x4_t neon_base6 = vld1q_f32(y + 20); + float32x4_t neon_base7 = vld1q_f32(y + 24); + float32x4_t neon_base8 = vld1q_f32(y + 28); + + neon_res[0] = vmulq_f32(neon_base1, single_query); + neon_res[1] = vmulq_f32(neon_base2, single_query); + neon_res[2] = vmulq_f32(neon_base3, single_query); + neon_res[3] = vmulq_f32(neon_base4, single_query); + neon_res[4] = vmulq_f32(neon_base5, single_query); + neon_res[5] = vmulq_f32(neon_base6, single_query); + neon_res[6] = vmulq_f32(neon_base7, single_query); + neon_res[7] = vmulq_f32(neon_base8, single_query); + + neon_base1 = vld1q_f32(y + 32); + neon_base2 = vld1q_f32(y + 36); + neon_base3 = vld1q_f32(y + 40); + neon_base4 = vld1q_f32(y + 44); + neon_base5 = vld1q_f32(y + 48); + neon_base6 = vld1q_f32(y + 52); + neon_base7 = vld1q_f32(y + 56); + neon_base8 = vld1q_f32(y + 60); + + neon_res[8] = vmulq_f32(neon_base1, single_query); + neon_res[9] = vmulq_f32(neon_base2, single_query); + neon_res[10] = vmulq_f32(neon_base3, single_query); + neon_res[11] = vmulq_f32(neon_base4, single_query); + neon_res[12] = vmulq_f32(neon_base5, single_query); + neon_res[13] = vmulq_f32(neon_base6, single_query); + neon_res[14] = vmulq_f32(neon_base7, single_query); + neon_res[15] = vmulq_f32(neon_base8, single_query); + + /* dim loop */ + for (size_t i = 1; i < d; ++i) { + single_query = vdupq_n_f32(x[i]); + neon_base1 = vld1q_f32(y + 64 * i); + neon_base2 = vld1q_f32(y + 64 * i + 4); + neon_base3 = vld1q_f32(y + 64 * i + 8); + neon_base4 = vld1q_f32(y + 64 * i + 12); + neon_base5 = vld1q_f32(y + 64 * i + 16); + neon_base6 = vld1q_f32(y + 64 * i + 20); + neon_base7 = vld1q_f32(y + 64 * i + 24); + neon_base8 = vld1q_f32(y + 64 * i + 28); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_base1, single_query); + neon_res[1] = vmlaq_f32(neon_res[1], neon_base2, single_query); + neon_res[2] = vmlaq_f32(neon_res[2], neon_base3, single_query); + neon_res[3] = vmlaq_f32(neon_res[3], neon_base4, single_query); + neon_res[4] = vmlaq_f32(neon_res[4], neon_base5, single_query); + neon_res[5] = vmlaq_f32(neon_res[5], neon_base6, single_query); + neon_res[6] = vmlaq_f32(neon_res[6], neon_base7, single_query); + neon_res[7] = vmlaq_f32(neon_res[7], neon_base8, single_query); + + neon_base1 = vld1q_f32(y + 64 * i + 32); + neon_base2 = vld1q_f32(y + 64 * i + 36); + neon_base3 = vld1q_f32(y + 64 * i + 40); + neon_base4 = vld1q_f32(y + 64 * i + 44); + neon_base5 = vld1q_f32(y + 64 * i + 48); + neon_base6 = vld1q_f32(y + 64 * i + 52); + neon_base7 = vld1q_f32(y + 64 * i + 56); + neon_base8 = vld1q_f32(y + 64 * i + 60); + + neon_res[8] = vmlaq_f32(neon_res[8], neon_base1, single_query); + neon_res[9] = vmlaq_f32(neon_res[9], neon_base2, single_query); + neon_res[10] = vmlaq_f32(neon_res[10], neon_base3, single_query); + neon_res[11] = vmlaq_f32(neon_res[11], neon_base4, single_query); + neon_res[12] = vmlaq_f32(neon_res[12], neon_base5, single_query); + neon_res[13] = vmlaq_f32(neon_res[13], neon_base6, single_query); + neon_res[14] = vmlaq_f32(neon_res[14], neon_base7, single_query); + neon_res[15] = vmlaq_f32(neon_res[15], neon_base8, single_query); + } + { + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); + vst1q_f32(dis + 16, neon_res[4]); + vst1q_f32(dis + 20, neon_res[5]); + vst1q_f32(dis + 24, neon_res[6]); + vst1q_f32(dis + 28, neon_res[7]); + vst1q_f32(dis + 32, neon_res[8]); + vst1q_f32(dis + 36, neon_res[9]); + vst1q_f32(dis + 40, neon_res[10]); + vst1q_f32(dis + 44, neon_res[11]); + vst1q_f32(dis + 48, neon_res[12]); + vst1q_f32(dis + 52, neon_res[13]); + vst1q_f32(dis + 56, neon_res[14]); + vst1q_f32(dis + 60, neon_res[15]); + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for 8 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_continuous_transpose_medium_kernel( + float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[8]; + float32x4_t single_query = vdupq_n_f32(x[0]); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + 4); + float32x4_t neon_base3 = vld1q_f32(y + 8); + float32x4_t neon_base4 = vld1q_f32(y + 12); + float32x4_t neon_base5 = vld1q_f32(y + 16); + float32x4_t neon_base6 = vld1q_f32(y + 20); + float32x4_t neon_base7 = vld1q_f32(y + 24); + float32x4_t neon_base8 = vld1q_f32(y + 28); + + neon_res[0] = vmulq_f32(neon_base1, single_query); + neon_res[1] = vmulq_f32(neon_base2, single_query); + neon_res[2] = vmulq_f32(neon_base3, single_query); + neon_res[3] = vmulq_f32(neon_base4, single_query); + neon_res[4] = vmulq_f32(neon_base5, single_query); + neon_res[5] = vmulq_f32(neon_base6, single_query); + neon_res[6] = vmulq_f32(neon_base7, single_query); + neon_res[7] = vmulq_f32(neon_base8, single_query); + + /* dim loop */ + for (size_t i = 1; i < d; ++i) { + single_query = vdupq_n_f32(x[i]); + neon_base1 = vld1q_f32(y + 32 * i); + neon_base2 = vld1q_f32(y + 32 * i + 4); + neon_base3 = vld1q_f32(y + 32 * i + 8); + neon_base4 = vld1q_f32(y + 32 * i + 12); + neon_base5 = vld1q_f32(y + 32 * i + 16); + neon_base6 = vld1q_f32(y + 32 * i + 20); + neon_base7 = vld1q_f32(y + 32 * i + 24); + neon_base8 = vld1q_f32(y + 32 * i + 28); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_base1, single_query); + neon_res[1] = vmlaq_f32(neon_res[1], neon_base2, single_query); + neon_res[2] = vmlaq_f32(neon_res[2], neon_base3, single_query); + neon_res[3] = vmlaq_f32(neon_res[3], neon_base4, single_query); + neon_res[4] = vmlaq_f32(neon_res[4], neon_base5, single_query); + neon_res[5] = vmlaq_f32(neon_res[5], neon_base6, single_query); + neon_res[6] = vmlaq_f32(neon_res[6], neon_base7, single_query); + neon_res[7] = vmlaq_f32(neon_res[7], neon_base8, single_query); + } + { + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); + vst1q_f32(dis + 16, neon_res[4]); + vst1q_f32(dis + 20, neon_res[5]); + vst1q_f32(dis + 24, neon_res[6]); + vst1q_f32(dis + 28, neon_res[7]); + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for 4 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_inner_product_continuous_transpose_mini_kernel( + float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[4]; + float32x4_t single_query = vdupq_n_f32(x[0]); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + 4); + float32x4_t neon_base3 = vld1q_f32(y + 8); + float32x4_t neon_base4 = vld1q_f32(y + 12); + + neon_res[0] = vmulq_f32(neon_base1, single_query); + neon_res[1] = vmulq_f32(neon_base2, single_query); + neon_res[2] = vmulq_f32(neon_base3, single_query); + neon_res[3] = vmulq_f32(neon_base4, single_query); + + /* dim loop */ + for (size_t i = 1; i < d; ++i) { + single_query = vdupq_n_f32(x[i]); + neon_base1 = vld1q_f32(y + 16 * i); + neon_base2 = vld1q_f32(y + 16 * i + 4); + neon_base3 = vld1q_f32(y + 16 * i + 8); + neon_base4 = vld1q_f32(y + 16 * i + 12); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_base1, single_query); + neon_res[1] = vmlaq_f32(neon_res[1], neon_base2, single_query); + neon_res[2] = vmlaq_f32(neon_res[2], neon_base3, single_query); + neon_res[3] = vmlaq_f32(neon_res[3], neon_base4, single_query); + } + + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute inner products for a batch of vectors based on given indices. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param ids Pointer to the indices array for selecting database vectors. +* @param d Dimension of the vectors. +* @param ny Number of database vectors to process. +* @param dis_size Length of dis. +*/ +int krl_inner_product_by_idx( + float *dis, const float *x, const float *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size) +{ + size_t i = 0; + const float *__restrict listy[16]; + + if (d < 1 || d > 65535 || ny < 1 || ny > 1ULL << 30) { + std::printf("Error: INVALPARAM in krl_inner_product_by_idx\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || ids == nullptr || dis == nullptr || dis_size < ny) { + std::printf("Error: INVALPOINTER in krl_inner_product_by_idx\n"); + return INVALPOINTER; + } + + for (; i + 16 <= ny; i += 16) { + /* Prefetch data for better cache utilization */ + prefetch_L1(x); + listy[0] = (const float *)(y + *(ids + i) * d); + prefetch_Lx(listy[0]); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + prefetch_Lx(listy[1]); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + prefetch_Lx(listy[2]); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + prefetch_Lx(listy[3]); + listy[4] = (const float *)(y + *(ids + i + 4) * d); + prefetch_Lx(listy[4]); + listy[5] = (const float *)(y + *(ids + i + 5) * d); + prefetch_Lx(listy[5]); + listy[6] = (const float *)(y + *(ids + i + 6) * d); + prefetch_Lx(listy[6]); + listy[7] = (const float *)(y + *(ids + i + 7) * d); + prefetch_Lx(listy[7]); + listy[8] = (const float *)(y + *(ids + i + 8) * d); + prefetch_Lx(listy[8]); + listy[9] = (const float *)(y + *(ids + i + 9) * d); + prefetch_Lx(listy[9]); + listy[10] = (const float *)(y + *(ids + i + 10) * d); + prefetch_Lx(listy[10]); + listy[11] = (const float *)(y + *(ids + i + 11) * d); + prefetch_Lx(listy[11]); + listy[12] = (const float *)(y + *(ids + i + 12) * d); + prefetch_Lx(listy[12]); + listy[13] = (const float *)(y + *(ids + i + 13) * d); + prefetch_Lx(listy[13]); + listy[14] = (const float *)(y + *(ids + i + 14) * d); + prefetch_Lx(listy[14]); + listy[15] = (const float *)(y + *(ids + i + 15) * d); + prefetch_Lx(listy[15]); + krl_inner_product_idx_prefetch_batch16(x, listy, d, dis + i); + } + if (ny & 8) { + listy[0] = (const float *)(y + *(ids + i) * d); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + listy[4] = (const float *)(y + *(ids + i + 4) * d); + listy[5] = (const float *)(y + *(ids + i + 5) * d); + listy[6] = (const float *)(y + *(ids + i + 6) * d); + listy[7] = (const float *)(y + *(ids + i + 7) * d); + krl_inner_product_idx_batch8(x, listy, d, dis + i); + i += 8; + } + if (ny & 4) { + listy[0] = (const float *)(y + *(ids + i) * d); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + krl_inner_product_idx_batch4(x, listy, d, dis + i); + i += 4; + } + if (ny & 2) { + const float *y0 = y + *(ids + i) * d; + const float *y1 = y + *(ids + i + 1) * d; + krl_inner_product_idx_batch2(x, y0, y1, d, dis + i); + i += 2; + } + if (ny & 1) { + krl_ipdis(x, y + d * ids[i], d, &dis[i], 1); + } + return SUCCESS; +} + +/* +* @brief Compute inner products for a batch of vectors. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param ny Number of database vectors to process. +* @param d Dimension of the vectors. +*/ +int krl_inner_product_ny(float *dis, const float *x, const float *y, const size_t ny, const size_t d, size_t dis_size) +{ + size_t i = 0; + + if (d < 1 || d > 65535 || ny < 1 || ny > 1ULL << 30) { + std::printf("Error: INVALPARAM in krl_inner_product_ny\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || dis == nullptr || dis_size < ny) { + std::printf("Error: INVALPOINTER in krl_inner_product_ny\n"); + return INVALPOINTER; + } + + for (; i + 16 <= ny; i += 16) { + krl_inner_product_batch16(x, y + i * d, d, dis + i); + } + if (ny & 8) { + krl_inner_product_batch8(x, y + i * d, d, dis + i); + i += 8; + } + if (ny & 4) { + krl_inner_product_batch4(x, y + i * d, d, dis + i); + i += 4; + } + if (ny & 2) { + krl_inner_product_batch2(x, y + i * d, d, dis + i); + } + if (ny & 1) { + krl_ipdis(x, y + (ny - 1) * d, d, &dis[ny - 1], 1); + } + return SUCCESS; +} + +/* +* @brief Compute inner products for a batch of vectors with a given handle. +* @param kdh Pointer to the distance handle containing configuration and data. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param dis_size Length of dis. +* @param x_size Length of x. +*/ +int krl_inner_product_ny_with_handle( + const KRLDistanceHandle *kdh, float *dis, const float *x, size_t dis_size, size_t x_size) +{ + if (kdh == nullptr || dis == nullptr || x == nullptr) { + std::printf("Error: INVALPOINTER in krl_inner_product_ny_with_handle\n"); + return INVALPOINTER; + } + const size_t ny = kdh->ny; + const size_t dim = kdh->d; + const size_t M = kdh->M; + if (dis_size < M * ny || x_size < dim * M) { + std::printf("Error: INVALPARAM in krl_inner_product_ny_with_handle\n"); + return INVALPARAM; + } + + if (kdh->data_bits == 32) { + const size_t ceil_ny = kdh->ceil_ny; + const float *y = (const float *)kdh->transposed_codes; + const size_t left = ny & (kdh->blocksize - 1); + switch (kdh->blocksize) { + case 16: + if (left) { + float distance_tmp_buffer[16]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 16 <= ny; i += 16) { + krl_inner_product_continuous_transpose_mini_kernel(dis + i, x, y + i * dim, dim); + } + krl_inner_product_continuous_transpose_mini_kernel(distance_tmp_buffer, x, y + i * dim, dim); + + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 16) { + krl_inner_product_continuous_transpose_mini_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + case 32: + if (left) { + float distance_tmp_buffer[32]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 32 <= ny; i += 32) { + krl_inner_product_continuous_transpose_medium_kernel(dis + i, x, y + i * dim, dim); + } + krl_inner_product_continuous_transpose_medium_kernel(distance_tmp_buffer, x, y + i * dim, dim); + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 32) { + krl_inner_product_continuous_transpose_medium_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + case 64: + if (left) { + float distance_tmp_buffer[64]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 64 <= ny; i += 64) { + krl_inner_product_continuous_transpose_large_kernel(dis + i, x, y + i * dim, dim); + } + krl_inner_product_continuous_transpose_large_kernel(distance_tmp_buffer, x, y + i * dim, dim); + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_inner_product_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 64) { + krl_inner_product_continuous_transpose_large_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + } + } else if (kdh->data_bits == 16) { + // fp16 path not built in minimal KRL for OpenViking + std::printf("Error: INVALPARAM in krl_inner_product_ny_with_handle (fp16 not supported)\n"); + return INVALPARAM; + } else { + // int8 path not built in minimal KRL for OpenViking + std::printf("Error: INVALPARAM in krl_inner_product_ny_with_handle (int8 not supported)\n"); + return INVALPARAM; + } + return SUCCESS; +} + +} // extern "C" diff --git a/third_party/krl/src/L2distance_simd.cpp b/third_party/krl/src/L2distance_simd.cpp new file mode 100644 index 00000000..47d8c33f --- /dev/null +++ b/third_party/krl/src/L2distance_simd.cpp @@ -0,0 +1,3193 @@ +// Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +// SPDX-License-Identifier: Apache-2.0 +// Adapted from KRL (Kunpeng Retrieval Library) for ARM NEON optimizations. + +#include "krl.h" +#include "krl_internal.h" +#include "platform_macros.h" +#include "safe_memory.h" +#include + +extern "C" { + +/* +* @brief Compute the L2 square of two float vectors. +* @param x Pointer to the first vector (float). +* @param y Pointer to the second vector (float). +* @param d Dimension of the vectors. +* @param dis Stores the computed L2 square result (float). +* @param dis_size Length of dis. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +int krl_L2sqr(const float *x, const float *__restrict y, const size_t d, float *dis, size_t dis_size) +{ + constexpr size_t single_round = 4; + constexpr size_t multi_round = 16; + size_t i; + float res; + + if (d < 1 || d > 65535) { + std::printf("Error: INVALPARAM in krl_L2sqr\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || dis == nullptr || dis_size < 1) { + std::printf("Error: INVALPOINTER in krl_L2sqr\n"); + return INVALPOINTER; + } + + if (likely(d >= multi_round)) { + prefetch_Lx(x + multi_round); + prefetch_Lx(y + multi_round); + float32x4_t x8_0 = vld1q_f32(x); + float32x4_t x8_1 = vld1q_f32(x + 4); + float32x4_t x8_2 = vld1q_f32(x + 8); + float32x4_t x8_3 = vld1q_f32(x + 12); + + float32x4_t y8_0 = vld1q_f32(y); + float32x4_t y8_1 = vld1q_f32(y + 4); + float32x4_t y8_2 = vld1q_f32(y + 8); + float32x4_t y8_3 = vld1q_f32(y + 12); + + float32x4_t d8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmulq_f32(d8_0, d8_0); + float32x4_t d8_1 = vsubq_f32(x8_1, y8_1); + d8_1 = vmulq_f32(d8_1, d8_1); + float32x4_t d8_2 = vsubq_f32(x8_2, y8_2); + d8_2 = vmulq_f32(d8_2, d8_2); + float32x4_t d8_3 = vsubq_f32(x8_3, y8_3); + d8_3 = vmulq_f32(d8_3, d8_3); + + for (i = multi_round; i <= d - multi_round; i += multi_round) { + prefetch_Lx(x + i + multi_round); + prefetch_Lx(y + i + multi_round); + x8_0 = vld1q_f32(x + i); + y8_0 = vld1q_f32(y + i); + const float32x4_t q8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmlaq_f32(d8_0, q8_0, q8_0); + + x8_1 = vld1q_f32(x + i + 4); + y8_1 = vld1q_f32(y + i + 4); + const float32x4_t q8_1 = vsubq_f32(x8_1, y8_1); + d8_1 = vmlaq_f32(d8_1, q8_1, q8_1); + + x8_2 = vld1q_f32(x + i + 8); + y8_2 = vld1q_f32(y + i + 8); + const float32x4_t q8_2 = vsubq_f32(x8_2, y8_2); + d8_2 = vmlaq_f32(d8_2, q8_2, q8_2); + + x8_3 = vld1q_f32(x + i + 12); + y8_3 = vld1q_f32(y + i + 12); + const float32x4_t q8_3 = vsubq_f32(x8_3, y8_3); + d8_3 = vmlaq_f32(d8_3, q8_3, q8_3); + } + + for (; i <= d - single_round; i += single_round) { + x8_0 = vld1q_f32(x + i); + y8_0 = vld1q_f32(y + i); + const float32x4_t q8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmlaq_f32(d8_0, q8_0, q8_0); + } + + d8_0 = vaddq_f32(d8_0, d8_1); + d8_2 = vaddq_f32(d8_2, d8_3); + d8_0 = vaddq_f32(d8_0, d8_2); + res = vaddvq_f32(d8_0); + } else if (d >= single_round) { + float32x4_t x8_0 = vld1q_f32(x); + float32x4_t y8_0 = vld1q_f32(y); + + float32x4_t d8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmulq_f32(d8_0, d8_0); + for (i = single_round; i <= d - single_round; i += single_round) { + x8_0 = vld1q_f32(x + i); + y8_0 = vld1q_f32(y + i); + const float32x4_t q8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmlaq_f32(d8_0, q8_0, q8_0); + } + res = vaddvq_f32(d8_0); + } else { + res = 0; + i = 0; + } + + for (; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + *dis = res; + return SUCCESS; +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for two float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_idx_batch2( + const float *x, const float *__restrict y0, const float *__restrict y1, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; + constexpr size_t multi_round = 8; + + if (likely(d >= multi_round)) { + float32x4_t x_0 = vld1q_f32(x); + float32x4_t x_1 = vld1q_f32(x + 4); + + float32x4_t y0_0 = vld1q_f32(y0); + float32x4_t y0_1 = vld1q_f32(y0 + 4); + float32x4_t y1_0 = vld1q_f32(y1); + float32x4_t y1_1 = vld1q_f32(y1 + 4); + + float32x4_t d0_0 = vsubq_f32(x_0, y0_0); + d0_0 = vmulq_f32(d0_0, d0_0); + float32x4_t d0_1 = vsubq_f32(x_1, y0_1); + d0_1 = vmulq_f32(d0_1, d0_1); + float32x4_t d1_0 = vsubq_f32(x_0, y1_0); + d1_0 = vmulq_f32(d1_0, d1_0); + float32x4_t d1_1 = vsubq_f32(x_1, y1_1); + d1_1 = vmulq_f32(d1_1, d1_1); + + for (i = multi_round; i <= d - multi_round; i += multi_round) { + x_0 = vld1q_f32(x + i); + y0_0 = vld1q_f32(y0 + i); + y1_0 = vld1q_f32(y1 + i); + const float32x4_t q0_0 = vsubq_f32(x_0, y0_0); + const float32x4_t q1_0 = vsubq_f32(x_0, y1_0); + d0_0 = vmlaq_f32(d0_0, q0_0, q0_0); + d1_0 = vmlaq_f32(d1_0, q1_0, q1_0); + + x_1 = vld1q_f32(x + i + 4); + y0_1 = vld1q_f32(y0 + i + 4); + y1_1 = vld1q_f32(y1 + i + 4); + const float32x4_t q0_1 = vsubq_f32(x_1, y0_1); + const float32x4_t q1_1 = vsubq_f32(x_1, y1_1); + d0_1 = vmlaq_f32(d0_1, q0_1, q0_1); + d1_1 = vmlaq_f32(d1_1, q1_1, q1_1); + } + + for (; i <= d - single_round; i += single_round) { + x_0 = vld1q_f32(x + i); + y0_0 = vld1q_f32(y0 + i); + y1_0 = vld1q_f32(y1 + i); + const float32x4_t q0_0 = vsubq_f32(x_0, y0_0); + const float32x4_t q1_0 = vsubq_f32(x_0, y1_0); + d0_0 = vmlaq_f32(d0_0, q0_0, q0_0); + d1_0 = vmlaq_f32(d1_0, q1_0, q1_0); + } + + d0_0 = vaddq_f32(d0_0, d0_1); + d1_0 = vaddq_f32(d1_0, d1_1); + dis[0] = vaddvq_f32(d0_0); + dis[1] = vaddvq_f32(d1_0); + } else if (d >= single_round) { + float32x4_t x8_0 = vld1q_f32(x); + float32x4_t y8_0 = vld1q_f32(y0); + float32x4_t y8_1 = vld1q_f32(y1); + + float32x4_t d8_0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmulq_f32(d8_0, d8_0); + float32x4_t d8_1 = vsubq_f32(x8_0, y8_1); + d8_1 = vmulq_f32(d8_1, d8_1); + for (i = single_round; i <= d - single_round; i += single_round) { + x8_0 = vld1q_f32(x); + y8_0 = vld1q_f32(y0); + y8_1 = vld1q_f32(y1); + + float32x4_t q0 = vsubq_f32(x8_0, y8_0); + d8_0 = vmlaq_f32(d8_0, q0, q0); + float32x4_t q1 = vsubq_f32(x8_0, y8_1); + d8_1 = vmlaq_f32(d8_1, q1, q1); + } + dis[0] = vaddvq_f32(d8_0); + dis[1] = vaddvq_f32(d8_1); + } else { + dis[0] = 0; + dis[1] = 0; + i = 0; + } + + for (; i < d; i++) { + const float tmp0 = x[i] - y0[i]; + const float tmp1 = x[i] - y1[i]; + dis[0] += tmp0 * tmp0; + dis[1] += tmp1 * tmp1; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for four float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_idx_batch4(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + constexpr size_t single_round = 4; + size_t i; + if (likely(d >= single_round)) { + float32x4_t b = vld1q_f32(x); + + float32x4_t q0 = vld1q_f32(y[0]); + float32x4_t q1 = vld1q_f32(y[1]); + float32x4_t q2 = vld1q_f32(y[2]); + float32x4_t q3 = vld1q_f32(y[3]); + + q0 = vsubq_f32(q0, b); + q1 = vsubq_f32(q1, b); + q2 = vsubq_f32(q2, b); + q3 = vsubq_f32(q3, b); + + float32x4_t res0 = vmulq_f32(q0, q0); + float32x4_t res1 = vmulq_f32(q1, q1); + float32x4_t res2 = vmulq_f32(q2, q2); + float32x4_t res3 = vmulq_f32(q3, q3); + + for (i = single_round; i <= d - single_round; i += single_round) { + b = vld1q_f32(x + i); + + q0 = vld1q_f32(y[0] + i); + q1 = vld1q_f32(y[1] + i); + q2 = vld1q_f32(y[2] + i); + q3 = vld1q_f32(y[3] + i); + + q0 = vsubq_f32(q0, b); + q1 = vsubq_f32(q1, b); + q2 = vsubq_f32(q2, b); + q3 = vsubq_f32(q3, b); + + res0 = vmlaq_f32(res0, q0, q0); + res1 = vmlaq_f32(res1, q1, q1); + res2 = vmlaq_f32(res2, q2, q2); + res3 = vmlaq_f32(res3, q3, q3); + } + dis[0] = vaddvq_f32(res0); + dis[1] = vaddvq_f32(res1); + dis[2] = vaddvq_f32(res2); + dis[3] = vaddvq_f32(res3); + } else { + for (int i = 0; i < 4; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (d > i) { + float q0 = x[i] - *(y[0] + i); + float q1 = x[i] - *(y[1] + i); + float q2 = x[i] - *(y[2] + i); + float q3 = x[i] - *(y[3] + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + for (i++; i < d; ++i) { + float q0 = x[i] - *(y[0] + i); + float q1 = x[i] - *(y[1] + i); + float q2 = x[i] - *(y[2] + i); + float q3 = x[i] - *(y[3] + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for eight float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_idx_prefetch_batch8(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + constexpr size_t multi_round = 16; /* 4 * single_round */ + if (likely(d >= multi_round)) { + float32x4_t neon_res1 = vdupq_n_f32(0.0); + float32x4_t neon_res2 = vdupq_n_f32(0.0); + float32x4_t neon_res3 = vdupq_n_f32(0.0); + float32x4_t neon_res4 = vdupq_n_f32(0.0); + float32x4_t neon_res5 = vdupq_n_f32(0.0); + float32x4_t neon_res6 = vdupq_n_f32(0.0); + float32x4_t neon_res7 = vdupq_n_f32(0.0); + float32x4_t neon_res8 = vdupq_n_f32(0.0); + for (i = 0; i < d - multi_round; i += multi_round) { + prefetch_L1(x + i + multi_round); + prefetch_Lx(y[0] + i + multi_round); + prefetch_Lx(y[1] + i + multi_round); + prefetch_Lx(y[2] + i + multi_round); + prefetch_Lx(y[3] + i + multi_round); + prefetch_Lx(y[4] + i + multi_round); + prefetch_Lx(y[5] + i + multi_round); + prefetch_Lx(y[6] + i + multi_round); + prefetch_Lx(y[7] + i + multi_round); + for (size_t j = 0; j < multi_round; j += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i + j); + float32x4_t neon_base1 = vld1q_f32(y[0] + i + j); + float32x4_t neon_base2 = vld1q_f32(y[1] + i + j); + float32x4_t neon_base3 = vld1q_f32(y[2] + i + j); + float32x4_t neon_base4 = vld1q_f32(y[3] + i + j); + float32x4_t neon_base5 = vld1q_f32(y[4] + i + j); + float32x4_t neon_base6 = vld1q_f32(y[5] + i + j); + float32x4_t neon_base7 = vld1q_f32(y[6] + i + j); + float32x4_t neon_base8 = vld1q_f32(y[7] + i + j); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + } + } + for (; i <= d - single_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + float32x4_t neon_base5 = vld1q_f32(y[4] + i); + float32x4_t neon_base6 = vld1q_f32(y[5] + i); + float32x4_t neon_base7 = vld1q_f32(y[6] + i); + float32x4_t neon_base8 = vld1q_f32(y[7] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + } else if (d >= single_round) { + float32x4_t neon_query = vld1q_f32(x); + + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + float32x4_t neon_base5 = vld1q_f32(y[4]); + float32x4_t neon_base6 = vld1q_f32(y[5]); + float32x4_t neon_base7 = vld1q_f32(y[6]); + float32x4_t neon_base8 = vld1q_f32(y[7]); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_base8); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + neon_base5 = vld1q_f32(y[4] + i); + neon_base6 = vld1q_f32(y[5] + i); + neon_base7 = vld1q_f32(y[6] + i); + neon_base8 = vld1q_f32(y[7] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + } else { + for (int i = 0; i < 8; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y[0] + i); + float q1 = x[i] - *(y[1] + i); + float q2 = x[i] - *(y[2] + i); + float q3 = x[i] - *(y[3] + i); + float q4 = x[i] - *(y[4] + i); + float q5 = x[i] - *(y[5] + i); + float q6 = x[i] - *(y[6] + i); + float q7 = x[i] - *(y[7] + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y[0] + i); + q1 = x[i] - *(y[1] + i); + q2 = x[i] - *(y[2] + i); + q3 = x[i] - *(y[3] + i); + q4 = x[i] - *(y[4] + i); + q5 = x[i] - *(y[5] + i); + q6 = x[i] - *(y[6] + i); + q7 = x[i] - *(y[7] + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for sixteen float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_idx_prefetch_batch16(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128/32 */ + constexpr size_t multi_round = 16; /* 2 * single_round */ + if (likely(d >= multi_round)) { + float32x4_t neon_res1 = vdupq_n_f32(0.0); + float32x4_t neon_res2 = vdupq_n_f32(0.0); + float32x4_t neon_res3 = vdupq_n_f32(0.0); + float32x4_t neon_res4 = vdupq_n_f32(0.0); + float32x4_t neon_res5 = vdupq_n_f32(0.0); + float32x4_t neon_res6 = vdupq_n_f32(0.0); + float32x4_t neon_res7 = vdupq_n_f32(0.0); + float32x4_t neon_res8 = vdupq_n_f32(0.0); + float32x4_t neon_res9 = vdupq_n_f32(0.0); + float32x4_t neon_res10 = vdupq_n_f32(0.0); + float32x4_t neon_res11 = vdupq_n_f32(0.0); + float32x4_t neon_res12 = vdupq_n_f32(0.0); + float32x4_t neon_res13 = vdupq_n_f32(0.0); + float32x4_t neon_res14 = vdupq_n_f32(0.0); + float32x4_t neon_res15 = vdupq_n_f32(0.0); + float32x4_t neon_res16 = vdupq_n_f32(0.0); + for (i = 0; i < d - multi_round; i += multi_round) { + prefetch_L1(x + i + multi_round); + prefetch_Lx(y[0] + i + multi_round); + prefetch_Lx(y[1] + i + multi_round); + prefetch_Lx(y[2] + i + multi_round); + prefetch_Lx(y[3] + i + multi_round); + prefetch_Lx(y[4] + i + multi_round); + prefetch_Lx(y[5] + i + multi_round); + prefetch_Lx(y[6] + i + multi_round); + prefetch_Lx(y[7] + i + multi_round); + prefetch_Lx(y[8] + i + multi_round); + prefetch_Lx(y[9] + i + multi_round); + prefetch_Lx(y[10] + i + multi_round); + prefetch_Lx(y[11] + i + multi_round); + prefetch_Lx(y[12] + i + multi_round); + prefetch_Lx(y[13] + i + multi_round); + prefetch_Lx(y[14] + i + multi_round); + prefetch_Lx(y[15] + i + multi_round); + for (size_t j = 0; j < multi_round; j += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i + j); + float32x4_t neon_base1 = vld1q_f32(y[0] + i + j); + float32x4_t neon_base2 = vld1q_f32(y[1] + i + j); + float32x4_t neon_base3 = vld1q_f32(y[2] + i + j); + float32x4_t neon_base4 = vld1q_f32(y[3] + i + j); + float32x4_t neon_base5 = vld1q_f32(y[4] + i + j); + float32x4_t neon_base6 = vld1q_f32(y[5] + i + j); + float32x4_t neon_base7 = vld1q_f32(y[6] + i + j); + float32x4_t neon_base8 = vld1q_f32(y[7] + i + j); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y[8] + i + j); + neon_base2 = vld1q_f32(y[9] + i + j); + neon_base3 = vld1q_f32(y[10] + i + j); + neon_base4 = vld1q_f32(y[11] + i + j); + neon_base5 = vld1q_f32(y[12] + i + j); + neon_base6 = vld1q_f32(y[13] + i + j); + neon_base7 = vld1q_f32(y[14] + i + j); + neon_base8 = vld1q_f32(y[15] + i + j); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_base5); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_base6); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_base7); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_base8); + } + } + for (; i <= d - single_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + float32x4_t neon_base5 = vld1q_f32(y[4] + i); + float32x4_t neon_base6 = vld1q_f32(y[5] + i); + float32x4_t neon_base7 = vld1q_f32(y[6] + i); + float32x4_t neon_base8 = vld1q_f32(y[7] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base5 = vld1q_f32(y[12] + i); + neon_base6 = vld1q_f32(y[13] + i); + neon_base7 = vld1q_f32(y[14] + i); + neon_base8 = vld1q_f32(y[15] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_base5); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_base6); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_base7); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else if (d >= single_round) { + float32x4_t neon_query = vld1q_f32(x); + + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + float32x4_t neon_base5 = vld1q_f32(y[4]); + float32x4_t neon_base6 = vld1q_f32(y[5]); + float32x4_t neon_base7 = vld1q_f32(y[6]); + float32x4_t neon_base8 = vld1q_f32(y[7]); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y[8]); + neon_base2 = vld1q_f32(y[9]); + neon_base3 = vld1q_f32(y[10]); + neon_base4 = vld1q_f32(y[11]); + neon_base5 = vld1q_f32(y[12]); + neon_base6 = vld1q_f32(y[13]); + neon_base7 = vld1q_f32(y[14]); + neon_base8 = vld1q_f32(y[15]); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res13 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res14 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res15 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res16 = vmulq_f32(neon_base8, neon_base8); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + neon_base5 = vld1q_f32(y[4] + i); + neon_base6 = vld1q_f32(y[5] + i); + neon_base7 = vld1q_f32(y[6] + i); + neon_base8 = vld1q_f32(y[7] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base5 = vld1q_f32(y[12] + i); + neon_base6 = vld1q_f32(y[13] + i); + neon_base7 = vld1q_f32(y[14] + i); + neon_base8 = vld1q_f32(y[15] + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_base5); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_base6); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_base7); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else { + for (int i = 0; i < 16; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y[0] + i); + float q1 = x[i] - *(y[1] + i); + float q2 = x[i] - *(y[2] + i); + float q3 = x[i] - *(y[3] + i); + float q4 = x[i] - *(y[4] + i); + float q5 = x[i] - *(y[5] + i); + float q6 = x[i] - *(y[6] + i); + float q7 = x[i] - *(y[7] + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + q0 = x[i] - *(y[8] + i); + q1 = x[i] - *(y[9] + i); + q2 = x[i] - *(y[10] + i); + q3 = x[i] - *(y[11] + i); + q4 = x[i] - *(y[12] + i); + q5 = x[i] - *(y[13] + i); + q6 = x[i] - *(y[14] + i); + q7 = x[i] - *(y[15] + i); + float d8 = q0 * q0; + float d9 = q1 * q1; + float d10 = q2 * q2; + float d11 = q3 * q3; + float d12 = q4 * q4; + float d13 = q5 * q5; + float d14 = q6 * q6; + float d15 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y[0] + i); + q1 = x[i] - *(y[1] + i); + q2 = x[i] - *(y[2] + i); + q3 = x[i] - *(y[3] + i); + q4 = x[i] - *(y[4] + i); + q5 = x[i] - *(y[5] + i); + q6 = x[i] - *(y[6] + i); + q7 = x[i] - *(y[7] + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + q0 = x[i] - *(y[8] + i); + q1 = x[i] - *(y[9] + i); + q2 = x[i] - *(y[10] + i); + q3 = x[i] - *(y[11] + i); + q4 = x[i] - *(y[12] + i); + q5 = x[i] - *(y[13] + i); + q6 = x[i] - *(y[14] + i); + q7 = x[i] - *(y[15] + i); + d8 += q0 * q0; + d9 += q1 * q1; + d10 += q2 * q2; + d11 += q3 * q3; + d12 += q4 * q4; + d13 += q5 * q5; + d14 += q6 * q6; + d15 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for twenty-four float vectors in batch. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Output array to store the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_idx_prefetch_batch24(const float *x, const float *__restrict *y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128 / 32 */ + constexpr size_t multi_round = 16; /* 2 * single_round */ + if (likely(d >= multi_round)) { + prefetch_L1(x + multi_round); + prefetch_Lx(y[0] + multi_round); + prefetch_Lx(y[1] + multi_round); + prefetch_Lx(y[2] + multi_round); + prefetch_Lx(y[3] + multi_round); + prefetch_Lx(y[4] + multi_round); + prefetch_Lx(y[5] + multi_round); + prefetch_Lx(y[6] + multi_round); + prefetch_Lx(y[7] + multi_round); + prefetch_Lx(y[8] + multi_round); + prefetch_Lx(y[9] + multi_round); + prefetch_Lx(y[10] + multi_round); + prefetch_Lx(y[11] + multi_round); + prefetch_Lx(y[12] + multi_round); + prefetch_Lx(y[13] + multi_round); + prefetch_Lx(y[14] + multi_round); + prefetch_Lx(y[15] + multi_round); + prefetch_Lx(y[16] + multi_round); + prefetch_Lx(y[17] + multi_round); + prefetch_Lx(y[18] + multi_round); + prefetch_Lx(y[19] + multi_round); + prefetch_Lx(y[20] + multi_round); + prefetch_Lx(y[21] + multi_round); + prefetch_Lx(y[22] + multi_round); + prefetch_Lx(y[23] + multi_round); + float32x4_t neon_res1, neon_res2, neon_res3, neon_res4; + float32x4_t neon_res5, neon_res6, neon_res7, neon_res8; + float32x4_t neon_res9, neon_res10, neon_res11, neon_res12; + float32x4_t neon_res13, neon_res14, neon_res15, neon_res16; + float32x4_t neon_res17, neon_res18, neon_res19, neon_res20; + float32x4_t neon_res21, neon_res22, neon_res23, neon_res24; + { + const float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmulq_f32(neon_base1, neon_base1); + neon_res2 = vmulq_f32(neon_base2, neon_base2); + neon_res3 = vmulq_f32(neon_base3, neon_base3); + neon_res4 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4]); + neon_base2 = vld1q_f32(y[5]); + neon_base3 = vld1q_f32(y[6]); + neon_base4 = vld1q_f32(y[7]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmulq_f32(neon_base1, neon_base1); + neon_res6 = vmulq_f32(neon_base2, neon_base2); + neon_res7 = vmulq_f32(neon_base3, neon_base3); + neon_res8 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8]); + neon_base2 = vld1q_f32(y[9]); + neon_base3 = vld1q_f32(y[10]); + neon_base4 = vld1q_f32(y[11]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmulq_f32(neon_base1, neon_base1); + neon_res10 = vmulq_f32(neon_base2, neon_base2); + neon_res11 = vmulq_f32(neon_base3, neon_base3); + neon_res12 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12]); + neon_base2 = vld1q_f32(y[13]); + neon_base3 = vld1q_f32(y[14]); + neon_base4 = vld1q_f32(y[15]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmulq_f32(neon_base1, neon_base1); + neon_res14 = vmulq_f32(neon_base2, neon_base2); + neon_res15 = vmulq_f32(neon_base3, neon_base3); + neon_res16 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16]); + neon_base2 = vld1q_f32(y[17]); + neon_base3 = vld1q_f32(y[18]); + neon_base4 = vld1q_f32(y[19]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmulq_f32(neon_base1, neon_base1); + neon_res18 = vmulq_f32(neon_base2, neon_base2); + neon_res19 = vmulq_f32(neon_base3, neon_base3); + neon_res20 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20]); + neon_base2 = vld1q_f32(y[21]); + neon_base3 = vld1q_f32(y[22]); + neon_base4 = vld1q_f32(y[23]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmulq_f32(neon_base1, neon_base1); + neon_res22 = vmulq_f32(neon_base2, neon_base2); + neon_res23 = vmulq_f32(neon_base3, neon_base3); + neon_res24 = vmulq_f32(neon_base4, neon_base4); + } + for (i = single_round; i < multi_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4] + i); + neon_base2 = vld1q_f32(y[5] + i); + neon_base3 = vld1q_f32(y[6] + i); + neon_base4 = vld1q_f32(y[7] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base1, neon_base1); + neon_res6 = vmlaq_f32(neon_res6, neon_base2, neon_base2); + neon_res7 = vmlaq_f32(neon_res7, neon_base3, neon_base3); + neon_res8 = vmlaq_f32(neon_res8, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12] + i); + neon_base2 = vld1q_f32(y[13] + i); + neon_base3 = vld1q_f32(y[14] + i); + neon_base4 = vld1q_f32(y[15] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base1, neon_base1); + neon_res14 = vmlaq_f32(neon_res14, neon_base2, neon_base2); + neon_res15 = vmlaq_f32(neon_res15, neon_base3, neon_base3); + neon_res16 = vmlaq_f32(neon_res16, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16] + i); + neon_base2 = vld1q_f32(y[17] + i); + neon_base3 = vld1q_f32(y[18] + i); + neon_base4 = vld1q_f32(y[19] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmlaq_f32(neon_res17, neon_base1, neon_base1); + neon_res18 = vmlaq_f32(neon_res18, neon_base2, neon_base2); + neon_res19 = vmlaq_f32(neon_res19, neon_base3, neon_base3); + neon_res20 = vmlaq_f32(neon_res20, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20] + i); + neon_base2 = vld1q_f32(y[21] + i); + neon_base3 = vld1q_f32(y[22] + i); + neon_base4 = vld1q_f32(y[23] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmlaq_f32(neon_res21, neon_base1, neon_base1); + neon_res22 = vmlaq_f32(neon_res22, neon_base2, neon_base2); + neon_res23 = vmlaq_f32(neon_res23, neon_base3, neon_base3); + neon_res24 = vmlaq_f32(neon_res24, neon_base4, neon_base4); + } + for (; i < d - multi_round; i += multi_round) { + prefetch_L1(x + i + multi_round); + prefetch_Lx(y[0] + i + multi_round); + prefetch_Lx(y[1] + i + multi_round); + prefetch_Lx(y[2] + i + multi_round); + prefetch_Lx(y[3] + i + multi_round); + prefetch_Lx(y[4] + i + multi_round); + prefetch_Lx(y[5] + i + multi_round); + prefetch_Lx(y[6] + i + multi_round); + prefetch_Lx(y[7] + i + multi_round); + prefetch_Lx(y[8] + i + multi_round); + prefetch_Lx(y[9] + i + multi_round); + prefetch_Lx(y[10] + i + multi_round); + prefetch_Lx(y[11] + i + multi_round); + prefetch_Lx(y[12] + i + multi_round); + prefetch_Lx(y[13] + i + multi_round); + prefetch_Lx(y[14] + i + multi_round); + prefetch_Lx(y[15] + i + multi_round); + prefetch_Lx(y[16] + i + multi_round); + prefetch_Lx(y[17] + i + multi_round); + prefetch_Lx(y[18] + i + multi_round); + prefetch_Lx(y[19] + i + multi_round); + prefetch_Lx(y[20] + i + multi_round); + prefetch_Lx(y[21] + i + multi_round); + prefetch_Lx(y[22] + i + multi_round); + prefetch_Lx(y[23] + i + multi_round); + for (size_t j = i; j < i + multi_round; j += single_round) { + const float32x4_t neon_query = vld1q_f32(x + j); + float32x4_t neon_base1 = vld1q_f32(y[0] + j); + float32x4_t neon_base2 = vld1q_f32(y[1] + j); + float32x4_t neon_base3 = vld1q_f32(y[2] + j); + float32x4_t neon_base4 = vld1q_f32(y[3] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4] + j); + neon_base2 = vld1q_f32(y[5] + j); + neon_base3 = vld1q_f32(y[6] + j); + neon_base4 = vld1q_f32(y[7] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base1, neon_base1); + neon_res6 = vmlaq_f32(neon_res6, neon_base2, neon_base2); + neon_res7 = vmlaq_f32(neon_res7, neon_base3, neon_base3); + neon_res8 = vmlaq_f32(neon_res8, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8] + j); + neon_base2 = vld1q_f32(y[9] + j); + neon_base3 = vld1q_f32(y[10] + j); + neon_base4 = vld1q_f32(y[11] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12] + j); + neon_base2 = vld1q_f32(y[13] + j); + neon_base3 = vld1q_f32(y[14] + j); + neon_base4 = vld1q_f32(y[15] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base1, neon_base1); + neon_res14 = vmlaq_f32(neon_res14, neon_base2, neon_base2); + neon_res15 = vmlaq_f32(neon_res15, neon_base3, neon_base3); + neon_res16 = vmlaq_f32(neon_res16, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16] + j); + neon_base2 = vld1q_f32(y[17] + j); + neon_base3 = vld1q_f32(y[18] + j); + neon_base4 = vld1q_f32(y[19] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmlaq_f32(neon_res17, neon_base1, neon_base1); + neon_res18 = vmlaq_f32(neon_res18, neon_base2, neon_base2); + neon_res19 = vmlaq_f32(neon_res19, neon_base3, neon_base3); + neon_res20 = vmlaq_f32(neon_res20, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20] + j); + neon_base2 = vld1q_f32(y[21] + j); + neon_base3 = vld1q_f32(y[22] + j); + neon_base4 = vld1q_f32(y[23] + j); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmlaq_f32(neon_res21, neon_base1, neon_base1); + neon_res22 = vmlaq_f32(neon_res22, neon_base2, neon_base2); + neon_res23 = vmlaq_f32(neon_res23, neon_base3, neon_base3); + neon_res24 = vmlaq_f32(neon_res24, neon_base4, neon_base4); + } + } + for (; i <= d - single_round; i += single_round) { + const float32x4_t neon_query = vld1q_f32(x + i); + float32x4_t neon_base1 = vld1q_f32(y[0] + i); + float32x4_t neon_base2 = vld1q_f32(y[1] + i); + float32x4_t neon_base3 = vld1q_f32(y[2] + i); + float32x4_t neon_base4 = vld1q_f32(y[3] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4] + i); + neon_base2 = vld1q_f32(y[5] + i); + neon_base3 = vld1q_f32(y[6] + i); + neon_base4 = vld1q_f32(y[7] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base1, neon_base1); + neon_res6 = vmlaq_f32(neon_res6, neon_base2, neon_base2); + neon_res7 = vmlaq_f32(neon_res7, neon_base3, neon_base3); + neon_res8 = vmlaq_f32(neon_res8, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12] + i); + neon_base2 = vld1q_f32(y[13] + i); + neon_base3 = vld1q_f32(y[14] + i); + neon_base4 = vld1q_f32(y[15] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base1, neon_base1); + neon_res14 = vmlaq_f32(neon_res14, neon_base2, neon_base2); + neon_res15 = vmlaq_f32(neon_res15, neon_base3, neon_base3); + neon_res16 = vmlaq_f32(neon_res16, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16] + i); + neon_base2 = vld1q_f32(y[17] + i); + neon_base3 = vld1q_f32(y[18] + i); + neon_base4 = vld1q_f32(y[19] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmlaq_f32(neon_res17, neon_base1, neon_base1); + neon_res18 = vmlaq_f32(neon_res18, neon_base2, neon_base2); + neon_res19 = vmlaq_f32(neon_res19, neon_base3, neon_base3); + neon_res20 = vmlaq_f32(neon_res20, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20] + i); + neon_base2 = vld1q_f32(y[21] + i); + neon_base3 = vld1q_f32(y[22] + i); + neon_base4 = vld1q_f32(y[23] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmlaq_f32(neon_res21, neon_base1, neon_base1); + neon_res22 = vmlaq_f32(neon_res22, neon_base2, neon_base2); + neon_res23 = vmlaq_f32(neon_res23, neon_base3, neon_base3); + neon_res24 = vmlaq_f32(neon_res24, neon_base4, neon_base4); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + dis[16] = vaddvq_f32(neon_res17); + dis[17] = vaddvq_f32(neon_res18); + dis[18] = vaddvq_f32(neon_res19); + dis[19] = vaddvq_f32(neon_res20); + dis[20] = vaddvq_f32(neon_res21); + dis[21] = vaddvq_f32(neon_res22); + dis[22] = vaddvq_f32(neon_res23); + dis[23] = vaddvq_f32(neon_res24); + } else if (d >= single_round) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y[0]); + float32x4_t neon_base2 = vld1q_f32(y[1]); + float32x4_t neon_base3 = vld1q_f32(y[2]); + float32x4_t neon_base4 = vld1q_f32(y[3]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4]); + neon_base2 = vld1q_f32(y[5]); + neon_base3 = vld1q_f32(y[6]); + neon_base4 = vld1q_f32(y[7]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res6 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res7 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res8 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8]); + neon_base2 = vld1q_f32(y[9]); + neon_base3 = vld1q_f32(y[10]); + neon_base4 = vld1q_f32(y[11]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12]); + neon_base2 = vld1q_f32(y[13]); + neon_base3 = vld1q_f32(y[14]); + neon_base4 = vld1q_f32(y[15]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res13 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res14 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res15 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res16 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16]); + neon_base2 = vld1q_f32(y[17]); + neon_base3 = vld1q_f32(y[18]); + neon_base4 = vld1q_f32(y[19]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res17 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res18 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res19 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res20 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20]); + neon_base2 = vld1q_f32(y[21]); + neon_base3 = vld1q_f32(y[22]); + neon_base4 = vld1q_f32(y[23]); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res21 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res22 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res23 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res24 = vmulq_f32(neon_base4, neon_base4); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y[0] + i); + neon_base2 = vld1q_f32(y[1] + i); + neon_base3 = vld1q_f32(y[2] + i); + neon_base4 = vld1q_f32(y[3] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[4] + i); + neon_base2 = vld1q_f32(y[5] + i); + neon_base3 = vld1q_f32(y[6] + i); + neon_base4 = vld1q_f32(y[7] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base1, neon_base1); + neon_res6 = vmlaq_f32(neon_res6, neon_base2, neon_base2); + neon_res7 = vmlaq_f32(neon_res7, neon_base3, neon_base3); + neon_res8 = vmlaq_f32(neon_res8, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[8] + i); + neon_base2 = vld1q_f32(y[9] + i); + neon_base3 = vld1q_f32(y[10] + i); + neon_base4 = vld1q_f32(y[11] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[12] + i); + neon_base2 = vld1q_f32(y[13] + i); + neon_base3 = vld1q_f32(y[14] + i); + neon_base4 = vld1q_f32(y[15] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base1, neon_base1); + neon_res14 = vmlaq_f32(neon_res14, neon_base2, neon_base2); + neon_res15 = vmlaq_f32(neon_res15, neon_base3, neon_base3); + neon_res16 = vmlaq_f32(neon_res16, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[16] + i); + neon_base2 = vld1q_f32(y[17] + i); + neon_base3 = vld1q_f32(y[18] + i); + neon_base4 = vld1q_f32(y[19] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmlaq_f32(neon_res17, neon_base1, neon_base1); + neon_res18 = vmlaq_f32(neon_res18, neon_base2, neon_base2); + neon_res19 = vmlaq_f32(neon_res19, neon_base3, neon_base3); + neon_res20 = vmlaq_f32(neon_res20, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y[20] + i); + neon_base2 = vld1q_f32(y[21] + i); + neon_base3 = vld1q_f32(y[22] + i); + neon_base4 = vld1q_f32(y[23] + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmlaq_f32(neon_res21, neon_base1, neon_base1); + neon_res22 = vmlaq_f32(neon_res22, neon_base2, neon_base2); + neon_res23 = vmlaq_f32(neon_res23, neon_base3, neon_base3); + neon_res24 = vmlaq_f32(neon_res24, neon_base4, neon_base4); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + dis[16] = vaddvq_f32(neon_res17); + dis[17] = vaddvq_f32(neon_res18); + dis[18] = vaddvq_f32(neon_res19); + dis[19] = vaddvq_f32(neon_res20); + dis[20] = vaddvq_f32(neon_res21); + dis[21] = vaddvq_f32(neon_res22); + dis[22] = vaddvq_f32(neon_res23); + dis[23] = vaddvq_f32(neon_res24); + } else { + for (int i = 0; i < 24; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y[0] + i); + float q1 = x[i] - *(y[1] + i); + float q2 = x[i] - *(y[2] + i); + float q3 = x[i] - *(y[3] + i); + float q4 = x[i] - *(y[4] + i); + float q5 = x[i] - *(y[5] + i); + float q6 = x[i] - *(y[6] + i); + float q7 = x[i] - *(y[7] + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + q0 = x[i] - *(y[8] + i); + q1 = x[i] - *(y[9] + i); + q2 = x[i] - *(y[10] + i); + q3 = x[i] - *(y[11] + i); + q4 = x[i] - *(y[12] + i); + q5 = x[i] - *(y[13] + i); + q6 = x[i] - *(y[14] + i); + q7 = x[i] - *(y[15] + i); + float d8 = q0 * q0; + float d9 = q1 * q1; + float d10 = q2 * q2; + float d11 = q3 * q3; + float d12 = q4 * q4; + float d13 = q5 * q5; + float d14 = q6 * q6; + float d15 = q7 * q7; + q0 = x[i] - *(y[16] + i); + q1 = x[i] - *(y[17] + i); + q2 = x[i] - *(y[18] + i); + q3 = x[i] - *(y[19] + i); + q4 = x[i] - *(y[20] + i); + q5 = x[i] - *(y[21] + i); + q6 = x[i] - *(y[22] + i); + q7 = x[i] - *(y[23] + i); + float d16 = q0 * q0; + float d17 = q1 * q1; + float d18 = q2 * q2; + float d19 = q3 * q3; + float d20 = q4 * q4; + float d21 = q5 * q5; + float d22 = q6 * q6; + float d23 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y[0] + i); + q1 = x[i] - *(y[1] + i); + q2 = x[i] - *(y[2] + i); + q3 = x[i] - *(y[3] + i); + q4 = x[i] - *(y[4] + i); + q5 = x[i] - *(y[5] + i); + q6 = x[i] - *(y[6] + i); + q7 = x[i] - *(y[7] + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + q0 = x[i] - *(y[8] + i); + q1 = x[i] - *(y[9] + i); + q2 = x[i] - *(y[10] + i); + q3 = x[i] - *(y[11] + i); + q4 = x[i] - *(y[12] + i); + q5 = x[i] - *(y[13] + i); + q6 = x[i] - *(y[14] + i); + q7 = x[i] - *(y[15] + i); + d8 += q0 * q0; + d9 += q1 * q1; + d10 += q2 * q2; + d11 += q3 * q3; + d12 += q4 * q4; + d13 += q5 * q5; + d14 += q6 * q6; + d15 += q7 * q7; + q0 = x[i] - *(y[16] + i); + q1 = x[i] - *(y[17] + i); + q2 = x[i] - *(y[18] + i); + q3 = x[i] - *(y[19] + i); + q4 = x[i] - *(y[20] + i); + q5 = x[i] - *(y[21] + i); + q6 = x[i] - *(y[22] + i); + q7 = x[i] - *(y[23] + i); + d16 += q0 * q0; + d17 += q1 * q1; + d18 += q2 * q2; + d19 += q3 * q3; + d20 += q4 * q4; + d21 += q5 * q5; + d22 += q6 * q6; + d23 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + dis[16] += d16; + dis[17] += d17; + dis[18] += d18; + dis[19] += d19; + dis[20] += d20; + dis[21] += d21; + dis[22] += d22; + dis[23] += d23; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 square distance for two vectors in batch mode. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_batch2(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 8; + + if (likely(d >= single_round)) { + float32x4_t x_0 = vld1q_f32(x); + float32x4_t x_1 = vld1q_f32(x + 4); + + float32x4_t y0_0 = vld1q_f32(y); + float32x4_t y0_1 = vld1q_f32(y + 4); + float32x4_t y1_0 = vld1q_f32(y + d); + float32x4_t y1_1 = vld1q_f32(y + d + 4); + + float32x4_t d0_0 = vsubq_f32(x_0, y0_0); + d0_0 = vmulq_f32(d0_0, d0_0); + float32x4_t d0_1 = vsubq_f32(x_1, y0_1); + d0_1 = vmulq_f32(d0_1, d0_1); + float32x4_t d1_0 = vsubq_f32(x_0, y1_0); + d1_0 = vmulq_f32(d1_0, d1_0); + float32x4_t d1_1 = vsubq_f32(x_1, y1_1); + d1_1 = vmulq_f32(d1_1, d1_1); + + for (i = single_round; i <= d - single_round; i += single_round) { + x_0 = vld1q_f32(x + i); + y0_0 = vld1q_f32(y + i); + y1_0 = vld1q_f32(y + d + i); + const float32x4_t q0_0 = vsubq_f32(x_0, y0_0); + const float32x4_t q1_0 = vsubq_f32(x_0, y1_0); + d0_0 = vmlaq_f32(d0_0, q0_0, q0_0); + d1_0 = vmlaq_f32(d1_0, q1_0, q1_0); + + x_1 = vld1q_f32(x + i + 4); + y0_1 = vld1q_f32(y + i + 4); + y1_1 = vld1q_f32(y + d + i + 4); + const float32x4_t q0_1 = vsubq_f32(x_1, y0_1); + const float32x4_t q1_1 = vsubq_f32(x_1, y1_1); + d0_1 = vmlaq_f32(d0_1, q0_1, q0_1); + d1_1 = vmlaq_f32(d1_1, q1_1, q1_1); + } + + d0_0 = vaddq_f32(d0_0, d0_1); + d1_0 = vaddq_f32(d1_0, d1_1); + dis[0] = vaddvq_f32(d0_0); + dis[1] = vaddvq_f32(d1_0); + } else { + dis[0] = 0; + dis[1] = 0; + i = 0; + } + + for (; i < d; i++) { + const float tmp0 = x[i] - *(y + i); + const float tmp1 = x[i] - *(y + d + i); + dis[0] += tmp0 * tmp0; + dis[1] += tmp1 * tmp1; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 square distance for four vectors in batch mode. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_batch4(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + constexpr size_t single_round = 4; + size_t i; + if (likely(d >= single_round)) { + float32x4_t b = vld1q_f32(x); + + float32x4_t q0 = vld1q_f32(y); + float32x4_t q1 = vld1q_f32(y + d); + float32x4_t q2 = vld1q_f32(y + 2 * d); + float32x4_t q3 = vld1q_f32(y + 3 * d); + + q0 = vsubq_f32(q0, b); + q1 = vsubq_f32(q1, b); + q2 = vsubq_f32(q2, b); + q3 = vsubq_f32(q3, b); + + float32x4_t res0 = vmulq_f32(q0, q0); + float32x4_t res1 = vmulq_f32(q1, q1); + float32x4_t res2 = vmulq_f32(q2, q2); + float32x4_t res3 = vmulq_f32(q3, q3); + + for (i = single_round; i <= d - single_round; i += single_round) { + b = vld1q_f32(x + i); + + q0 = vld1q_f32(y + i); + q1 = vld1q_f32(y + d + i); + q2 = vld1q_f32(y + 2 * d + i); + q3 = vld1q_f32(y + 3 * d + i); + + q0 = vsubq_f32(q0, b); + q1 = vsubq_f32(q1, b); + q2 = vsubq_f32(q2, b); + q3 = vsubq_f32(q3, b); + + res0 = vmlaq_f32(res0, q0, q0); + res1 = vmlaq_f32(res1, q1, q1); + res2 = vmlaq_f32(res2, q2, q2); + res3 = vmlaq_f32(res3, q3, q3); + } + dis[0] = vaddvq_f32(res0); + dis[1] = vaddvq_f32(res1); + dis[2] = vaddvq_f32(res2); + dis[3] = vaddvq_f32(res3); + } else { + for (int i = 0; i < 4; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (d > i) { + float q0 = x[i] - *(y + i); + float q1 = x[i] - *(y + d + i); + float q2 = x[i] - *(y + 2 * d + i); + float q3 = x[i] - *(y + 3 * d + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + for (i++; i < d; ++i) { + float q0 = x[i] - *(y + i); + float q1 = x[i] - *(y + d + i); + float q2 = x[i] - *(y + 2 * d + i); + float q3 = x[i] - *(y + 3 * d + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 square distance for eight vectors in batch mode. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_batch8(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + float32x4_t neon_base5 = vld1q_f32(y + 4 * d); + float32x4_t neon_base6 = vld1q_f32(y + 5 * d); + float32x4_t neon_base7 = vld1q_f32(y + 6 * d); + float32x4_t neon_base8 = vld1q_f32(y + 7 * d); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_base8); + + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + neon_base5 = vld1q_f32(y + 4 * d + i); + neon_base6 = vld1q_f32(y + 5 * d + i); + neon_base7 = vld1q_f32(y + 6 * d + i); + neon_base8 = vld1q_f32(y + 7 * d + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + } else { + for (int i = 0; i < 8; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y + i); + float q1 = x[i] - *(y + d + i); + float q2 = x[i] - *(y + 2 * d + i); + float q3 = x[i] - *(y + 3 * d + i); + float q4 = x[i] - *(y + 4 * d + i); + float q5 = x[i] - *(y + 5 * d + i); + float q6 = x[i] - *(y + 6 * d + i); + float q7 = x[i] - *(y + 7 * d + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y + i); + q1 = x[i] - *(y + d + i); + q2 = x[i] - *(y + 2 * d + i); + q3 = x[i] - *(y + 3 * d + i); + q4 = x[i] - *(y + 4 * d + i); + q5 = x[i] - *(y + 5 * d + i); + q6 = x[i] - *(y + 6 * d + i); + q7 = x[i] - *(y + 7 * d + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 square distance for sixteen vectors in batch mode. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_batch16(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128 / 32 */ + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + float32x4_t neon_base5 = vld1q_f32(y + 4 * d); + float32x4_t neon_base6 = vld1q_f32(y + 5 * d); + float32x4_t neon_base7 = vld1q_f32(y + 6 * d); + float32x4_t neon_base8 = vld1q_f32(y + 7 * d); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res5 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res6 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res7 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res8 = vmulq_f32(neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y + 8 * d); + neon_base2 = vld1q_f32(y + 9 * d); + neon_base3 = vld1q_f32(y + 10 * d); + neon_base4 = vld1q_f32(y + 11 * d); + neon_base5 = vld1q_f32(y + 12 * d); + neon_base6 = vld1q_f32(y + 13 * d); + neon_base7 = vld1q_f32(y + 14 * d); + neon_base8 = vld1q_f32(y + 15 * d); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_base4); + float32x4_t neon_res13 = vmulq_f32(neon_base5, neon_base5); + float32x4_t neon_res14 = vmulq_f32(neon_base6, neon_base6); + float32x4_t neon_res15 = vmulq_f32(neon_base7, neon_base7); + float32x4_t neon_res16 = vmulq_f32(neon_base8, neon_base8); + + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + neon_base5 = vld1q_f32(y + 4 * d + i); + neon_base6 = vld1q_f32(y + 5 * d + i); + neon_base7 = vld1q_f32(y + 6 * d + i); + neon_base8 = vld1q_f32(y + 7 * d + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + neon_res5 = vmlaq_f32(neon_res5, neon_base5, neon_base5); + neon_res6 = vmlaq_f32(neon_res6, neon_base6, neon_base6); + neon_res7 = vmlaq_f32(neon_res7, neon_base7, neon_base7); + neon_res8 = vmlaq_f32(neon_res8, neon_base8, neon_base8); + + neon_base1 = vld1q_f32(y + 8 * d + i); + neon_base2 = vld1q_f32(y + 9 * d + i); + neon_base3 = vld1q_f32(y + 10 * d + i); + neon_base4 = vld1q_f32(y + 11 * d + i); + neon_base5 = vld1q_f32(y + 12 * d + i); + neon_base6 = vld1q_f32(y + 13 * d + i); + neon_base7 = vld1q_f32(y + 14 * d + i); + neon_base8 = vld1q_f32(y + 15 * d + i); + + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_base5 = vsubq_f32(neon_base5, neon_query); + neon_base6 = vsubq_f32(neon_base6, neon_query); + neon_base7 = vsubq_f32(neon_base7, neon_query); + neon_base8 = vsubq_f32(neon_base8, neon_query); + + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + neon_res13 = vmlaq_f32(neon_res13, neon_base5, neon_base5); + neon_res14 = vmlaq_f32(neon_res14, neon_base6, neon_base6); + neon_res15 = vmlaq_f32(neon_res15, neon_base7, neon_base7); + neon_res16 = vmlaq_f32(neon_res16, neon_base8, neon_base8); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + } else { + for (int i = 0; i < 16; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y + i); + float q1 = x[i] - *(y + d + i); + float q2 = x[i] - *(y + 2 * d + i); + float q3 = x[i] - *(y + 3 * d + i); + float q4 = x[i] - *(y + 4 * d + i); + float q5 = x[i] - *(y + 5 * d + i); + float q6 = x[i] - *(y + 6 * d + i); + float q7 = x[i] - *(y + 7 * d + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + q0 = x[i] - *(y + 8 * d + i); + q1 = x[i] - *(y + 9 * d + i); + q2 = x[i] - *(y + 10 * d + i); + q3 = x[i] - *(y + 11 * d + i); + q4 = x[i] - *(y + 12 * d + i); + q5 = x[i] - *(y + 13 * d + i); + q6 = x[i] - *(y + 14 * d + i); + q7 = x[i] - *(y + 15 * d + i); + float d8 = q0 * q0; + float d9 = q1 * q1; + float d10 = q2 * q2; + float d11 = q3 * q3; + float d12 = q4 * q4; + float d13 = q5 * q5; + float d14 = q6 * q6; + float d15 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y + i); + q1 = x[i] - *(y + d + i); + q2 = x[i] - *(y + 2 * d + i); + q3 = x[i] - *(y + 3 * d + i); + q4 = x[i] - *(y + 4 * d + i); + q5 = x[i] - *(y + 5 * d + i); + q6 = x[i] - *(y + 6 * d + i); + q7 = x[i] - *(y + 7 * d + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + q0 = x[i] - *(y + 8 * d + i); + q1 = x[i] - *(y + 9 * d + i); + q2 = x[i] - *(y + 10 * d + i); + q3 = x[i] - *(y + 11 * d + i); + q4 = x[i] - *(y + 12 * d + i); + q5 = x[i] - *(y + 13 * d + i); + q6 = x[i] - *(y + 14 * d + i); + q7 = x[i] - *(y + 15 * d + i); + d8 += q0 * q0; + d9 += q1 * q1; + d10 += q2 * q2; + d11 += q3 * q3; + d12 += q4 * q4; + d13 += q5 * q5; + d14 += q6 * q6; + d15 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 square distance for twenty-four vectors in batch mode. +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +* @param dis Pointer to the output array for storing the results (float). +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_batch24(const float *x, const float *__restrict y, const size_t d, float *dis) +{ + size_t i; + constexpr size_t single_round = 4; /* 128 / 32 */ + if (likely(d >= single_round)) { + float32x4_t neon_query = vld1q_f32(x); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + d); + float32x4_t neon_base3 = vld1q_f32(y + 2 * d); + float32x4_t neon_base4 = vld1q_f32(y + 3 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res1 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res2 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res3 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res4 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 4 * d); + neon_base2 = vld1q_f32(y + 5 * d); + neon_base3 = vld1q_f32(y + 6 * d); + neon_base4 = vld1q_f32(y + 7 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res5 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res6 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res7 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res8 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 8 * d); + neon_base2 = vld1q_f32(y + 9 * d); + neon_base3 = vld1q_f32(y + 10 * d); + neon_base4 = vld1q_f32(y + 11 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res9 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res10 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res11 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res12 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 12 * d); + neon_base2 = vld1q_f32(y + 13 * d); + neon_base3 = vld1q_f32(y + 14 * d); + neon_base4 = vld1q_f32(y + 15 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res13 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res14 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res15 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res16 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 16 * d); + neon_base2 = vld1q_f32(y + 17 * d); + neon_base3 = vld1q_f32(y + 18 * d); + neon_base4 = vld1q_f32(y + 19 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res17 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res18 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res19 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res20 = vmulq_f32(neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 20 * d); + neon_base2 = vld1q_f32(y + 21 * d); + neon_base3 = vld1q_f32(y + 22 * d); + neon_base4 = vld1q_f32(y + 23 * d); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + float32x4_t neon_res21 = vmulq_f32(neon_base1, neon_base1); + float32x4_t neon_res22 = vmulq_f32(neon_base2, neon_base2); + float32x4_t neon_res23 = vmulq_f32(neon_base3, neon_base3); + float32x4_t neon_res24 = vmulq_f32(neon_base4, neon_base4); + for (i = single_round; i <= d - single_round; i += single_round) { + neon_query = vld1q_f32(x + i); + neon_base1 = vld1q_f32(y + i); + neon_base2 = vld1q_f32(y + d + i); + neon_base3 = vld1q_f32(y + 2 * d + i); + neon_base4 = vld1q_f32(y + 3 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res1 = vmlaq_f32(neon_res1, neon_base1, neon_base1); + neon_res2 = vmlaq_f32(neon_res2, neon_base2, neon_base2); + neon_res3 = vmlaq_f32(neon_res3, neon_base3, neon_base3); + neon_res4 = vmlaq_f32(neon_res4, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 4 * d + i); + neon_base2 = vld1q_f32(y + 5 * d + i); + neon_base3 = vld1q_f32(y + 6 * d + i); + neon_base4 = vld1q_f32(y + 7 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res5 = vmlaq_f32(neon_res5, neon_base1, neon_base1); + neon_res6 = vmlaq_f32(neon_res6, neon_base2, neon_base2); + neon_res7 = vmlaq_f32(neon_res7, neon_base3, neon_base3); + neon_res8 = vmlaq_f32(neon_res8, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 8 * d + i); + neon_base2 = vld1q_f32(y + 9 * d + i); + neon_base3 = vld1q_f32(y + 10 * d + i); + neon_base4 = vld1q_f32(y + 11 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res9 = vmlaq_f32(neon_res9, neon_base1, neon_base1); + neon_res10 = vmlaq_f32(neon_res10, neon_base2, neon_base2); + neon_res11 = vmlaq_f32(neon_res11, neon_base3, neon_base3); + neon_res12 = vmlaq_f32(neon_res12, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 12 * d + i); + neon_base2 = vld1q_f32(y + 13 * d + i); + neon_base3 = vld1q_f32(y + 14 * d + i); + neon_base4 = vld1q_f32(y + 15 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res13 = vmlaq_f32(neon_res13, neon_base1, neon_base1); + neon_res14 = vmlaq_f32(neon_res14, neon_base2, neon_base2); + neon_res15 = vmlaq_f32(neon_res15, neon_base3, neon_base3); + neon_res16 = vmlaq_f32(neon_res16, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 16 * d + i); + neon_base2 = vld1q_f32(y + 17 * d + i); + neon_base3 = vld1q_f32(y + 18 * d + i); + neon_base4 = vld1q_f32(y + 19 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res17 = vmlaq_f32(neon_res17, neon_base1, neon_base1); + neon_res18 = vmlaq_f32(neon_res18, neon_base2, neon_base2); + neon_res19 = vmlaq_f32(neon_res19, neon_base3, neon_base3); + neon_res20 = vmlaq_f32(neon_res20, neon_base4, neon_base4); + + neon_base1 = vld1q_f32(y + 20 * d + i); + neon_base2 = vld1q_f32(y + 21 * d + i); + neon_base3 = vld1q_f32(y + 22 * d + i); + neon_base4 = vld1q_f32(y + 23 * d + i); + neon_base1 = vsubq_f32(neon_base1, neon_query); + neon_base2 = vsubq_f32(neon_base2, neon_query); + neon_base3 = vsubq_f32(neon_base3, neon_query); + neon_base4 = vsubq_f32(neon_base4, neon_query); + neon_res21 = vmlaq_f32(neon_res21, neon_base1, neon_base1); + neon_res22 = vmlaq_f32(neon_res22, neon_base2, neon_base2); + neon_res23 = vmlaq_f32(neon_res23, neon_base3, neon_base3); + neon_res24 = vmlaq_f32(neon_res24, neon_base4, neon_base4); + } + dis[0] = vaddvq_f32(neon_res1); + dis[1] = vaddvq_f32(neon_res2); + dis[2] = vaddvq_f32(neon_res3); + dis[3] = vaddvq_f32(neon_res4); + dis[4] = vaddvq_f32(neon_res5); + dis[5] = vaddvq_f32(neon_res6); + dis[6] = vaddvq_f32(neon_res7); + dis[7] = vaddvq_f32(neon_res8); + dis[8] = vaddvq_f32(neon_res9); + dis[9] = vaddvq_f32(neon_res10); + dis[10] = vaddvq_f32(neon_res11); + dis[11] = vaddvq_f32(neon_res12); + dis[12] = vaddvq_f32(neon_res13); + dis[13] = vaddvq_f32(neon_res14); + dis[14] = vaddvq_f32(neon_res15); + dis[15] = vaddvq_f32(neon_res16); + dis[16] = vaddvq_f32(neon_res17); + dis[17] = vaddvq_f32(neon_res18); + dis[18] = vaddvq_f32(neon_res19); + dis[19] = vaddvq_f32(neon_res20); + dis[20] = vaddvq_f32(neon_res21); + dis[21] = vaddvq_f32(neon_res22); + dis[22] = vaddvq_f32(neon_res23); + dis[23] = vaddvq_f32(neon_res24); + } else { + for (int i = 0; i < 24; i++) { + dis[i] = 0.0f; + } + i = 0; + } + if (i < d) { + float q0 = x[i] - *(y + i); + float q1 = x[i] - *(y + d + i); + float q2 = x[i] - *(y + 2 * d + i); + float q3 = x[i] - *(y + 3 * d + i); + float q4 = x[i] - *(y + 4 * d + i); + float q5 = x[i] - *(y + 5 * d + i); + float q6 = x[i] - *(y + 6 * d + i); + float q7 = x[i] - *(y + 7 * d + i); + float d0 = q0 * q0; + float d1 = q1 * q1; + float d2 = q2 * q2; + float d3 = q3 * q3; + float d4 = q4 * q4; + float d5 = q5 * q5; + float d6 = q6 * q6; + float d7 = q7 * q7; + q0 = x[i] - *(y + 8 * d + i); + q1 = x[i] - *(y + 9 * d + i); + q2 = x[i] - *(y + 10 * d + i); + q3 = x[i] - *(y + 11 * d + i); + q4 = x[i] - *(y + 12 * d + i); + q5 = x[i] - *(y + 13 * d + i); + q6 = x[i] - *(y + 14 * d + i); + q7 = x[i] - *(y + 15 * d + i); + float d8 = q0 * q0; + float d9 = q1 * q1; + float d10 = q2 * q2; + float d11 = q3 * q3; + float d12 = q4 * q4; + float d13 = q5 * q5; + float d14 = q6 * q6; + float d15 = q7 * q7; + q0 = x[i] - *(y + 16 * d + i); + q1 = x[i] - *(y + 17 * d + i); + q2 = x[i] - *(y + 18 * d + i); + q3 = x[i] - *(y + 19 * d + i); + q4 = x[i] - *(y + 20 * d + i); + q5 = x[i] - *(y + 21 * d + i); + q6 = x[i] - *(y + 22 * d + i); + q7 = x[i] - *(y + 23 * d + i); + float d16 = q0 * q0; + float d17 = q1 * q1; + float d18 = q2 * q2; + float d19 = q3 * q3; + float d20 = q4 * q4; + float d21 = q5 * q5; + float d22 = q6 * q6; + float d23 = q7 * q7; + for (i++; i < d; ++i) { + q0 = x[i] - *(y + i); + q1 = x[i] - *(y + d + i); + q2 = x[i] - *(y + 2 * d + i); + q3 = x[i] - *(y + 3 * d + i); + q4 = x[i] - *(y + 4 * d + i); + q5 = x[i] - *(y + 5 * d + i); + q6 = x[i] - *(y + 6 * d + i); + q7 = x[i] - *(y + 7 * d + i); + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + d4 += q4 * q4; + d5 += q5 * q5; + d6 += q6 * q6; + d7 += q7 * q7; + q0 = x[i] - *(y + 8 * d + i); + q1 = x[i] - *(y + 9 * d + i); + q2 = x[i] - *(y + 10 * d + i); + q3 = x[i] - *(y + 11 * d + i); + q4 = x[i] - *(y + 12 * d + i); + q5 = x[i] - *(y + 13 * d + i); + q6 = x[i] - *(y + 14 * d + i); + q7 = x[i] - *(y + 15 * d + i); + d8 += q0 * q0; + d9 += q1 * q1; + d10 += q2 * q2; + d11 += q3 * q3; + d12 += q4 * q4; + d13 += q5 * q5; + d14 += q6 * q6; + d15 += q7 * q7; + q0 = x[i] - *(y + 16 * d + i); + q1 = x[i] - *(y + 17 * d + i); + q2 = x[i] - *(y + 18 * d + i); + q3 = x[i] - *(y + 19 * d + i); + q4 = x[i] - *(y + 20 * d + i); + q5 = x[i] - *(y + 21 * d + i); + q6 = x[i] - *(y + 22 * d + i); + q7 = x[i] - *(y + 23 * d + i); + d16 += q0 * q0; + d17 += q1 * q1; + d18 += q2 * q2; + d19 += q3 * q3; + d20 += q4 * q4; + d21 += q5 * q5; + d22 += q6 * q6; + d23 += q7 * q7; + } + dis[0] += d0; + dis[1] += d1; + dis[2] += d2; + dis[3] += d3; + dis[4] += d4; + dis[5] += d5; + dis[6] += d6; + dis[7] += d7; + dis[8] += d8; + dis[9] += d9; + dis[10] += d10; + dis[11] += d11; + dis[12] += d12; + dis[13] += d13; + dis[14] += d14; + dis[15] += d15; + dis[16] += d16; + dis[17] += d17; + dis[18] += d18; + dis[19] += d19; + dis[20] += d20; + dis[21] += d21; + dis[22] += d22; + dis[23] += d23; + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for 16 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_continuous_transpose_large_kernel(float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[16]; + float32x4_t neon_base[8]; + float32x4_t single_query = vdupq_n_f32(x[0]); + prefetch_Lx(y + 64); + neon_base[0] = vld1q_f32(y); + neon_base[1] = vld1q_f32(y + 4); + neon_base[2] = vld1q_f32(y + 8); + neon_base[3] = vld1q_f32(y + 12); + neon_base[4] = vld1q_f32(y + 16); + neon_base[5] = vld1q_f32(y + 20); + neon_base[6] = vld1q_f32(y + 24); + neon_base[7] = vld1q_f32(y + 28); + + neon_base[0] = vsubq_f32(neon_base[0], single_query); + neon_base[1] = vsubq_f32(neon_base[1], single_query); + neon_base[2] = vsubq_f32(neon_base[2], single_query); + neon_base[3] = vsubq_f32(neon_base[3], single_query); + neon_base[4] = vsubq_f32(neon_base[4], single_query); + neon_base[5] = vsubq_f32(neon_base[5], single_query); + neon_base[6] = vsubq_f32(neon_base[6], single_query); + neon_base[7] = vsubq_f32(neon_base[7], single_query); + + neon_res[0] = vmulq_f32(neon_base[0], neon_base[0]); + neon_res[1] = vmulq_f32(neon_base[1], neon_base[1]); + neon_res[2] = vmulq_f32(neon_base[2], neon_base[2]); + neon_res[3] = vmulq_f32(neon_base[3], neon_base[3]); + neon_res[4] = vmulq_f32(neon_base[4], neon_base[4]); + neon_res[5] = vmulq_f32(neon_base[5], neon_base[5]); + neon_res[6] = vmulq_f32(neon_base[6], neon_base[6]); + neon_res[7] = vmulq_f32(neon_base[7], neon_base[7]); + + neon_base[0] = vld1q_f32(y + 32); + neon_base[1] = vld1q_f32(y + 36); + neon_base[2] = vld1q_f32(y + 40); + neon_base[3] = vld1q_f32(y + 44); + neon_base[4] = vld1q_f32(y + 48); + neon_base[5] = vld1q_f32(y + 52); + neon_base[6] = vld1q_f32(y + 56); + neon_base[7] = vld1q_f32(y + 60); + + neon_base[0] = vsubq_f32(neon_base[0], single_query); + neon_base[1] = vsubq_f32(neon_base[1], single_query); + neon_base[2] = vsubq_f32(neon_base[2], single_query); + neon_base[3] = vsubq_f32(neon_base[3], single_query); + neon_base[4] = vsubq_f32(neon_base[4], single_query); + neon_base[5] = vsubq_f32(neon_base[5], single_query); + neon_base[6] = vsubq_f32(neon_base[6], single_query); + neon_base[7] = vsubq_f32(neon_base[7], single_query); + + neon_res[8] = vmulq_f32(neon_base[0], neon_base[0]); + neon_res[9] = vmulq_f32(neon_base[1], neon_base[1]); + neon_res[10] = vmulq_f32(neon_base[2], neon_base[2]); + neon_res[11] = vmulq_f32(neon_base[3], neon_base[3]); + neon_res[12] = vmulq_f32(neon_base[4], neon_base[4]); + neon_res[13] = vmulq_f32(neon_base[5], neon_base[5]); + neon_res[14] = vmulq_f32(neon_base[6], neon_base[6]); + neon_res[15] = vmulq_f32(neon_base[7], neon_base[7]); + + /* dim loop */ + for (size_t i = 1; i < d; ++i) { + single_query = vdupq_n_f32(x[i]); + prefetch_Lx(y + 64 * (i + 1)); + + neon_base[0] = vld1q_f32(y + 64 * i); + neon_base[1] = vld1q_f32(y + 64 * i + 4); + neon_base[2] = vld1q_f32(y + 64 * i + 8); + neon_base[3] = vld1q_f32(y + 64 * i + 12); + neon_base[4] = vld1q_f32(y + 64 * i + 16); + neon_base[5] = vld1q_f32(y + 64 * i + 20); + neon_base[6] = vld1q_f32(y + 64 * i + 24); + neon_base[7] = vld1q_f32(y + 64 * i + 28); + + neon_base[0] = vsubq_f32(neon_base[0], single_query); + neon_base[1] = vsubq_f32(neon_base[1], single_query); + neon_base[2] = vsubq_f32(neon_base[2], single_query); + neon_base[3] = vsubq_f32(neon_base[3], single_query); + neon_base[4] = vsubq_f32(neon_base[4], single_query); + neon_base[5] = vsubq_f32(neon_base[5], single_query); + neon_base[6] = vsubq_f32(neon_base[6], single_query); + neon_base[7] = vsubq_f32(neon_base[7], single_query); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_base[0], neon_base[0]); + neon_res[1] = vmlaq_f32(neon_res[1], neon_base[1], neon_base[1]); + neon_res[2] = vmlaq_f32(neon_res[2], neon_base[2], neon_base[2]); + neon_res[3] = vmlaq_f32(neon_res[3], neon_base[3], neon_base[3]); + neon_res[4] = vmlaq_f32(neon_res[4], neon_base[4], neon_base[4]); + neon_res[5] = vmlaq_f32(neon_res[5], neon_base[5], neon_base[5]); + neon_res[6] = vmlaq_f32(neon_res[6], neon_base[6], neon_base[6]); + neon_res[7] = vmlaq_f32(neon_res[7], neon_base[7], neon_base[7]); + + neon_base[0] = vld1q_f32(y + 64 * i + 32); + neon_base[1] = vld1q_f32(y + 64 * i + 36); + neon_base[2] = vld1q_f32(y + 64 * i + 40); + neon_base[3] = vld1q_f32(y + 64 * i + 44); + neon_base[4] = vld1q_f32(y + 64 * i + 48); + neon_base[5] = vld1q_f32(y + 64 * i + 52); + neon_base[6] = vld1q_f32(y + 64 * i + 56); + neon_base[7] = vld1q_f32(y + 64 * i + 60); + + neon_base[0] = vsubq_f32(neon_base[0], single_query); + neon_base[1] = vsubq_f32(neon_base[1], single_query); + neon_base[2] = vsubq_f32(neon_base[2], single_query); + neon_base[3] = vsubq_f32(neon_base[3], single_query); + neon_base[4] = vsubq_f32(neon_base[4], single_query); + neon_base[5] = vsubq_f32(neon_base[5], single_query); + neon_base[6] = vsubq_f32(neon_base[6], single_query); + neon_base[7] = vsubq_f32(neon_base[7], single_query); + + neon_res[8] = vmlaq_f32(neon_res[8], neon_base[0], neon_base[0]); + neon_res[9] = vmlaq_f32(neon_res[9], neon_base[1], neon_base[1]); + neon_res[10] = vmlaq_f32(neon_res[10], neon_base[2], neon_base[2]); + neon_res[11] = vmlaq_f32(neon_res[11], neon_base[3], neon_base[3]); + neon_res[12] = vmlaq_f32(neon_res[12], neon_base[4], neon_base[4]); + neon_res[13] = vmlaq_f32(neon_res[13], neon_base[5], neon_base[5]); + neon_res[14] = vmlaq_f32(neon_res[14], neon_base[6], neon_base[6]); + neon_res[15] = vmlaq_f32(neon_res[15], neon_base[7], neon_base[7]); + } + { + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); + vst1q_f32(dis + 16, neon_res[4]); + vst1q_f32(dis + 20, neon_res[5]); + vst1q_f32(dis + 24, neon_res[6]); + vst1q_f32(dis + 28, neon_res[7]); + vst1q_f32(dis + 32, neon_res[8]); + vst1q_f32(dis + 36, neon_res[9]); + vst1q_f32(dis + 40, neon_res[10]); + vst1q_f32(dis + 44, neon_res[11]); + vst1q_f32(dis + 48, neon_res[12]); + vst1q_f32(dis + 52, neon_res[13]); + vst1q_f32(dis + 56, neon_res[14]); + vst1q_f32(dis + 60, neon_res[15]); + } +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for 8 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_continuous_transpose_medium_kernel(float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[8]; + float32x4_t neon_base[8]; + float32x4_t neon_diff[8]; + float32x4_t single_query = vdupq_n_f32(x[0]); + neon_base[0] = vld1q_f32(y); + neon_base[1] = vld1q_f32(y + 4); + neon_base[2] = vld1q_f32(y + 8); + neon_base[3] = vld1q_f32(y + 12); + neon_base[4] = vld1q_f32(y + 16); + neon_base[5] = vld1q_f32(y + 20); + neon_base[6] = vld1q_f32(y + 24); + neon_base[7] = vld1q_f32(y + 28); + + neon_diff[0] = vsubq_f32(neon_base[0], single_query); + neon_diff[1] = vsubq_f32(neon_base[1], single_query); + neon_diff[2] = vsubq_f32(neon_base[2], single_query); + neon_diff[3] = vsubq_f32(neon_base[3], single_query); + neon_diff[4] = vsubq_f32(neon_base[4], single_query); + neon_diff[5] = vsubq_f32(neon_base[5], single_query); + neon_diff[6] = vsubq_f32(neon_base[6], single_query); + neon_diff[7] = vsubq_f32(neon_base[7], single_query); + + if (unlikely(d == 1)) { + neon_res[0] = vmulq_f32(neon_diff[0], neon_diff[0]); + neon_res[1] = vmulq_f32(neon_diff[1], neon_diff[1]); + neon_res[2] = vmulq_f32(neon_diff[2], neon_diff[2]); + neon_res[3] = vmulq_f32(neon_diff[3], neon_diff[3]); + neon_res[4] = vmulq_f32(neon_diff[4], neon_diff[4]); + neon_res[5] = vmulq_f32(neon_diff[5], neon_diff[5]); + neon_res[6] = vmulq_f32(neon_diff[6], neon_diff[6]); + neon_res[7] = vmulq_f32(neon_diff[7], neon_diff[7]); + } else { + single_query = vdupq_n_f32(x[1]); + neon_base[0] = vld1q_f32(y + 32); + neon_base[1] = vld1q_f32(y + 36); + neon_base[2] = vld1q_f32(y + 40); + neon_base[3] = vld1q_f32(y + 44); + neon_base[4] = vld1q_f32(y + 48); + neon_base[5] = vld1q_f32(y + 52); + neon_base[6] = vld1q_f32(y + 56); + neon_base[7] = vld1q_f32(y + 60); + + neon_res[0] = vmulq_f32(neon_diff[0], neon_diff[0]); + neon_res[1] = vmulq_f32(neon_diff[1], neon_diff[1]); + neon_res[2] = vmulq_f32(neon_diff[2], neon_diff[2]); + neon_res[3] = vmulq_f32(neon_diff[3], neon_diff[3]); + neon_res[4] = vmulq_f32(neon_diff[4], neon_diff[4]); + neon_res[5] = vmulq_f32(neon_diff[5], neon_diff[5]); + neon_res[6] = vmulq_f32(neon_diff[6], neon_diff[6]); + neon_res[7] = vmulq_f32(neon_diff[7], neon_diff[7]); + /* dim loop */ + for (size_t i = 2; i < d; ++i) { + neon_diff[0] = vsubq_f32(neon_base[0], single_query); + neon_diff[1] = vsubq_f32(neon_base[1], single_query); + neon_diff[2] = vsubq_f32(neon_base[2], single_query); + neon_diff[3] = vsubq_f32(neon_base[3], single_query); + neon_diff[4] = vsubq_f32(neon_base[4], single_query); + neon_diff[5] = vsubq_f32(neon_base[5], single_query); + neon_diff[6] = vsubq_f32(neon_base[6], single_query); + neon_diff[7] = vsubq_f32(neon_base[7], single_query); + + single_query = vdupq_n_f32(x[i]); + neon_base[0] = vld1q_f32(y + 32 * i); + neon_base[1] = vld1q_f32(y + 32 * i + 4); + neon_base[2] = vld1q_f32(y + 32 * i + 8); + neon_base[3] = vld1q_f32(y + 32 * i + 12); + neon_base[4] = vld1q_f32(y + 32 * i + 16); + neon_base[5] = vld1q_f32(y + 32 * i + 20); + neon_base[6] = vld1q_f32(y + 32 * i + 24); + neon_base[7] = vld1q_f32(y + 32 * i + 28); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_diff[0], neon_diff[0]); + neon_res[1] = vmlaq_f32(neon_res[1], neon_diff[1], neon_diff[1]); + neon_res[2] = vmlaq_f32(neon_res[2], neon_diff[2], neon_diff[2]); + neon_res[3] = vmlaq_f32(neon_res[3], neon_diff[3], neon_diff[3]); + neon_res[4] = vmlaq_f32(neon_res[4], neon_diff[4], neon_diff[4]); + neon_res[5] = vmlaq_f32(neon_res[5], neon_diff[5], neon_diff[5]); + neon_res[6] = vmlaq_f32(neon_res[6], neon_diff[6], neon_diff[6]); + neon_res[7] = vmlaq_f32(neon_res[7], neon_diff[7], neon_diff[7]); + } + { + neon_diff[0] = vsubq_f32(neon_base[0], single_query); + neon_diff[1] = vsubq_f32(neon_base[1], single_query); + neon_diff[2] = vsubq_f32(neon_base[2], single_query); + neon_diff[3] = vsubq_f32(neon_base[3], single_query); + neon_diff[4] = vsubq_f32(neon_base[4], single_query); + neon_diff[5] = vsubq_f32(neon_base[5], single_query); + neon_diff[6] = vsubq_f32(neon_base[6], single_query); + neon_diff[7] = vsubq_f32(neon_base[7], single_query); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_diff[0], neon_diff[0]); + neon_res[1] = vmlaq_f32(neon_res[1], neon_diff[1], neon_diff[1]); + neon_res[2] = vmlaq_f32(neon_res[2], neon_diff[2], neon_diff[2]); + neon_res[3] = vmlaq_f32(neon_res[3], neon_diff[3], neon_diff[3]); + neon_res[4] = vmlaq_f32(neon_res[4], neon_diff[4], neon_diff[4]); + neon_res[5] = vmlaq_f32(neon_res[5], neon_diff[5], neon_diff[5]); + neon_res[6] = vmlaq_f32(neon_res[6], neon_diff[6], neon_diff[6]); + neon_res[7] = vmlaq_f32(neon_res[7], neon_diff[7], neon_diff[7]); + } + } + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); + vst1q_f32(dis + 16, neon_res[4]); + vst1q_f32(dis + 20, neon_res[5]); + vst1q_f32(dis + 24, neon_res[6]); + vst1q_f32(dis + 28, neon_res[7]); +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for 4 vectors with float precision and store results in dis array. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param d Dimension of the vectors. +*/ +KRL_IMPRECISE_FUNCTION_BEGIN +static void krl_L2sqr_continuous_transpose_mini_kernel(float *dis, const float *x, const float *y, const size_t d) +{ + float32x4_t neon_res[4]; + float32x4_t single_query = vdupq_n_f32(x[0]); + float32x4_t neon_base1 = vld1q_f32(y); + float32x4_t neon_base2 = vld1q_f32(y + 4); + float32x4_t neon_base3 = vld1q_f32(y + 8); + float32x4_t neon_base4 = vld1q_f32(y + 12); + float32x4_t neon_diff1 = vsubq_f32(neon_base1, single_query); + float32x4_t neon_diff2 = vsubq_f32(neon_base2, single_query); + float32x4_t neon_diff3 = vsubq_f32(neon_base3, single_query); + float32x4_t neon_diff4 = vsubq_f32(neon_base4, single_query); + if (unlikely(d == 1)) { + neon_res[0] = vmulq_f32(neon_diff1, neon_diff1); + neon_res[1] = vmulq_f32(neon_diff2, neon_diff2); + neon_res[2] = vmulq_f32(neon_diff3, neon_diff3); + neon_res[3] = vmulq_f32(neon_diff4, neon_diff4); + } else { + single_query = vdupq_n_f32(x[1]); + neon_base1 = vld1q_f32(y + 16); + neon_base2 = vld1q_f32(y + 20); + neon_base3 = vld1q_f32(y + 24); + neon_base4 = vld1q_f32(y + 28); + neon_res[0] = vmulq_f32(neon_diff1, neon_diff1); + neon_res[1] = vmulq_f32(neon_diff2, neon_diff2); + neon_res[2] = vmulq_f32(neon_diff3, neon_diff3); + neon_res[3] = vmulq_f32(neon_diff4, neon_diff4); + for (size_t i = 2; i < d; ++i) { + neon_diff1 = vsubq_f32(neon_base1, single_query); + neon_diff2 = vsubq_f32(neon_base2, single_query); + neon_diff3 = vsubq_f32(neon_base3, single_query); + neon_diff4 = vsubq_f32(neon_base4, single_query); + + single_query = vdupq_n_f32(x[i]); + neon_base1 = vld1q_f32(y + 16 * i); + neon_base2 = vld1q_f32(y + 16 * i + 4); + neon_base3 = vld1q_f32(y + 16 * i + 8); + neon_base4 = vld1q_f32(y + 16 * i + 12); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_diff1, neon_diff1); + neon_res[1] = vmlaq_f32(neon_res[1], neon_diff2, neon_diff2); + neon_res[2] = vmlaq_f32(neon_res[2], neon_diff3, neon_diff3); + neon_res[3] = vmlaq_f32(neon_res[3], neon_diff4, neon_diff4); + } + { + neon_diff1 = vsubq_f32(neon_base1, single_query); + neon_diff2 = vsubq_f32(neon_base2, single_query); + neon_diff3 = vsubq_f32(neon_base3, single_query); + neon_diff4 = vsubq_f32(neon_base4, single_query); + + neon_res[0] = vmlaq_f32(neon_res[0], neon_diff1, neon_diff1); + neon_res[1] = vmlaq_f32(neon_res[1], neon_diff2, neon_diff2); + neon_res[2] = vmlaq_f32(neon_res[2], neon_diff3, neon_diff3); + neon_res[3] = vmlaq_f32(neon_res[3], neon_diff4, neon_diff4); + } + } + + vst1q_f32(dis, neon_res[0]); + vst1q_f32(dis + 4, neon_res[1]); + vst1q_f32(dis + 8, neon_res[2]); + vst1q_f32(dis + 12, neon_res[3]); +} +KRL_IMPRECISE_FUNCTION_END + +/* +* @brief Compute L2 squares for a batch of vectors based on given indices. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param ids Pointer to the indices array for selecting database vectors. +* @param d Dimension of the vectors. +* @param ny Number of database vectors to process. +* @param dis_size Length of dis. +*/ +int krl_L2sqr_by_idx( + float *dis, const float *x, const float *y, const int64_t *ids, size_t d, size_t ny, size_t dis_size) +{ + size_t i = 0; + const float *__restrict listy[24]; + + if (d < 1 || d > 65535 || ny < 1 || ny > 1ULL << 30) { + std::printf("Error: INVALPARAM in krl_L2sqr_by_idx\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || ids == nullptr || dis == nullptr || dis_size < ny) { + std::printf("Error: INVALPOINTER in krl_L2sqr_by_idx\n"); + return INVALPOINTER; + } + + for (; i + 24 <= ny; i += 24) { + prefetch_L1(x); + listy[0] = (const float *)(y + *(ids + i) * d); + prefetch_Lx(listy[0]); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + prefetch_Lx(listy[1]); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + prefetch_Lx(listy[2]); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + prefetch_Lx(listy[3]); + listy[4] = (const float *)(y + *(ids + i + 4) * d); + prefetch_Lx(listy[4]); + listy[5] = (const float *)(y + *(ids + i + 5) * d); + prefetch_Lx(listy[5]); + listy[6] = (const float *)(y + *(ids + i + 6) * d); + prefetch_Lx(listy[6]); + listy[7] = (const float *)(y + *(ids + i + 7) * d); + prefetch_Lx(listy[7]); + listy[8] = (const float *)(y + *(ids + i + 8) * d); + prefetch_Lx(listy[8]); + listy[9] = (const float *)(y + *(ids + i + 9) * d); + prefetch_Lx(listy[9]); + listy[10] = (const float *)(y + *(ids + i + 10) * d); + prefetch_Lx(listy[10]); + listy[11] = (const float *)(y + *(ids + i + 11) * d); + prefetch_Lx(listy[11]); + listy[12] = (const float *)(y + *(ids + i + 12) * d); + prefetch_Lx(listy[12]); + listy[13] = (const float *)(y + *(ids + i + 13) * d); + prefetch_Lx(listy[13]); + listy[14] = (const float *)(y + *(ids + i + 14) * d); + prefetch_Lx(listy[14]); + listy[15] = (const float *)(y + *(ids + i + 15) * d); + prefetch_Lx(listy[15]); + listy[16] = (const float *)(y + *(ids + i + 16) * d); + prefetch_Lx(listy[16]); + listy[17] = (const float *)(y + *(ids + i + 17) * d); + prefetch_Lx(listy[17]); + listy[18] = (const float *)(y + *(ids + i + 18) * d); + prefetch_Lx(listy[18]); + listy[19] = (const float *)(y + *(ids + i + 19) * d); + prefetch_Lx(listy[19]); + listy[20] = (const float *)(y + *(ids + i + 20) * d); + prefetch_Lx(listy[20]); + listy[21] = (const float *)(y + *(ids + i + 21) * d); + prefetch_Lx(listy[21]); + listy[22] = (const float *)(y + *(ids + i + 22) * d); + prefetch_Lx(listy[22]); + listy[23] = (const float *)(y + *(ids + i + 23) * d); + prefetch_Lx(listy[23]); + krl_L2sqr_idx_prefetch_batch24(x, listy, d, dis + i); + } + if (i + 16 <= ny) { + prefetch_L1(x); + listy[0] = (const float *)(y + *(ids + i) * d); + prefetch_Lx(listy[0]); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + prefetch_Lx(listy[1]); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + prefetch_Lx(listy[2]); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + prefetch_Lx(listy[3]); + listy[4] = (const float *)(y + *(ids + i + 4) * d); + prefetch_Lx(listy[4]); + listy[5] = (const float *)(y + *(ids + i + 5) * d); + prefetch_Lx(listy[5]); + listy[6] = (const float *)(y + *(ids + i + 6) * d); + prefetch_Lx(listy[6]); + listy[7] = (const float *)(y + *(ids + i + 7) * d); + prefetch_Lx(listy[7]); + listy[8] = (const float *)(y + *(ids + i + 8) * d); + prefetch_Lx(listy[8]); + listy[9] = (const float *)(y + *(ids + i + 9) * d); + prefetch_Lx(listy[9]); + listy[10] = (const float *)(y + *(ids + i + 10) * d); + prefetch_Lx(listy[10]); + listy[11] = (const float *)(y + *(ids + i + 11) * d); + prefetch_Lx(listy[11]); + listy[12] = (const float *)(y + *(ids + i + 12) * d); + prefetch_Lx(listy[12]); + listy[13] = (const float *)(y + *(ids + i + 13) * d); + prefetch_Lx(listy[13]); + listy[14] = (const float *)(y + *(ids + i + 14) * d); + prefetch_Lx(listy[14]); + listy[15] = (const float *)(y + *(ids + i + 15) * d); + prefetch_Lx(listy[15]); + krl_L2sqr_idx_prefetch_batch16(x, listy, d, dis + i); + i += 16; + } else if (i + 8 <= ny) { + prefetch_L1(x); + listy[0] = (const float *)(y + *(ids + i) * d); + prefetch_Lx(listy[0]); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + prefetch_Lx(listy[1]); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + prefetch_Lx(listy[2]); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + prefetch_Lx(listy[3]); + listy[4] = (const float *)(y + *(ids + i + 4) * d); + prefetch_Lx(listy[4]); + listy[5] = (const float *)(y + *(ids + i + 5) * d); + prefetch_Lx(listy[5]); + listy[6] = (const float *)(y + *(ids + i + 6) * d); + prefetch_Lx(listy[6]); + listy[7] = (const float *)(y + *(ids + i + 7) * d); + prefetch_Lx(listy[7]); + krl_L2sqr_idx_prefetch_batch8(x, listy, d, dis + i); + i += 8; + } + if (ny & 4) { + listy[0] = (const float *)(y + *(ids + i) * d); + listy[1] = (const float *)(y + *(ids + i + 1) * d); + listy[2] = (const float *)(y + *(ids + i + 2) * d); + listy[3] = (const float *)(y + *(ids + i + 3) * d); + krl_L2sqr_idx_batch4(x, listy, d, dis + i); + i += 4; + } + if (ny & 2) { + const float *y0 = y + *(ids + i) * d; + const float *y1 = y + *(ids + i + 1) * d; + krl_L2sqr_idx_batch2(x, y0, y1, d, dis + i); + i += 2; + } + if (ny & 1) { + krl_L2sqr(x, y + d * ids[i], d, &dis[i], 1); + } + return SUCCESS; +} + +/* +* @brief Compute L2 squares for a batch of vectors. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param y Pointer to the database vectors (float). +* @param ny Number of database vectors to process. +* @param d Dimension of the vectors. +* @param dis_size Length of dis. +*/ +int krl_L2sqr_ny(float *dis, const float *x, const float *y, const size_t ny, const size_t d, size_t dis_size) +{ + size_t i = 0; + + if (d < 1 || d > 65535 || ny < 1 || ny > 1ULL << 30) { + std::printf("Error: INVALPARAM in krl_L2sqr_ny\n"); + return INVALPARAM; + } + + if (x == nullptr || y == nullptr || dis == nullptr || dis_size < ny) { + std::printf("Error: INVALPOINTER in krl_L2sqr_ny\n"); + return INVALPOINTER; + } + + for (; i + 24 <= ny; i += 24) { + krl_L2sqr_batch24(x, y + i * d, d, dis + i); + } + if (i + 16 <= ny) { + krl_L2sqr_batch16(x, y + i * d, d, dis + i); + i += 16; + } else if (i + 8 <= ny) { + krl_L2sqr_batch8(x, y + i * d, d, dis + i); + i += 8; + } + if (ny & 4) { + krl_L2sqr_batch4(x, y + i * d, d, dis + i); + i += 4; + } + if (ny & 2) { + krl_L2sqr_batch2(x, y + i * d, d, dis + i); + } + if (ny & 1) { + const float *y0 = (y + (ny - 1) * d); + krl_L2sqr(x, y0, d, &dis[ny - 1], 1); + } + return SUCCESS; +} + +/* +* @brief Compute L2 squares for a batch of vectors with a given handle. +* @param kdh Pointer to the distance handle containing configuration and data. +* @param dis Pointer to the output array for storing the results (float). +* @param x Pointer to the query vector (float). +* @param dis_size Length of dis. +* @param x_size Length of x. +*/ +int krl_L2sqr_ny_with_handle(const KRLDistanceHandle *kdh, float *dis, const float *x, size_t dis_size, size_t x_size) +{ + if (kdh == nullptr || dis == nullptr || x == nullptr) { + std::printf("Error: INVALPOINTER in krl_L2sqr_ny_with_handle\n"); + return INVALPOINTER; + } + const size_t ny = kdh->ny; + const size_t dim = kdh->d; + const size_t M = kdh->M; + if (dis_size < M * ny || x_size < dim * M) { + std::printf("Error: INVALPARAM in krl_L2sqr_ny_with_handle\n"); + return INVALPARAM; + } + + if (kdh->data_bits == 32) { + const float *y = (const float *)kdh->transposed_codes; + const size_t ceil_ny = kdh->ceil_ny; + const size_t left = ny & (kdh->blocksize - 1); + switch (kdh->blocksize) { + case 16: + if (left) { + float distance_tmp_buffer[16]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 16 <= ny; i += 16) { + krl_L2sqr_continuous_transpose_mini_kernel(dis + i, x, y + i * dim, dim); + } + krl_L2sqr_continuous_transpose_mini_kernel(distance_tmp_buffer, x, y + i * dim, dim); + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 16) { + krl_L2sqr_continuous_transpose_mini_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + case 32: + if (left) { + float distance_tmp_buffer[32]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 32 <= ny; i += 32) { + krl_L2sqr_continuous_transpose_medium_kernel(dis + i, x, y + i * dim, dim); + } + krl_L2sqr_continuous_transpose_medium_kernel(distance_tmp_buffer, x, y + i * dim, dim); + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 32) { + krl_L2sqr_continuous_transpose_medium_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + case 64: + if (left) { + float distance_tmp_buffer[64]; + for (size_t m = 0; m < M; m++) { + size_t i = 0; + for (; i + 64 <= ny; i += 64) { + krl_L2sqr_continuous_transpose_large_kernel(dis + i, x, y + i * dim, dim); + } + krl_L2sqr_continuous_transpose_large_kernel(distance_tmp_buffer, x, y + i * dim, dim); + size_t remaining_dis_size = dis_size - (m * ny + i); + if (remaining_dis_size < left) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + int ret = SafeMemory::CheckAndMemcpy( + dis + i, remaining_dis_size * sizeof(float), distance_tmp_buffer, left * sizeof(float)); + if (ret != 0) { + std::printf("Error: UNSAFEMEM in krl_L2sqr_ny_with_handle\n"); + return UNSAFEMEM; + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } else { + for (size_t m = 0; m < M; m++) { + for (size_t i = 0; i < ny; i += 64) { + krl_L2sqr_continuous_transpose_large_kernel(dis + i, x, y + i * dim, dim); + } + dis += ny; + x += dim; + y += ceil_ny * dim; + } + } + break; + } + } else if (kdh->data_bits == 16) { + // fp16 path not built in minimal KRL for OpenViking + std::printf("Error: INVALPARAM in krl_L2sqr_ny_with_handle (fp16 not supported)\n"); + return INVALPARAM; + } else { + // u8 path not built in minimal KRL for OpenViking + std::printf("Error: INVALPARAM in krl_L2sqr_ny_with_handle (u8 not supported)\n"); + return INVALPARAM; + } + return SUCCESS; +} + +} // extern "C"