Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dfa35ea
`Mat()` of overlap is correct.
SaltyChiang Mar 31, 2024
3e813a2
Merge branch 'develop' into feature/overlap
SaltyChiang Aug 23, 2024
a3029c1
Merge commit '4d3236eddfeba114b1a4c892d76b2f382c9fd78a' into feature/…
SaltyChiang Dec 16, 2024
4691a74
modify dirac_overlap codes to work with newest quda
V3-vvv Sep 19, 2024
cb16be2
enable code to run overlap DiracM eigensolver
Dec 12, 2024
9468668
modify interface_quda.cpp for overlap settings
Dec 16, 2024
cbc3ebc
code modification
Dec 16, 2024
d506ae7
Add chrial projection for spinor field.
SaltyChiang Dec 25, 2024
6b2cbb5
overlap invert
V3-vvv Jun 4, 2025
df1b14b
Merge branch 'develop' into feature/overlap
SaltyChiang Jun 5, 2025
5e90942
Remove interface_quda.cpp.bak.
SaltyChiang Jun 5, 2025
823d6ea
Cleanup.
SaltyChiang Jun 5, 2025
2443e2c
Multi-shift solver for overlap is correct now.
SaltyChiang Jun 11, 2025
8f8b9bf
Add `OverlapKernel` to handle the eigensystem of gamma_5 Wilson .
SaltyChiang Jun 13, 2025
2fc439b
Add inversion for low-mode.
SaltyChiang Jun 14, 2025
dcd732c
Merge branch 'develop' into feature/overlap
SaltyChiang Jun 14, 2025
eb95324
Merge branch 'develop' into feature/overlap
SaltyChiang Jun 17, 2025
3a3fd7c
Migrate `invertOverlapQuda` to standard interface.
SaltyChiang Jun 19, 2025
7020c15
Merge branch 'develop' into feature/overlap
SaltyChiang Jun 19, 2025
8477b46
Enable MRHS solver for overlap fermion.
SaltyChiang Jun 21, 2025
2840c4e
Migrate `invertOverlapMultiShiftQuda` to the standard interface.
SaltyChiang Jun 21, 2025
71e0cbf
Apply clang-format.
SaltyChiang Jun 21, 2025
31bec8b
Enable mixed precision solver for overlap chiral fermion.
SaltyChiang Jun 22, 2025
f66f5af
Enable the eigensolver for chiral vectors.
SaltyChiang Jun 22, 2025
4e7098e
Cannot use QudaEigParam in QudaInvertParam.
SaltyChiang Jun 22, 2025
018d7c1
Revert the change in eigensolve_quda.cpp.
SaltyChiang Jun 22, 2025
5e4d944
Fix an issue in `chebyOp`.
SaltyChiang Jun 23, 2025
cf291b7
Add memory type to the field aux string.
SaltyChiang Jun 24, 2025
8e9e333
Merge branch 'develop' into feature/overlap
SaltyChiang Jun 28, 2025
a24f330
Use existing methods of `ColorSpinor` to implement chiral project and…
SaltyChiang Jun 28, 2025
7ea9ab2
Cleanup.
V3-vvv Jul 7, 2025
36691e0
Merge branch 'develop' into feature/overlap
SaltyChiang Jul 12, 2025
0624505
Merge remote-tracking branch 'origin/develop' into feature/overlap
ELI-C0DE Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions include/blas_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ namespace quda
template <> struct VectorType<int8_t, 24> {
using type = array<int8_t, 24>;
};
template <> struct VectorType<double, 12> {
using type = array<double, 12>;
};
template <> struct VectorType<float, 12> {
using type = array<float, 12>;
};
template <> struct VectorType<short, 12> {
using type = array<short, 12>;
};
template <> struct VectorType<int8_t, 12> {
using type = array<int8_t, 12>;
};
template <> struct VectorType<double, 6> {
using type = array<double, 6>;
};
Expand Down Expand Up @@ -343,37 +355,49 @@ namespace quda

// native ordering
template <> constexpr int n_vector<double, true, 4, false>() { return 2; }
template <> constexpr int n_vector<double, true, 2, false>() { return 2; }
template <> constexpr int n_vector<double, true, 1, false>() { return 2; }

template <> constexpr int n_vector<double, true, 4, true>() { return 2; }
template <> constexpr int n_vector<double, true, 2, true>() { return 2; }
template <> constexpr int n_vector<double, true, 1, true>() { return 2; }

