I believe it does, at least for the static range(20)
in the notebook (Edit: this is noted in πͺ JAX - The Sharp Bits πͺ β JAX documentation, among other places). Using the incantation in print optimized code Β· Discussion #7068 Β· google/jax Β· GitHub to look at the optimized HLO:
In [24]: print(module.to_string(xla_ext.HloPrintOptions.short_parsable()))
HloModule xla_computation_run_jax_kernel.7
fused_computation.clone {
param_1.9 = c64[2000,2000]{1,0} parameter(1)
multiply.20 = c64[2000,2000]{1,0} multiply(param_1.9, param_1.9)
add.20 = c64[2000,2000]{1,0} add(multiply.20, param_1.9)
multiply.21 = c64[2000,2000]{1,0} multiply(add.20, add.20)
add.21 = c64[2000,2000]{1,0} add(multiply.21, param_1.9)
multiply.22 = c64[2000,2000]{1,0} multiply(add.21, add.21)
add.22 = c64[2000,2000]{1,0} add(multiply.22, param_1.9)
multiply.23 = c64[2000,2000]{1,0} multiply(add.22, add.22)
add.23 = c64[2000,2000]{1,0} add(multiply.23, param_1.9)
multiply.24 = c64[2000,2000]{1,0} multiply(add.23, add.23)
add.24 = c64[2000,2000]{1,0} add(multiply.24, param_1.9)
multiply.25 = c64[2000,2000]{1,0} multiply(add.24, add.24)
add.25 = c64[2000,2000]{1,0} add(multiply.25, param_1.9)
multiply.26 = c64[2000,2000]{1,0} multiply(add.25, add.25)
add.26 = c64[2000,2000]{1,0} add(multiply.26, param_1.9)
multiply.28 = c64[2000,2000]{1,0} multiply(add.26, add.26)
add.27 = c64[2000,2000]{1,0} add(multiply.28, param_1.9)
multiply.29 = c64[2000,2000]{1,0} multiply(add.27, add.27)
add.29 = c64[2000,2000]{1,0} add(multiply.29, param_1.9)
multiply.30 = c64[2000,2000]{1,0} multiply(add.29, add.29)
add.30 = c64[2000,2000]{1,0} add(multiply.30, param_1.9)
multiply.31 = c64[2000,2000]{1,0} multiply(add.30, add.30)
add.31 = c64[2000,2000]{1,0} add(multiply.31, param_1.9)
multiply.32 = c64[2000,2000]{1,0} multiply(add.31, add.31)
add.32 = c64[2000,2000]{1,0} add(multiply.32, param_1.9)
multiply.33 = c64[2000,2000]{1,0} multiply(add.32, add.32)
add.33 = c64[2000,2000]{1,0} add(multiply.33, param_1.9)
multiply.34 = c64[2000,2000]{1,0} multiply(add.33, add.33)
add.34 = c64[2000,2000]{1,0} add(multiply.34, param_1.9)
multiply.35 = c64[2000,2000]{1,0} multiply(add.34, add.34)
add.35 = c64[2000,2000]{1,0} add(multiply.35, param_1.9)
multiply.36 = c64[2000,2000]{1,0} multiply(add.35, add.35)
add.36 = c64[2000,2000]{1,0} add(multiply.36, param_1.9)
multiply.37 = c64[2000,2000]{1,0} multiply(add.36, add.36)
add.37 = c64[2000,2000]{1,0} add(multiply.37, param_1.9)
multiply.38 = c64[2000,2000]{1,0} multiply(add.37, add.37)
add.38 = c64[2000,2000]{1,0} add(multiply.38, param_1.9)
multiply.39 = c64[2000,2000]{1,0} multiply(add.38, add.38)
add.39 = c64[2000,2000]{1,0} add(multiply.39, param_1.9)
multiply.41 = c64[2000,2000]{1,0} multiply(add.39, add.39)
add.40 = c64[2000,2000]{1,0} add(multiply.41, param_1.9)
abs.40 = f32[2000,2000]{1,0} abs(add.40)
constant.46 = f32[] constant(2)
broadcast.47 = f32[2000,2000]{1,0} broadcast(constant.46), dimensions={}
compare.89 = pred[2000,2000]{1,0} compare(abs.40, broadcast.47), direction=GT
abs.39 = f32[2000,2000]{1,0} abs(add.39)
compare.87 = pred[2000,2000]{1,0} compare(abs.39, broadcast.47), direction=GT
abs.38 = f32[2000,2000]{1,0} abs(add.38)
compare.85 = pred[2000,2000]{1,0} compare(abs.38, broadcast.47), direction=GT
abs.37 = f32[2000,2000]{1,0} abs(add.37)
compare.81 = pred[2000,2000]{1,0} compare(abs.37, broadcast.47), direction=GT
abs.36 = f32[2000,2000]{1,0} abs(add.36)
compare.79 = pred[2000,2000]{1,0} compare(abs.36, broadcast.47), direction=GT
abs.35 = f32[2000,2000]{1,0} abs(add.35)
compare.77 = pred[2000,2000]{1,0} compare(abs.35, broadcast.47), direction=GT
abs.34 = f32[2000,2000]{1,0} abs(add.34)
compare.75 = pred[2000,2000]{1,0} compare(abs.34, broadcast.47), direction=GT
abs.33 = f32[2000,2000]{1,0} abs(add.33)
compare.73 = pred[2000,2000]{1,0} compare(abs.33, broadcast.47), direction=GT
abs.32 = f32[2000,2000]{1,0} abs(add.32)
compare.71 = pred[2000,2000]{1,0} compare(abs.32, broadcast.47), direction=GT
abs.31 = f32[2000,2000]{1,0} abs(add.31)
compare.67 = pred[2000,2000]{1,0} compare(abs.31, broadcast.47), direction=GT
abs.30 = f32[2000,2000]{1,0} abs(add.30)
compare.65 = pred[2000,2000]{1,0} compare(abs.30, broadcast.47), direction=GT
abs.28 = f32[2000,2000]{1,0} abs(add.29)
compare.63 = pred[2000,2000]{1,0} compare(abs.28, broadcast.47), direction=GT
abs.27 = f32[2000,2000]{1,0} abs(add.27)
compare.61 = pred[2000,2000]{1,0} compare(abs.27, broadcast.47), direction=GT
abs.26 = f32[2000,2000]{1,0} abs(add.26)
compare.59 = pred[2000,2000]{1,0} compare(abs.26, broadcast.47), direction=GT
abs.25 = f32[2000,2000]{1,0} abs(add.25)
compare.55 = pred[2000,2000]{1,0} compare(abs.25, broadcast.47), direction=GT
abs.24 = f32[2000,2000]{1,0} abs(add.24)
compare.53 = pred[2000,2000]{1,0} compare(abs.24, broadcast.47), direction=GT
abs.23 = f32[2000,2000]{1,0} abs(add.23)
compare.51 = pred[2000,2000]{1,0} compare(abs.23, broadcast.47), direction=GT
abs.22 = f32[2000,2000]{1,0} abs(add.22)
compare.49 = pred[2000,2000]{1,0} compare(abs.22, broadcast.47), direction=GT
abs.21 = f32[2000,2000]{1,0} abs(add.21)
compare.47 = pred[2000,2000]{1,0} compare(abs.21, broadcast.47), direction=GT
abs.20 = f32[2000,2000]{1,0} abs(add.20)
compare.45 = pred[2000,2000]{1,0} compare(abs.20, broadcast.47), direction=GT
param_0.4 = s32[2000,2000]{1,0} parameter(0)
constant.45 = s32[] constant(20)
broadcast.46 = s32[2000,2000]{1,0} broadcast(constant.45), dimensions={}
compare.42 = pred[2000,2000]{1,0} compare(param_0.4, broadcast.46), direction=EQ
and.20 = pred[2000,2000]{1,0} and(compare.45, compare.42)
constant.44 = s32[] constant(0)
broadcast.45 = s32[2000,2000]{1,0} broadcast(constant.44), dimensions={}
select.41 = s32[2000,2000]{1,0} select(and.20, broadcast.45, param_0.4)
compare.46 = pred[2000,2000]{1,0} compare(select.41, broadcast.46), direction=EQ
and.21 = pred[2000,2000]{1,0} and(compare.47, compare.46)
constant.47 = s32[] constant(1)
broadcast.48 = s32[2000,2000]{1,0} broadcast(constant.47), dimensions={}
select.42 = s32[2000,2000]{1,0} select(and.21, broadcast.48, select.41)
compare.48 = pred[2000,2000]{1,0} compare(select.42, broadcast.46), direction=EQ
and.22 = pred[2000,2000]{1,0} and(compare.49, compare.48)
constant.48 = s32[] constant(2)
broadcast.49 = s32[2000,2000]{1,0} broadcast(constant.48), dimensions={}
select.43 = s32[2000,2000]{1,0} select(and.22, broadcast.49, select.42)
compare.50 = pred[2000,2000]{1,0} compare(select.43, broadcast.46), direction=EQ
and.23 = pred[2000,2000]{1,0} and(compare.51, compare.50)
constant.49 = s32[] constant(3)
broadcast.51 = s32[2000,2000]{1,0} broadcast(constant.49), dimensions={}
select.44 = s32[2000,2000]{1,0} select(and.23, broadcast.51, select.43)
compare.52 = pred[2000,2000]{1,0} compare(select.44, broadcast.46), direction=EQ
and.24 = pred[2000,2000]{1,0} and(compare.53, compare.52)
constant.50 = s32[] constant(4)
broadcast.52 = s32[2000,2000]{1,0} broadcast(constant.50), dimensions={}
select.45 = s32[2000,2000]{1,0} select(and.24, broadcast.52, select.44)
compare.54 = pred[2000,2000]{1,0} compare(select.45, broadcast.46), direction=EQ
and.25 = pred[2000,2000]{1,0} and(compare.55, compare.54)
constant.51 = s32[] constant(5)
broadcast.53 = s32[2000,2000]{1,0} broadcast(constant.51), dimensions={}
select.46 = s32[2000,2000]{1,0} select(and.25, broadcast.53, select.45)
compare.58 = pred[2000,2000]{1,0} compare(select.46, broadcast.46), direction=EQ
and.26 = pred[2000,2000]{1,0} and(compare.59, compare.58)
constant.52 = s32[] constant(6)
broadcast.54 = s32[2000,2000]{1,0} broadcast(constant.52), dimensions={}
select.47 = s32[2000,2000]{1,0} select(and.26, broadcast.54, select.46)
compare.60 = pred[2000,2000]{1,0} compare(select.47, broadcast.46), direction=EQ
and.27 = pred[2000,2000]{1,0} and(compare.61, compare.60)
constant.53 = s32[] constant(7)
broadcast.55 = s32[2000,2000]{1,0} broadcast(constant.53), dimensions={}
select.48 = s32[2000,2000]{1,0} select(and.27, broadcast.55, select.47)
compare.62 = pred[2000,2000]{1,0} compare(select.48, broadcast.46), direction=EQ
and.28 = pred[2000,2000]{1,0} and(compare.63, compare.62)
constant.54 = s32[] constant(8)
broadcast.56 = s32[2000,2000]{1,0} broadcast(constant.54), dimensions={}
select.49 = s32[2000,2000]{1,0} select(and.28, broadcast.56, select.48)
compare.64 = pred[2000,2000]{1,0} compare(select.49, broadcast.46), direction=EQ
and.29 = pred[2000,2000]{1,0} and(compare.65, compare.64)
constant.55 = s32[] constant(9)
broadcast.57 = s32[2000,2000]{1,0} broadcast(constant.55), dimensions={}
select.50 = s32[2000,2000]{1,0} select(and.29, broadcast.57, select.49)
compare.66 = pred[2000,2000]{1,0} compare(select.50, broadcast.46), direction=EQ
and.30 = pred[2000,2000]{1,0} and(compare.67, compare.66)
constant.56 = s32[] constant(10)
broadcast.58 = s32[2000,2000]{1,0} broadcast(constant.56), dimensions={}
select.52 = s32[2000,2000]{1,0} select(and.30, broadcast.58, select.50)
compare.68 = pred[2000,2000]{1,0} compare(select.52, broadcast.46), direction=EQ
and.31 = pred[2000,2000]{1,0} and(compare.71, compare.68)
constant.57 = s32[] constant(11)
broadcast.59 = s32[2000,2000]{1,0} broadcast(constant.57), dimensions={}
select.53 = s32[2000,2000]{1,0} select(and.31, broadcast.59, select.52)
compare.72 = pred[2000,2000]{1,0} compare(select.53, broadcast.46), direction=EQ
and.33 = pred[2000,2000]{1,0} and(compare.73, compare.72)
constant.58 = s32[] constant(12)
broadcast.60 = s32[2000,2000]{1,0} broadcast(constant.58), dimensions={}
select.54 = s32[2000,2000]{1,0} select(and.33, broadcast.60, select.53)
compare.74 = pred[2000,2000]{1,0} compare(select.54, broadcast.46), direction=EQ
and.34 = pred[2000,2000]{1,0} and(compare.75, compare.74)
constant.59 = s32[] constant(13)
broadcast.61 = s32[2000,2000]{1,0} broadcast(constant.59), dimensions={}
select.55 = s32[2000,2000]{1,0} select(and.34, broadcast.61, select.54)
compare.76 = pred[2000,2000]{1,0} compare(select.55, broadcast.46), direction=EQ
and.35 = pred[2000,2000]{1,0} and(compare.77, compare.76)
constant.60 = s32[] constant(14)
broadcast.62 = s32[2000,2000]{1,0} broadcast(constant.60), dimensions={}
select.56 = s32[2000,2000]{1,0} select(and.35, broadcast.62, select.55)
compare.78 = pred[2000,2000]{1,0} compare(select.56, broadcast.46), direction=EQ
and.36 = pred[2000,2000]{1,0} and(compare.79, compare.78)
constant.61 = s32[] constant(15)
broadcast.64 = s32[2000,2000]{1,0} broadcast(constant.61), dimensions={}
select.57 = s32[2000,2000]{1,0} select(and.36, broadcast.64, select.56)
compare.80 = pred[2000,2000]{1,0} compare(select.57, broadcast.46), direction=EQ
and.37 = pred[2000,2000]{1,0} and(compare.81, compare.80)
constant.62 = s32[] constant(16)
broadcast.65 = s32[2000,2000]{1,0} broadcast(constant.62), dimensions={}
select.58 = s32[2000,2000]{1,0} select(and.37, broadcast.65, select.57)
compare.84 = pred[2000,2000]{1,0} compare(select.58, broadcast.46), direction=EQ
and.38 = pred[2000,2000]{1,0} and(compare.85, compare.84)
constant.63 = s32[] constant(17)
broadcast.66 = s32[2000,2000]{1,0} broadcast(constant.63), dimensions={}
select.59 = s32[2000,2000]{1,0} select(and.38, broadcast.66, select.58)
compare.86 = pred[2000,2000]{1,0} compare(select.59, broadcast.46), direction=EQ
and.39 = pred[2000,2000]{1,0} and(compare.87, compare.86)
constant.64 = s32[] constant(18)
broadcast.67 = s32[2000,2000]{1,0} broadcast(constant.64), dimensions={}
select.60 = s32[2000,2000]{1,0} select(and.39, broadcast.67, select.59)
compare.88 = pred[2000,2000]{1,0} compare(select.60, broadcast.46), direction=EQ
and.40 = pred[2000,2000]{1,0} and(compare.89, compare.88)
constant.65 = s32[] constant(19)
broadcast.68 = s32[2000,2000]{1,0} broadcast(constant.65), dimensions={}
ROOT select.61 = s32[2000,2000]{1,0} select(and.40, broadcast.68, select.60)
}
parallel_fusion {
p = s32[2000,2000]{1,0} parameter(0)
p.1 = c64[2000,2000]{1,0} parameter(1)
ROOT fusion.clone = s32[2000,2000]{1,0} fusion(p, p.1), kind=kLoop, calls=fused_computation.clone, outer_dimension_partitions={8}
}
ENTRY main.288 {
Arg_1.2 = s32[2000,2000]{1,0} parameter(1)
Arg_0.1 = c64[2000,2000]{1,0} parameter(0)
call = s32[2000,2000]{1,0} call(Arg_1.2, Arg_0.1), to_apply=parallel_fusion
ROOT tuple.287 = (s32[2000,2000]{1,0}) tuple(call)
}
The canonical form my have more low-level details, but it exceeds the Discourse character limit. Iβve dumped it in JAX Mandelbrot canonical HLO Β· GitHub.
Some more diagnostics using Add public API for computation cost analysis (flops, memory use, etc.) Β· Issue #10542 Β· google/jax Β· GitHub :
In [28]: module = comp.as_hlo_module()
In [29]: client = jax.lib.xla_bridge.get_backend()
In [30]: analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, module)
In [31]: analysis
Out[31]:
{'bytes accessed': 7871990784.0,
'bytes accessed operand 0 {}': 2720000000.0,
'bytes accessed operand 1 {}': 2320000000.0,
'bytes accessed operand 2 {}': 320000000.0,
'bytes accessed output {}': 2512000000.0,
'flops': 560000000.0,
'optimal_seconds': 0.0}