diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 7fcffd1..f91e77a 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -16,6 +16,7 @@ import asyncio from collections.abc import Mapping import dataclasses +import datetime import json import logging import os @@ -270,26 +271,44 @@ def start_trace( "features for Pathways on Cloud and may not be fully supported." ) - if jax.version.__version_info__ < (0, 9, 2) and profiler_options is not None: - _logger.warning( - "ProfileOptions are not supported until JAX 0.9.2 and will be omitted. " - "Some options can be specified via command line flags." - ) - profiler_options = None + if jax.version.__version_info__ < (0, 9, 2): + if profiler_options is not None: + _logger.warning( + "ProfileOptions are not supported until JAX 0.9.2 and will be omitted. " + "Some options can be specified via command line flags." + ) + profiler_options = None + else: + if profiler_options is None: + profiler_options = jax.profiler.ProfileOptions() + if not profiler_options.session_id: + profiler_options.session_id = datetime.datetime.now().strftime( + "%Y_%m_%d_%H_%M_%S" + ) profile_request = _create_profile_request( - log_dir, profiler_options, max_num_hosts=max_num_hosts + log_dir, + profiler_options, + max_num_hosts=max_num_hosts, ) _logger.debug("Profile request: %s", profile_request) _start_pathways_trace_from_profile_request(profile_request) - _original_start_trace( - log_dir=log_dir, - create_perfetto_link=create_perfetto_link, - create_perfetto_trace=create_perfetto_trace, - ) + if jax.version.__version_info__ >= (0, 9, 2): + _original_start_trace( + log_dir=log_dir, + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + profiler_options=profiler_options, + ) + else: + _original_start_trace( + log_dir=log_dir, + create_perfetto_link=create_perfetto_link, + create_perfetto_trace=create_perfetto_trace, + ) def stop_trace() -> None: diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index 5cb2fa8..c189fb0 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -15,6 +15,7 @@ import json import logging from unittest import mock +from typing import Any from absl.testing import absltest from absl.testing import parameterized @@ -53,6 +54,40 @@ def setUp(self): self.mock_original_stop_trace = self.enter_context( mock.patch.object(profiling, "_original_stop_trace", autospec=True) ) + self.mock_datetime = self.enter_context( + mock.patch.object(profiling.datetime, "datetime", autospec=True) + ) + self.mock_datetime.now.return_value.strftime.return_value = ( + "2026_06_04_05_29_33" + ) + + def _get_expected_profile_request( + self, + trace_location: str, + max_num_hosts: int = 1, + session_id: str = "2026_06_04_05_29_33", + ) -> dict[str, Any]: + if jax.version.__version_info__ >= (0, 9, 2): + return { + "profileRequest": { + "traceLocation": trace_location, + "maxNumHosts": max_num_hosts, + "xprofTraceOptions": { + "traceDirectory": trace_location, + "pwTraceOptions": { + "enablePythonTracer": True, + }, + "traceSessionName": session_id, + }, + } + } + else: + return { + "profileRequest": { + "traceLocation": trace_location, + "maxNumHosts": max_num_hosts, + } + } @parameterized.parameters(8000, 1234) def test_collect_profile_port(self, port): @@ -228,40 +263,67 @@ def test_start_trace_success(self): profiling.start_trace("gs://test_bucket/test_dir") self.mock_toy_computation.assert_called_once() + expected_request = self._get_expected_profile_request( + "gs://test_bucket/test_dir", max_num_hosts=1 + ) self.mock_plugin_executable_cls.assert_called_once_with( - json.dumps({ - "profileRequest": { - "traceLocation": "gs://test_bucket/test_dir", - "maxNumHosts": 1, - } - }) + json.dumps(expected_request) ) self.mock_plugin_executable_cls.return_value.call.assert_called_once() - self.mock_original_start_trace.assert_called_once_with( - log_dir="gs://test_bucket/test_dir", - create_perfetto_link=False, - create_perfetto_trace=False, - ) + self.mock_original_start_trace.assert_called_once() + call_args = self.mock_original_start_trace.call_args[1] + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") + self.assertFalse(call_args["create_perfetto_link"]) + self.assertFalse(call_args["create_perfetto_trace"]) + if jax.version.__version_info__ >= (0, 9, 2): + self.assertEqual( + call_args["profiler_options"].session_id, "2026_06_04_05_29_33" + ) self.assertIsNotNone(profiling._profile_state.executable) def test_start_trace_with_max_num_hosts(self): profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10) self.mock_toy_computation.assert_called_once() + expected_request = self._get_expected_profile_request( + "gs://test_bucket/test_dir", max_num_hosts=10 + ) self.mock_plugin_executable_cls.assert_called_once_with( - json.dumps({ - "profileRequest": { - "traceLocation": "gs://test_bucket/test_dir", - "maxNumHosts": 10, - } - }) + json.dumps(expected_request) ) self.mock_plugin_executable_cls.return_value.call.assert_called_once() - self.mock_original_start_trace.assert_called_once_with( - log_dir="gs://test_bucket/test_dir", - create_perfetto_link=False, - create_perfetto_trace=False, + self.mock_original_start_trace.assert_called_once() + call_args = self.mock_original_start_trace.call_args[1] + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") + self.assertFalse(call_args["create_perfetto_link"]) + self.assertFalse(call_args["create_perfetto_trace"]) + if jax.version.__version_info__ >= (0, 9, 2): + self.assertEqual( + call_args["profiler_options"].session_id, "2026_06_04_05_29_33" + ) + + @absltest.skipIf( + jax.version.__version_info__ < (0, 9, 2), + "ProfileOptions requires JAX 0.9.2 or newer", + ) + def test_start_trace_with_session_id_in_options(self): + options = jax.profiler.ProfileOptions() + options.session_id = "options_session" + profiling.start_trace("gs://test_bucket/test_dir", profiler_options=options) + + expected_request = self._get_expected_profile_request( + "gs://test_bucket/test_dir", max_num_hosts=1, session_id="options_session" ) + self.mock_plugin_executable_cls.assert_called_once_with( + json.dumps(expected_request) + ) + self.assertEqual(options.session_id, "options_session") + self.mock_original_start_trace.assert_called_once() + call_args = self.mock_original_start_trace.call_args[1] + self.assertEqual(call_args["log_dir"], "gs://test_bucket/test_dir") + self.assertFalse(call_args["create_perfetto_link"]) + self.assertFalse(call_args["create_perfetto_trace"]) + self.assertEqual(call_args["profiler_options"].session_id, "options_session") def test_start_trace_no_toy_computation_second_time(self): profiling.start_trace("gs://test_bucket/test_dir")