Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 117 additions & 5 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def _run_scoped_resource_estimator(

@_register_resource_estimator(lax.reduce_sum_p)
@_register_resource_estimator(lax.reduce_max_p)
@_register_resource_estimator(lax.reduce_min_p)
@_register_resource_estimator(lax.reduce_prod_p)
def _reduce_resource_estimator(
ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes,
**kwargs
Expand Down Expand Up @@ -2404,6 +2406,79 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)


@register_lowering_rule(lax.abs_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.abs_p, mgpu.LoweringSemantics.Warpgroup)
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).abs()
x = _ensure_ir_value(x, x_aval.dtype)
if jax.numpy.issubdtype(x_aval.dtype, jax.numpy.floating):
return math_dialect.absf(x)
elif jax.numpy.issubdtype(x_aval.dtype, jax.numpy.integer):
return math_dialect.absi(x)
else:
raise NotImplementedError(f"Unsupported dtype for abs: {x_aval.dtype}")


@register_lowering_rule(lax.sign_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.sign_p, mgpu.LoweringSemantics.Warpgroup)
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).sign()
x = _ensure_ir_value(x, x_aval.dtype)
if jax.numpy.issubdtype(x_aval.dtype, jax.numpy.floating):
# For floats: sign(x) = copysign(1.0, x) if x != 0 else 0
mlir_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
one = arith_dialect.constant(mlir_dtype, ir.FloatAttr.get(mlir_dtype, 1.0))
zero = arith_dialect.constant(mlir_dtype, ir.FloatAttr.get(mlir_dtype, 0.0))
if ir.VectorType.isinstance(x.type):
one = vector_dialect.broadcast(x.type, one)
zero = vector_dialect.broadcast(x.type, zero)
signed_one = math_dialect.copysign(one, x)
is_nonzero = arith_dialect.cmpf(arith_dialect.CmpFPredicate.ONE, x, zero)
return arith_dialect.select(is_nonzero, signed_one, zero)
elif jax.numpy.issubdtype(x_aval.dtype, jax.numpy.integer):
mlir_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
zero = arith_dialect.constant(mlir_dtype, 0)
if ir.VectorType.isinstance(x.type):
zero = vector_dialect.broadcast(x.type, zero)
if jax.numpy.issubdtype(x_aval.dtype, jax.numpy.signedinteger):
# For signed integers: sign(x) = (x > 0) - (x < 0)
pos = arith_dialect.cmpi(arith_dialect.CmpIPredicate.sgt, x, zero)
neg = arith_dialect.cmpi(arith_dialect.CmpIPredicate.slt, x, zero)
pos_ext = arith_dialect.extui(x.type, pos)
neg_ext = arith_dialect.extui(x.type, neg)
return arith_dialect.subi(pos_ext, neg_ext)
else:
# For unsigned integers: sign(x) = (x > 0) ? 1 : 0
pos = arith_dialect.cmpi(arith_dialect.CmpIPredicate.ugt, x, zero)
return arith_dialect.extui(x.type, pos)
else:
raise NotImplementedError(f"Unsupported dtype for sign: {x_aval.dtype}")


@register_lowering_rule(lax.erf_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.erf_p, mgpu.LoweringSemantics.Warpgroup)
def _erf_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
return _ensure_fa(x, x_aval.dtype).erf()
return math_dialect.erf(_ensure_ir_value(x, x_aval.dtype))


@register_lowering_rule(lax.atan2_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.atan2_p, mgpu.LoweringSemantics.Warpgroup)
def _atan2_lowering_rule(ctx: LoweringRuleContext, y, x):
[y_aval, x_aval] = ctx.avals_in
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
return _ensure_fa(y, y_aval.dtype).atan2(_ensure_fa(x, x_aval.dtype))
return math_dialect.atan2(
_ensure_ir_value(y, y_aval.dtype), _ensure_ir_value(x, x_aval.dtype)
)


@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Lane)
def _reshape_lowering_rule(
ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding
Expand Down Expand Up @@ -2492,6 +2567,12 @@ def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane)(
functools.partial(_reduce_lowering_rule, "max")
)
register_lowering_rule(lax.reduce_min_p, mgpu.LoweringSemantics.Lane)(
functools.partial(_reduce_lowering_rule, "min")
)
register_lowering_rule(lax.reduce_prod_p, mgpu.LoweringSemantics.Lane)(
functools.partial(_reduce_lowering_rule, "prod")
)