template <> constexpr int n_vector<float, true, 4, false>() { return 4; }
template <> constexpr int n_vector<float, true, 1, false>() { return 4; }
template <> constexpr int n_vector<float, true, 2, false>() { return 4; }
template <> constexpr int n_vector<float, true, 1, false>() { return 4; } // TODO: correct?

template <> constexpr int n_vector<float, true, 4, true>() { return 4; }
template <> constexpr int n_vector<float, true, 2, true>() { return QUDA_ORDER_SP_MG; }
template <> constexpr int n_vector<float, true, 1, true>() { return 2; }

template <> constexpr int n_vector<short, true, 4, true>() { return QUDA_ORDER_FP; }
template <> constexpr int n_vector<short, true, 2, true>() { return QUDA_ORDER_FP_MG; }
template <> constexpr int n_vector<short, true, 1, true>() { return 2; }

template <> constexpr int n_vector<int8_t, true, 4, true>() { return QUDA_ORDER_FP; }
template <> constexpr int n_vector<int8_t, true, 2, true>() { return QUDA_ORDER_FP_MG; }
template <> constexpr int n_vector<int8_t, true, 1, true>() { return 2; }

// Just use float-2/float-4 ordering on CPU when not site unrolling
template <> constexpr int n_vector<double, false, 4, false>() { return 2; }
template <> constexpr int n_vector<double, false, 2, false>() { return 2; }
template <> constexpr int n_vector<double, false, 1, false>() { return 2; }
template <> constexpr int n_vector<float, false, 4, false>() { return 4; }
template <> constexpr int n_vector<float, false, 2, false>() { return 4; }
template <> constexpr int n_vector<float, false, 1, false>() { return 4; }

// AoS ordering is used on CPU uses when we are site unrolling
template <> constexpr int n_vector<double, false, 4, true>() { return 24; }
template <> constexpr int n_vector<double, false, 2, true>() { return 12; }
template <> constexpr int n_vector<double, false, 1, true>() { return 6; }
template <> constexpr int n_vector<float, false, 4, true>() { return 24; }
template <> constexpr int n_vector<float, false, 2, true>() { return 12; }
template <> constexpr int n_vector<float, false, 1, true>() { return 6; }
template <> constexpr int n_vector<short, false, 4, true>() { return 24; }
template <> constexpr int n_vector<short, false, 2, true>() { return 12; }
template <> constexpr int n_vector<short, false, 1, true>() { return 6; }
template <> constexpr int n_vector<int8_t, false, 4, true>() { return 24; }
template <> constexpr int n_vector<int8_t, false, 2, true>() { return 12; }
template <> constexpr int n_vector<int8_t, false, 1, true>() { return 6; }

