diff --git a/dpnp/backend/extensions/lapack/evd_common.hpp b/dpnp/backend/extensions/lapack/evd_common.hpp new file mode 100644 index 000000000000..180f8ad0f651 --- /dev/null +++ b/dpnp/backend/extensions/lapack/evd_common.hpp @@ -0,0 +1,178 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace evd +{ +typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &, + const oneapi::mkl::job, + const oneapi::mkl::uplo, + const std::int64_t, + char *, + char *, + std::vector &, + const std::vector &); + +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +namespace py = pybind11; + +template +std::pair + evd_func(sycl::queue &exec_q, + const std::int8_t jobz, + const std::int8_t upper_lower, + dpctl::tensor::usm_ndarray &eig_vecs, + dpctl::tensor::usm_ndarray &eig_vals, + const std::vector &depends, + const dispatchT &evd_dispatch_table) +{ + const int eig_vecs_nd = eig_vecs.get_ndim(); + const int eig_vals_nd = eig_vals.get_ndim(); + + if (eig_vecs_nd != 2) { + throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) + + " of an output array with eigenvectors"); + } + else if (eig_vals_nd != 1) { + throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) + + " of an output array with eigenvalues"); + } + + const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw(); + const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw(); + + if (eig_vecs_shape[0] != eig_vecs_shape[1]) { + throw py::value_error("Output array with eigenvectors with be square"); + } + else if (eig_vecs_shape[0] != eig_vals_shape[0]) { + throw py::value_error( + "Eigenvectors and eigenvalues have different shapes"); + } + + size_t src_nelems(1); + + for (int i = 0; i < eig_vecs_nd; ++i) { + src_nelems *= static_cast(eig_vecs_shape[i]); + } + + if (src_nelems == 0) { + // nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vecs); + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vals); + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(eig_vecs, eig_vals)) { + throw py::value_error("Arrays with eigenvectors and eigenvalues are " + "overlapping segments of memory"); + } + + bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous(); + bool is_eig_vals_c_contig = eig_vals.is_c_contiguous(); + if (!is_eig_vecs_f_contig) { + throw py::value_error( + "An array with input matrix / output eigenvectors " + "must be F-contiguous"); + } + else if (!is_eig_vals_c_contig) { + throw py::value_error( + "An array with output eigenvalues must be C-contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int eig_vecs_type_id = + array_types.typenum_to_lookup_id(eig_vecs.get_typenum()); + int eig_vals_type_id = + array_types.typenum_to_lookup_id(eig_vals.get_typenum()); + + evd_impl_fn_ptr_t evd_fn = + evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; + if (evd_fn == nullptr) { + throw py::value_error( + "Types of input vectors and result array are mismatched."); + } + + char *eig_vecs_data = eig_vecs.get_data(); + char *eig_vals_data = eig_vals.get_data(); + + const std::int64_t n = eig_vecs_shape[0]; + const oneapi::mkl::job jobz_val = static_cast(jobz); + const oneapi::mkl::uplo uplo_val = + static_cast(upper_lower); + + std::vector host_task_events; + sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data, + eig_vals_data, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {eig_vecs, eig_vals}, host_task_events); + + return std::make_pair(args_ev, evd_ev); +} + +template + typename factoryT> +void init_evd_dispatch_table( + dispatchT evd_dispatch_table[][dpctl_td_ns::num_types]) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(evd_dispatch_table); +} +} // namespace evd +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/heevd.cpp b/dpnp/backend/extensions/lapack/heevd.cpp index feffb2ec4a36..4d8626d0ea91 100644 --- a/dpnp/backend/extensions/lapack/heevd.cpp +++ b/dpnp/backend/extensions/lapack/heevd.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023-2024, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -23,16 +23,7 @@ // THE POSSIBILITY OF SUCH DAMAGE. //***************************************************************************** -#include - -// dpctl tensor headers -#include "utils/memory_overlap.hpp" -#include "utils/type_utils.hpp" - #include "heevd.hpp" -#include "types_matrix.hpp" - -#include "dpnp_utils.hpp" namespace dpnp { @@ -43,23 +34,10 @@ namespace ext namespace lapack { namespace mkl_lapack = oneapi::mkl::lapack; -namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*heevd_impl_fn_ptr_t)(sycl::queue, - const oneapi::mkl::job, - const oneapi::mkl::uplo, - const std::int64_t, - char *, - char *, - std::vector &, - const std::vector &); - -static heevd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types] - [dpctl_td_ns::num_types]; - template -static sycl::event heevd_impl(sycl::queue exec_q, +static sycl::event heevd_impl(sycl::queue &exec_q, const oneapi::mkl::job jobz, const oneapi::mkl::uplo upper_lower, const std::int64_t n, @@ -128,104 +106,8 @@ static sycl::event heevd_impl(sycl::queue exec_q, cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); }); host_task_events.push_back(clean_up_event); - return heevd_event; -} - -std::pair - heevd(sycl::queue exec_q, - const std::int8_t jobz, - const std::int8_t upper_lower, - dpctl::tensor::usm_ndarray eig_vecs, - dpctl::tensor::usm_ndarray eig_vals, - const std::vector &depends) -{ - const int eig_vecs_nd = eig_vecs.get_ndim(); - const int eig_vals_nd = eig_vals.get_ndim(); - - if (eig_vecs_nd != 2) { - throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) + - " of an output array with eigenvectors"); - } - else if (eig_vals_nd != 1) { - throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) + - " of an output array with eigenvalues"); - } - - const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw(); - const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw(); - - if (eig_vecs_shape[0] != eig_vecs_shape[1]) { - throw py::value_error("Output array with eigenvectors with be square"); - } - else if (eig_vecs_shape[0] != eig_vals_shape[0]) { - throw py::value_error( - "Eigenvectors and eigenvalues have different shapes"); - } - - size_t src_nelems(1); - - for (int i = 0; i < eig_vecs_nd; ++i) { - src_nelems *= static_cast(eig_vecs_shape[i]); - } - - if (src_nelems == 0) { - // nothing to do - return std::make_pair(sycl::event(), sycl::event()); - } - - // check compatibility of execution queue and allocation queue - if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(eig_vecs, eig_vals)) { - throw py::value_error("Arrays with eigenvectors and eigenvalues are " - "overlapping segments of memory"); - } - - bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous(); - bool is_eig_vals_c_contig = eig_vals.is_c_contiguous(); - if (!is_eig_vecs_f_contig) { - throw py::value_error( - "An array with input matrix / output eigenvectors " - "must be F-contiguous"); - } - else if (!is_eig_vals_c_contig) { - throw py::value_error( - "An array with output eigenvalues must be C-contiguous"); - } - - auto array_types = dpctl_td_ns::usm_ndarray_types(); - int eig_vecs_type_id = - array_types.typenum_to_lookup_id(eig_vecs.get_typenum()); - int eig_vals_type_id = - array_types.typenum_to_lookup_id(eig_vals.get_typenum()); - - heevd_impl_fn_ptr_t heevd_fn = - heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id]; - if (heevd_fn == nullptr) { - throw py::value_error("No heevd implementation defined for a pair of " - "type for eigenvectors and eigenvalues"); - } - - char *eig_vecs_data = eig_vecs.get_data(); - char *eig_vals_data = eig_vals.get_data(); - - const std::int64_t n = eig_vecs_shape[0]; - const oneapi::mkl::job jobz_val = static_cast(jobz); - const oneapi::mkl::uplo uplo_val = - static_cast(upper_lower); - std::vector host_task_events; - sycl::event heevd_ev = - heevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data, - host_task_events, depends); - - sycl::event args_ev = dpctl::utils::keep_args_alive( - exec_q, {eig_vecs, eig_vals}, host_task_events); - return std::make_pair(args_ev, heevd_ev); + return heevd_event; } template @@ -243,12 +125,35 @@ struct HeevdContigFactory } }; -void init_heevd_dispatch_table(void) +using evd::evd_impl_fn_ptr_t; + +void init_heevd(py::module_ m) { - dpctl_td_ns::DispatchTableBuilder - contig; - contig.populate_dispatch_table(heevd_dispatch_table); + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + static evd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + + { + evd::init_evd_dispatch_table( + heevd_dispatch_table); + + auto heevd_pyapi = [&](sycl::queue &exec_q, const std::int8_t jobz, + const std::int8_t upper_lower, arrayT &eig_vecs, + arrayT &eig_vals, + const event_vecT &depends = {}) { + return evd::evd_func(exec_q, jobz, upper_lower, eig_vecs, eig_vals, + depends, heevd_dispatch_table); + }; + + m.def("_heevd", heevd_pyapi, + "Call `heevd` from OneMKL LAPACK library to return " + "the eigenvalues and eigenvectors of a complex Hermitian matrix", + py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), + py::arg("eig_vecs"), py::arg("eig_vals"), + py::arg("depends") = py::list()); + } } } // namespace lapack } // namespace ext diff --git a/dpnp/backend/extensions/lapack/heevd.hpp b/dpnp/backend/extensions/lapack/heevd.hpp index 3eae78bde245..cdb74ad1a4e2 100644 --- a/dpnp/backend/extensions/lapack/heevd.hpp +++ b/dpnp/backend/extensions/lapack/heevd.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023-2024, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -25,10 +25,12 @@ #pragma once -#include -#include +#include +#include -#include +#include "evd_common.hpp" + +namespace py = pybind11; namespace dpnp { @@ -38,15 +40,7 @@ namespace ext { namespace lapack { -extern std::pair - heevd(sycl::queue exec_q, - const std::int8_t jobz, - const std::int8_t upper_lower, - dpctl::tensor::usm_ndarray eig_vecs, - dpctl::tensor::usm_ndarray eig_vals, - const std::vector &depends = {}); - -extern void init_heevd_dispatch_table(void); +void init_heevd(py::module_ m); } // namespace lapack } // namespace ext } // namespace backend diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index e6b4365a906d..0eac8c0349f2 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -60,7 +60,6 @@ void init_dispatch_vectors(void) lapack_ext::init_orgqr_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); lapack_ext::init_potrf_dispatch_vector(); - lapack_ext::init_syevd_dispatch_vector(); lapack_ext::init_ungqr_batch_dispatch_vector(); lapack_ext::init_ungqr_dispatch_vector(); } @@ -69,7 +68,6 @@ void init_dispatch_vectors(void) void init_dispatch_tables(void) { lapack_ext::init_gesvd_dispatch_table(); - lapack_ext::init_heevd_dispatch_table(); } PYBIND11_MODULE(_lapack_impl, m) @@ -82,6 +80,9 @@ PYBIND11_MODULE(_lapack_impl, m) init_dispatch_vectors(); init_dispatch_tables(); + lapack_ext::init_heevd(m); + lapack_ext::init_syevd(m); + m.def("_geqrf_batch", &lapack_ext::geqrf_batch, "Call `geqrf_batch` from OneMKL LAPACK library to return " "the QR factorization of a batch general matrix ", @@ -139,13 +140,6 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), py::arg("b_array"), py::arg("depends") = py::list()); - m.def("_heevd", &lapack_ext::heevd, - "Call `heevd` from OneMKL LAPACK library to return " - "the eigenvalues and eigenvectors of a complex Hermitian matrix", - py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), - py::arg("eig_vecs"), py::arg("eig_vals"), - py::arg("depends") = py::list()); - m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " "the real orthogonal matrix Qi of the QR factorization " @@ -176,13 +170,6 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("n"), py::arg("stride_a"), py::arg("batch_size"), py::arg("depends") = py::list()); - m.def("_syevd", &lapack_ext::syevd, - "Call `syevd` from OneMKL LAPACK library to return " - "the eigenvalues and eigenvectors of a real symmetric matrix", - py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), - py::arg("eig_vecs"), py::arg("eig_vals"), - py::arg("depends") = py::list()); - m.def("_ungqr_batch", &lapack_ext::ungqr_batch, "Call `_ungqr_batch` from OneMKL LAPACK library to return " "the complex unitary matrices matrix Qi of the QR factorization " diff --git a/dpnp/backend/extensions/lapack/syevd.cpp b/dpnp/backend/extensions/lapack/syevd.cpp index 0374e96b8bb8..7a8041adfb15 100644 --- a/dpnp/backend/extensions/lapack/syevd.cpp +++ b/dpnp/backend/extensions/lapack/syevd.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023-2024, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -23,16 +23,7 @@ // THE POSSIBILITY OF SUCH DAMAGE. //***************************************************************************** -#include - -// dpctl tensor headers -#include "utils/memory_overlap.hpp" -#include "utils/type_utils.hpp" - #include "syevd.hpp" -#include "types_matrix.hpp" - -#include "dpnp_utils.hpp" namespace dpnp { @@ -43,22 +34,10 @@ namespace ext namespace lapack { namespace mkl_lapack = oneapi::mkl::lapack; -namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -typedef sycl::event (*syevd_impl_fn_ptr_t)(sycl::queue, - const oneapi::mkl::job, - const oneapi::mkl::uplo, - const std::int64_t, - char *, - char *, - std::vector &, - const std::vector &); - -static syevd_impl_fn_ptr_t syevd_dispatch_vector[dpctl_td_ns::num_types]; - -template -static sycl::event syevd_impl(sycl::queue exec_q, +template +static sycl::event syevd_impl(sycl::queue &exec_q, const oneapi::mkl::job jobz, const oneapi::mkl::uplo upper_lower, const std::int64_t n, @@ -68,9 +47,10 @@ static sycl::event syevd_impl(sycl::queue exec_q, const std::vector &depends) { type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); T *a = reinterpret_cast(in_a); - T *w = reinterpret_cast(out_w); + RealT *w = reinterpret_cast(out_w); const std::int64_t lda = std::max(1UL, n); const std::int64_t scratchpad_size = @@ -125,118 +105,19 @@ static sycl::event syevd_impl(sycl::queue exec_q, auto ctx = exec_q.get_context(); cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); }); + host_task_events.push_back(clean_up_event); return syevd_event; } -std::pair - syevd(sycl::queue exec_q, - const std::int8_t jobz, - const std::int8_t upper_lower, - dpctl::tensor::usm_ndarray eig_vecs, - dpctl::tensor::usm_ndarray eig_vals, - const std::vector &depends) -{ - const int eig_vecs_nd = eig_vecs.get_ndim(); - const int eig_vals_nd = eig_vals.get_ndim(); - - if (eig_vecs_nd != 2) { - throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) + - " of an output array with eigenvectors"); - } - else if (eig_vals_nd != 1) { - throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) + - " of an output array with eigenvalues"); - } - - const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw(); - const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw(); - - if (eig_vecs_shape[0] != eig_vecs_shape[1]) { - throw py::value_error("Output array with eigenvectors with be square"); - } - else if (eig_vecs_shape[0] != eig_vals_shape[0]) { - throw py::value_error( - "Eigenvectors and eigenvalues have different shapes"); - } - - size_t src_nelems(1); - - for (int i = 0; i < eig_vecs_nd; ++i) { - src_nelems *= static_cast(eig_vecs_shape[i]); - } - - if (src_nelems == 0) { - // nothing to do - return std::make_pair(sycl::event(), sycl::event()); - } - - // check compatibility of execution queue and allocation queue - if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(eig_vecs, eig_vals)) { - throw py::value_error("Arrays with eigenvectors and eigenvalues are " - "overlapping segments of memory"); - } - - bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous(); - bool is_eig_vals_c_contig = eig_vals.is_c_contiguous(); - if (!is_eig_vecs_f_contig) { - throw py::value_error( - "An array with input matrix / output eigenvectors " - "must be F-contiguous"); - } - else if (!is_eig_vals_c_contig) { - throw py::value_error( - "An array with output eigenvalues must be C-contiguous"); - } - - auto array_types = dpctl_td_ns::usm_ndarray_types(); - int eig_vecs_type_id = - array_types.typenum_to_lookup_id(eig_vecs.get_typenum()); - int eig_vals_type_id = - array_types.typenum_to_lookup_id(eig_vals.get_typenum()); - - if (eig_vecs_type_id != eig_vals_type_id) { - throw py::value_error( - "Types of eigenvectors and eigenvalues are mismatched"); - } - - syevd_impl_fn_ptr_t syevd_fn = syevd_dispatch_vector[eig_vecs_type_id]; - if (syevd_fn == nullptr) { - throw py::value_error("No syevd implementation defined for a type of " - "eigenvectors and eigenvalues"); - } - - char *eig_vecs_data = eig_vecs.get_data(); - char *eig_vals_data = eig_vals.get_data(); - - const std::int64_t n = eig_vecs_shape[0]; - const oneapi::mkl::job jobz_val = static_cast(jobz); - const oneapi::mkl::uplo uplo_val = - static_cast(upper_lower); - - std::vector host_task_events; - sycl::event syevd_ev = - syevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data, - host_task_events, depends); - - sycl::event args_ev = dpctl::utils::keep_args_alive( - exec_q, {eig_vecs, eig_vals}, host_task_events); - return std::make_pair(args_ev, syevd_ev); -} - -template +template struct SyevdContigFactory { fnT get() { - if constexpr (types::SyevdTypePairSupportFactory::is_defined) { - return syevd_impl; + if constexpr (types::SyevdTypePairSupportFactory::is_defined) + { + return syevd_impl; } else { return nullptr; @@ -244,12 +125,34 @@ struct SyevdContigFactory } }; -void init_syevd_dispatch_vector(void) +using evd::evd_impl_fn_ptr_t; + +void init_syevd(py::module_ m) { - dpctl_td_ns::DispatchVectorBuilder - contig; - contig.populate_dispatch_vector(syevd_dispatch_vector); + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + static evd_impl_fn_ptr_t syevd_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + + { + evd::init_evd_dispatch_table( + syevd_dispatch_table); + + auto syevd_pyapi = [&](sycl::queue &exec_q, const std::int8_t jobz, + const std::int8_t upper_lower, arrayT &eig_vecs, + arrayT &eig_vals, + const event_vecT &depends = {}) { + return evd::evd_func(exec_q, jobz, upper_lower, eig_vecs, eig_vals, + depends, syevd_dispatch_table); + }; + m.def("_syevd", syevd_pyapi, + "Call `syevd` from OneMKL LAPACK library to return " + "the eigenvalues and eigenvectors of a real symmetric matrix", + py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), + py::arg("eig_vecs"), py::arg("eig_vals"), + py::arg("depends") = py::list()); + } } } // namespace lapack } // namespace ext diff --git a/dpnp/backend/extensions/lapack/syevd.hpp b/dpnp/backend/extensions/lapack/syevd.hpp index 1b6750487fd5..fe596187e049 100644 --- a/dpnp/backend/extensions/lapack/syevd.hpp +++ b/dpnp/backend/extensions/lapack/syevd.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023-2024, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -25,10 +25,12 @@ #pragma once -#include -#include +#include +#include -#include +#include "evd_common.hpp" + +namespace py = pybind11; namespace dpnp { @@ -38,15 +40,7 @@ namespace ext { namespace lapack { -extern std::pair - syevd(sycl::queue exec_q, - const std::int8_t jobz, - const std::int8_t upper_lower, - dpctl::tensor::usm_ndarray eig_vecs, - dpctl::tensor::usm_ndarray eig_vals, - const std::vector &depends = {}); - -extern void init_syevd_dispatch_vector(void); +void init_syevd(py::module_ m); } // namespace lapack } // namespace ext } // namespace backend diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index a5edffa56dc0..0e69ef24dea9 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -373,12 +373,12 @@ struct PotrfBatchTypePairSupportFactory * @tparam T Type of array containing input matrix A and an output arrays with * eigenvectors and eigenvectors. */ -template +template struct SyevdTypePairSupportFactory { static constexpr bool is_defined = std::disjunction< - dpctl_td_ns::TypePairDefinedEntry, - dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; };