update a test for checking zero ROCm GPU event#585
update a test for checking zero ROCm GPU event#585cj401-amd wants to merge 11 commits intorocm-jaxlib-v0.8.0from
Conversation
tests/profiler_test.py
Outdated
| a = jnp.ones((1024, 1024), dtype=jnp.float16) | ||
| b = jnp.ones((1024, 1024), dtype=jnp.float16) |
There was a problem hiding this comment.
oh, just curious is it important to check this on large matrices? I mean, the goal of the test is to just check if GPU events are seen, so it should be irrelevant what matrices it uses. So, IIUC, just by using large matrices here, we only prolong test runtime and waste resources?
There was a problem hiding this comment.
IIRC we used to have no gpu events on small matrices?
Instead of just one big size, maybe it's better to test a number of permutation sizes, from 1024, 512, 128, 64 to 32. This gives us better guarantee on profiling robustness.
There was a problem hiding this comment.
hmm, so there might be some dynamic routing to a backend depending on a size of operation?
jax.jit() has a backend parameter, which could be set to "gpu" to explicitly request GPU to do the work. This should never be re-routed to a CPU, for example, so this should keep the code simple, clean, and still very lightweight. Is this correct?
There was a problem hiding this comment.
This ticket SWDEV-568283 and here https://github.com/ROCm/xla/blob/rocm-jaxlib-v0.7.1/xla/backends/profiler/gpu/rocm_profiler_sdk.cc#L427 a pre-fixed size and we haven't flushed properly. But this should be fixed here openxla/xla#34968 and it's backported to 0.8.0 as well. So it's best to check more matrices size on UTs just in case for future changes.
There was a problem hiding this comment.
oh, thanks... so it happened because of profiler integration bugs? Oh, then absolutely agree, there should be several tests for different sizes starting from the smallest...
ADDED: remove the approval until resolved. Thanks Chao for noticing!
There was a problem hiding this comment.
And since this gpu kernels, even small matrices will have similar gpu kenrels launch anway...
tests/profiler_test.py
Outdated
|
|
||
| r = subprocess.run([sys.executable, "-c", code], | ||
| env=env, capture_output=True, text=True) | ||
| if r.returncode != 0: |
There was a problem hiding this comment.
not necessary for this PR, but for the future: https://docs.python.org/3.11/library/subprocess.html#subprocess.run subprocess.run has check=True argument that does the same error checking and throwing under the hood, so you don't have to.
Arech8
left a comment
There was a problem hiding this comment.
Thanks Chunyu for this test, it's really important to have it!
I'm approving it, but before merging, please consider my comment about matrices size: if I didn't miss anything, we could safe quite a compute on making it something trivially small like 8x8 instead of 1024x1024.
After the merge, please (this is really important!) also make a PR into the upstream jax-ml/jax. You might probably want to make the test disabled by default, if the profiling support it requires isn't merged to the upstream XLA yet. We'll enable it later once changes to XLA propagates to the upstream JAX XLA commit.
|
@cj401-amd @Arech8 is this PR ready to merge? |
|
@cj401-amd I realized that you have test failing in your PR description. is that the current outcome of this test? |
Yes. it's ready to be merged. Previously, I posted the message showing failed test, which indicates no ROCm GPU profiling. so it can help catch the case of no GPU events. |
does it require the local XLA path to be having those GPU events to be showing up? I want to merge this but I have to be sure that it is also working if we don't specify local XLA path. |
I believe so. the upstream is here jax-ml#34135. |
Motivation
for kernel_details test, it requires to build
jaxliblike the followings, otherwise trace file might miss kernel_details.Run specific ROCm profiler test