template <template <typename...> class Functor,
Expand All @@ -382,13 +406,18 @@ namespace quda
constexpr void instantiate(const T &a, const T &b, const T &c, V &x_, Args &&... args)
{
unwrap_t<V> &x(x_);
if (x.Nspin() == 4 || x.Nspin() == 2) {
if constexpr (is_enabled_spin(2) || is_enabled_spin(4)) {
// Nspin-2 takes Nspin-4 path here, and we check for this later
if (x.Nspin() == 4) {
if constexpr (is_enabled_spin(4)) {
Blas<Functor, store_t, y_store_t, 4, T>(a, b, c, x, args...);
} else {
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
}
} else if (x.Nspin() == 2) {
if constexpr (is_enabled_spin(2)) {
Blas<Functor, store_t, y_store_t, 2, T>(a, b, c, x, args...);
} else {
errorQuda("blas has not been built for Nspin=%d fields", x.Nspin());
}
} else {
if constexpr (is_enabled_spin(1)) {
Blas<Functor, store_t, y_store_t, 1, T>(a, b, c, x, args...);
Expand Down
35 changes: 34 additions & 1 deletion include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ namespace quda

struct ColorSpinorParam : public LatticeFieldParam {
int nColor = 0; // Number of colors of the field
int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor
int nSpin = 0; // =1 for staggered, =2 for coarse Dslash and chiral overlap Dslash, =4 for 4d spinor
int nVec = 1; // number of packed vectors (for multigrid transfer operator)
int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded)

Expand Down Expand Up @@ -1103,6 +1103,39 @@ namespace quda
*/
void spinorDistanceReweight(ColorSpinorField &src, double alpha0, int t0);

/**
@brief Reconstruct a chiral spinor into a full spinor
@param[out] dst The reconstructed full spinor nSpin = 4
@param[in] src The chiral spinor nSpin = 2
@param[in] chirality The chirality of the reconstruction
*/
void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality);

/**
@brief Reconstruct two chiral spinors into a full spinor
@param[out] dst The reconstructed full spinor nSpin = 4
@param[in] src_left The left chirality part nSpin = 2
@param[in] src_right The right chirality part nSpin = 2
*/
void spinorChiralReconstruct(ColorSpinorField &dst, const ColorSpinorField &src_left,
const ColorSpinorField &src_right);

/**
@brief Project a full spinor to a chiral spinor
@param[out] dst The projected chiral spinor nSpin = 2
@param[in] src The full spinor nSpin = 4
@param[in] chirality The chirality of the projection
*/
void spinorChiralProject(ColorSpinorField &dst, const ColorSpinorField &src, QudaChirality chirality);

/**
@brief Project a full spinor to two chiral spinors
@param[out] dst_left The projected left chirality part nSpin = 2
@param[out] dst_right The projected left chirality part nSpin = 2
@param[in] src The full spinor nSpin = 4
*/
void spinorChiralProject(ColorSpinorField &dst_left, ColorSpinorField &dst_right, const ColorSpinorField &src);

/**
@brief Helper function for determining if the spin of the fields is the same.
@param[in] a Input field
Expand Down
97 changes: 97 additions & 0 deletions include/dirac_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <blas_quda.h>
#include <field_cache.h>
#include <memory>
#include <overlap_kernel.h>

namespace quda {

Expand Down Expand Up @@ -67,6 +68,8 @@ namespace quda {

bool use_mobius_fused_kernel; // Whether or not use fused kernels for Mobius

OverlapKernel *overlap_kernel;

double distance_pc_alpha0; // used by distance preconditioning
int distance_pc_t0; // used by distance preconditioning

Expand Down Expand Up @@ -149,6 +152,7 @@ namespace quda {
class DiracMMdag;
class DiracMdag;
class DiracG5M;
class DiracMdagMChiral;
//Forward declaration of multigrid Transfer class
class Transfer;

Expand All @@ -162,6 +166,7 @@ namespace quda {
friend class DiracMMdag;
friend class DiracMdag;
friend class DiracG5M;
friend class DiracMdagMChiral;

protected:
GaugeField *gauge;
Expand Down Expand Up @@ -350,6 +355,14 @@ namespace quda {
*/
virtual void MMdag(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const;

/**
@brief Apply MdagM on single chirality
*/
virtual void MdagMChiral(cvector_ref<ColorSpinorField> &, cvector_ref<const ColorSpinorField> &, QudaChirality) const
{
errorQuda("Not implemented!");
}

/**
@brief Prepare the source and solution vectors for solving given the solution type
@param[out] out Prepared solution vectors
Expand Down Expand Up @@ -1406,6 +1419,47 @@ namespace quda {
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override;
};

// Full overlap
class DiracOverlap : public Dirac
{

protected:
OverlapKernel *overlap_kernel;

public:
DiracOverlap(const DiracParam &param);
DiracOverlap(const DiracOverlap &dirac);
virtual ~DiracOverlap();
DiracOverlap &operator=(const DiracOverlap &dirac);

virtual void Dslash(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
QudaParity parity) const override;
virtual void DslashXpay(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
QudaParity parity, cvector_ref<const ColorSpinorField> &x, double k) const override;
virtual void M(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override;
virtual void MdagM(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override;
virtual void MdagMChiral(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
QudaChirality chirality) const override;

virtual void prepare(cvector_ref<ColorSpinorField> &out, cvector_ref<ColorSpinorField> &in,
cvector_ref<ColorSpinorField> &x, cvector_ref<const ColorSpinorField> &b,
const QudaSolutionType solType) const override;
virtual void reconstruct(cvector_ref<ColorSpinorField> &x, cvector_ref<const ColorSpinorField> &b,
const QudaSolutionType solType) const override;

virtual int getStencilSteps() const override { return 2 * (overlap_kernel->remez_order[0] + 1) + 1; }
virtual QudaDiracType getDiracType() const { return QUDA_OVERLAP_DIRAC; }

/**
@brief If managed memory and prefetch is enabled, prefetch
all relevant memory fields (gauge, clover, temporary spinors)
to the CPU or GPU as requested
@param[in] mem_space Memory space we are prefetching to
@param[in] stream Which stream to run the prefetch in (default 0)
*/
virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const;
};

// Full staggered
class DiracStaggered : public Dirac
{
Expand Down Expand Up @@ -2499,6 +2553,7 @@ namespace quda {
case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC:
case QUDA_TWISTED_MASS_DIRAC:
case QUDA_TWISTED_CLOVER_DIRAC:
case QUDA_OVERLAP_DIRAC:
// while the twisted ops don't have a Hermitian indefinite spectrum, they
// do have a spectrum of the form (real) + i mu
gamma5(vec, vec);
Expand Down Expand Up @@ -2584,6 +2639,8 @@ namespace quda {
|| dirac_type == QUDA_GAUGE_COVDEV_DIRAC)
return true;

if (dirac_type == QUDA_WILSON_DIRAC || dirac_type == QUDA_CLOVER_DIRAC) return true;

// subtle: odd operator gets a minus sign
if ((dirac_type == QUDA_STAGGEREDPC_DIRAC || dirac_type == QUDA_ASQTADPC_DIRAC)
&& (pc_type == QUDA_MATPC_EVEN_EVEN || pc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC))
Expand All @@ -2593,6 +2650,46 @@ namespace quda {
}
};

/**
Gloms onto a DiracMatrix and provides an operator() for its MdagMChiral method
*/
class DiracMdagMChiral : public DiracMatrix
{
protected:
QudaChirality chirality;

public:
DiracMdagMChiral(const Dirac &d) : DiracMatrix(d) { }
DiracMdagMChiral(const Dirac *d) : DiracMatrix(d) { }

/**
@brief Multi-RHS operator application.
@param[out] out The vector of output fields
@param[in] in The vector of input fields
*/
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->MdagMChiral(out, in, chirality);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override
{
if (dirac->getDiracType() == QUDA_OVERLAP_DIRAC) {
return dirac->getStencilSteps();
} else {
return dirac->getStencilSteps() * 2; // 2 for M and M dagger
}
}

/**
@brief return if the operator is HPD
*/
virtual bool hermitian() const override { return true; }

void setChirality(QudaChirality chirality_in) { chirality = chirality_in; }
};

/**
* Create the Dirac operator. By default, we also create operators with possibly different
* precisions: Sloppy, and Preconditioner.
Expand Down
10 changes: 10 additions & 0 deletions include/enum_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ typedef enum QudaDslashType_s {
QUDA_DOMAIN_WALL_4D_DSLASH,
QUDA_MOBIUS_DWF_DSLASH,
QUDA_MOBIUS_DWF_EOFA_DSLASH,
QUDA_OVERLAP_DSLASH,
QUDA_STAGGERED_DSLASH,
QUDA_ASQTAD_DSLASH,
QUDA_TWISTED_MASS_DSLASH,
Expand Down Expand Up @@ -173,6 +174,8 @@ typedef enum QudaSolveType_s {
QUDA_NORMOP_PC_SOLVE,
QUDA_NORMERR_SOLVE,
QUDA_NORMERR_PC_SOLVE,
QUDA_NORMOP_CHIRAL_SOLVE,
QUDA_NORMERR_CHIRAL_SOLVE = QUDA_NORMOP_CHIRAL_SOLVE,
QUDA_NORMEQ_SOLVE = QUDA_NORMOP_SOLVE, // deprecated
QUDA_NORMEQ_PC_SOLVE = QUDA_NORMOP_PC_SOLVE, // deprecated
QUDA_INVALID_SOLVE = QUDA_INVALID_ENUM
Expand Down Expand Up @@ -309,6 +312,7 @@ typedef enum QudaDiracType_s {
QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC,
QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC,
QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC,
QUDA_OVERLAP_DIRAC,
QUDA_STAGGERED_DIRAC,
QUDA_STAGGEREDPC_DIRAC,
QUDA_STAGGEREDKD_DIRAC,
Expand Down Expand Up @@ -637,6 +641,12 @@ typedef enum QudaExtLibType_s {
QUDA_EXTLIB_INVALID = QUDA_INVALID_ENUM
} QudaExtLibType;

typedef enum QudaChirality_s {
QUDA_LEFT_CHIRALITY = -1, // (1 - \gamma_5) / 2
QUDA_RIGHT_CHIRALITY = +1, // (1 + \gamma_5) / 2
QUDA_INVALID_CHIRALITY = QUDA_INVALID_ENUM
} QudaChirality;

typedef enum QudaDDType_s { QUDA_DD_NO, QUDA_DD_RED_BLACK, QUDA_DD_INVALID = QUDA_INVALID_ENUM } QudaDDType;

typedef enum QudaWFlowStepType_s {
Expand Down
Loading