def _reduce_lowering_rule_wg(
kind: vector_dialect.CombiningKind,
Expand Down Expand Up @@ -2538,17 +2619,48 @@ def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = vector_dialect.CombiningKind.MAXIMUMF
acc = float("-inf")
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = vector_dialect.CombiningKind.MAXSI
acc = np.iinfo(x_aval.dtype).max
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
kind = vector_dialect.CombiningKind.MAXUI
elif jnp.issubdtype(x_aval.dtype, jnp.integer):
if jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = vector_dialect.CombiningKind.MAXSI
else:
kind = vector_dialect.CombiningKind.MAXUI
acc = np.iinfo(x_aval.dtype).min
else:
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result


@register_lowering_rule(lax.reduce_min_p, mgpu.LoweringSemantics.Warpgroup)
def _reduce_min_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = vector_dialect.CombiningKind.MINIMUMF
acc = float("inf")
elif jnp.issubdtype(x_aval.dtype, jnp.integer):
if jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
kind = vector_dialect.CombiningKind.MINSI
else:
kind = vector_dialect.CombiningKind.MINUI
acc = np.iinfo(x_aval.dtype).max
else:
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result


@register_lowering_rule(lax.reduce_prod_p, mgpu.LoweringSemantics.Warpgroup)
def _reduce_prod_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
[x_aval] = ctx.avals_in
if jnp.issubdtype(x_aval.dtype, jnp.floating):
kind = vector_dialect.CombiningKind.MUL
acc = 1.0
elif jnp.issubdtype(x_aval.dtype, jnp.integer):
kind = vector_dialect.CombiningKind.MUL
acc = 1
else:
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result


def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value:
result = gpu_dialect.block_id(dim)
cluster_size = ctx.launch_ctx.cluster_size
Expand Down
6 changes: 5 additions & 1 deletion jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def _conversion_op_lowering_rule(
def _unary_op_lowering_rule(
_: LoweringContext,
op: Any,
impl: Callable[..., fa.FragmentedArray],
impl: Any,
is_signed: bool | None = None,
) -> Sequence[ir.Value]:
in_layouts = inference_utils.in_layouts(op)
Expand All @@ -1116,6 +1116,9 @@ def _unary_op_lowering_rule(
(mlir_math.CosOp, fa.FragmentedArray.cos, None),
(mlir_math.LogOp, fa.FragmentedArray.log, None),
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
(mlir_math.AbsFOp, fa.FragmentedArray.abs, None),
(mlir_math.AbsIOp, fa.FragmentedArray.abs, True),
(mlir_math.ErfOp, fa.FragmentedArray.erf, None),
]:
_lowerings[op.OPERATION_NAME] = functools.partial(
_unary_op_lowering_rule, impl=unary_impl, is_signed=is_signed
Expand Down Expand Up @@ -1161,6 +1164,7 @@ def _binary_op_lowering_rule(
(arith.MinSIOp, fa.FragmentedArray.min, True),
(arith.MinUIOp, fa.FragmentedArray.min, False),
(arith.MinimumFOp, fa.FragmentedArray.min, None),
(mlir_math.Atan2Op, fa.FragmentedArray.atan2, None),
]:
_lowerings[op.OPERATION_NAME] = functools.partial(
_binary_op_lowering_rule, impl=binary_impl, is_signed=is_signed
Expand Down
81 changes: 81 additions & 0 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,63 @@ def rsqrt(self, *, approx: bool = False) -> FragmentedArray:
self._lift_fast_instr("rsqrt.approx.f32") if approx else mlir_math.rsqrt
)

def abs(self) -> FragmentedArray:
if ir.FloatType.isinstance(self.mlir_dtype):
return self._pointwise(mlir_math.absf)
elif ir.IntegerType.isinstance(self.mlir_dtype):
return self._pointwise(mlir_math.absi)
else:
raise NotImplementedError(self.mlir_dtype)

def sign(self) -> FragmentedArray:
if ir.FloatType.isinstance(self.mlir_dtype):
# For floats: sign(x) = copysign(1.0, x) if x != 0 else 0
# We use: sign(x) = copysign(1.0, x) * (x != 0)
dtype = self.mlir_dtype
one = arith.constant(dtype, ir.FloatAttr.get(dtype, 1.0))
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
def float_sign(x):
one_val = one
zero_val = zero
if ir.VectorType.isinstance(x.type):
one_val = vector.broadcast(x.type, one)
zero_val = vector.broadcast(x.type, zero)
signed_one = mlir_math.copysign(one_val, x)
is_nonzero = arith.cmpf(arith.CmpFPredicate.ONE, x, zero_val)
return arith.select(is_nonzero, signed_one, zero_val)
return self._pointwise(float_sign)
elif ir.IntegerType.isinstance(self.mlir_dtype):
# For integers: sign(x) = (x > 0) - (x < 0)
int_dtype = self.mlir_dtype
zero_scalar = arith.constant(int_dtype, 0)
is_signed = self.is_signed
def int_sign(x):
zero = zero_scalar
if ir.VectorType.isinstance(x.type):
zero = vector.broadcast(x.type, zero_scalar)
if is_signed:
pos = arith.cmpi(arith.CmpIPredicate.sgt, x, zero)
neg = arith.cmpi(arith.CmpIPredicate.slt, x, zero)
else:
pos = arith.cmpi(arith.CmpIPredicate.ugt, x, zero)
neg = arith.cmpi(arith.CmpIPredicate.ult, x, zero)
pos_ext = arith.extui(x.type, pos)
neg_ext = arith.extui(x.type, neg)
return arith.subi(pos_ext, neg_ext)
return self._pointwise(int_sign, output_is_signed=is_signed)
else:
raise NotImplementedError(self.mlir_dtype)

def erf(self) -> FragmentedArray:
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError(self.mlir_dtype)
return self._pointwise(mlir_math.erf)

def atan2(self, other: FragmentedArray) -> FragmentedArray:
if not ir.FloatType.isinstance(self.mlir_dtype):
raise NotImplementedError(self.mlir_dtype)
return self._pointwise(mlir_math.atan2, other)

@staticmethod
def _lift_fast_instr(
instr: str | Callable[[ir.Value], ir.Value],
Expand Down Expand Up @@ -2263,6 +2320,30 @@ def reduce(
else:
raise NotImplementedError(self.mlir_dtype)
splat_op = lambda x: x
case "min":
if ir.F32Type.isinstance(self.mlir_dtype):
op = self._lift_fast_instr("min.NaN.f32")
elif ir.FloatType.isinstance(self.mlir_dtype):
op = arith.minimumf
elif ir.IntegerType.isinstance(self.mlir_dtype):
op = arith.minsi if self.is_signed else arith.minui
else:
raise NotImplementedError(self.mlir_dtype)
splat_op = lambda x: x
case "prod":
reduced_elems = math.prod(self.shape[a] for a in axis)
if ir.FloatType.isinstance(self.mlir_dtype):
op = arith.mulf
# For splat, prod(x, x, ..., x) = x^n
splat_op = lambda x: mlir_math.powf(
x, c(float(reduced_elems), x.type)
)
elif ir.IntegerType.isinstance(self.mlir_dtype):
op = arith.muli
# For integer splat, we cannot easily compute x^n, so raise error
splat_op = None
else:
raise NotImplementedError(self.mlir_dtype)
case _:
raise ValueError(f"Unrecognized reduction operator: {op}")
assert not isinstance(op, str)
Expand Down
62 changes: 62 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3502,6 +3502,8 @@ def kernel(ctx, dst, _):
(lambda x: mgpu.FragmentedArray.sin(x), np.sin),
(lambda x: mgpu.FragmentedArray.cos(x), np.cos),
(lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt),
(lambda x: mgpu.FragmentedArray.abs(x), np.abs),
(lambda x: mgpu.FragmentedArray.erf(x), jax.scipy.special.erf),
],
approx=[False, True],
)
Expand All @@ -3520,6 +3522,37 @@ def kernel(ctx, dst, _):
rtol = 4e-6 if approx else 2e-7
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)

@parameterized.product(
dtype=[jnp.float32, jnp.int32, jnp.uint32],
)
def test_sign(self, dtype, m=64, n=32):
def kernel(ctx, dst, _):
# Use values that include negative, zero, and positive
iota = iota_tensor(m, n, dtype)
shifted = iota - (m * n // 2) # Center around zero
shifted.sign().store_untiled(dst, optimized=False)

out_shape = jax.ShapeDtypeStruct((m, n), dtype)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=dtype).reshape(m, n) - (m * n // 2)
np.testing.assert_array_equal(result, np.sign(x))

def test_atan2(self, m=64, n=32):
def kernel(ctx, dst, _):
y = iota_tensor(m, n, jnp.float32) + 1 # Avoid zero
x = iota_tensor(m, n, jnp.float32) + 2
y.atan2(x).store_untiled(dst, optimized=False)

out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
y = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 2
np.testing.assert_allclose(result, np.arctan2(y, x), atol=2e-7, rtol=2e-7)

def test_strided_copy_noncontig_good(self):
def kernel(ctx, src, dst, _):
src_slice = mgpu.memref_slice(src, (slice(None), 1))
Expand Down Expand Up @@ -3687,6 +3720,35 @@ def kernel(ctx, dst, _):
raise NotImplementedError(f"Unsupported op: {op}")
np.testing.assert_array_equal(result, expected)

@parameterized.product(
op=("min", "prod"),
m=(64,),
n=(32, 64),
)
def test_reduce_min_prod(self, op, m, n):
def kernel(ctx, dst, _):
# Use smaller values for prod to avoid overflow
iota = iota_tensor(m, n, jnp.float32)
if op == "prod":
# Normalize to values near 1 to avoid overflow in product
iota = iota / (m * n) + 0.5
iota.reduce(op, axis=1).broadcast_in_dim(
(m, n), (0,), mgpu.WGMMA_LAYOUT
).store_untiled(dst, optimized=False)
out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32)
result = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()
)()
x = np.arange(m * n, dtype=jnp.float32).reshape(m, n)
if op == "min":
expected = np.broadcast_to(x.min(axis=1, keepdims=True), x.shape)
elif op == "prod":
x_normalized = x / (m * n) + 0.5
expected = np.broadcast_to(x_normalized.prod(axis=1, keepdims=True), x.shape)
else:
raise NotImplementedError(f"Unsupported op: {op}")
np.testing.assert_allclose(result, expected, rtol=1e-5)

def test_splat_layout(self):
m, n = 64, 8
def kernel(ctx, dst, _):
Expand Down