nntool.slurm.taskΒΆ

Functions

reconstruct_command_line(argv)

Classes

DistributedTaskConfig([num_processes, ...])

Configuration for distributed tasks.

PyTorchDistributedTask(launch_cmd, argv, ...)

A task that runs on Slurm and sets up the PyTorch distributed environment variables.

Task(argv, slurm_config[, verbose])

The base class for all tasks that will be run on Slurm.

class nntool.slurm.task.Task(argv, slurm_config, verbose=False)[source]ΒΆ

The base class for all tasks that will be run on Slurm. Especially useful for distributed tasks that need to set up the distributed environment variables.

Parameters:
  • argv (list[str]) – the command line arguments to run the task. This will be passed to the command method to reconstruct the command line.

  • slurm_config (SlurmConfig) – the Slurm configuration to use for the task.

  • verbose (bool, optional) – whether to print verbose output. Defaults to False.

log(msg)[source]ΒΆ

Log a message to the console if verbose is enabled.

Parameters:

msg (str) – the message to log.

command()[source]ΒΆ

Return the command to run the task. This method should be implemented by subclasses to return the actual command line to run the task.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Returns:

the command to run the task.

Return type:

str

checkpoint()[source]ΒΆ

Return a checkpoint for the task. This is used to save the state of the task.

class nntool.slurm.task.DistributedTaskConfig(num_processes='$nntool_num_processes', num_machines='$nntool_num_machines', machine_rank='$nntool_machine_rank', main_process_ip='$nntool_main_process_ip', main_process_port='$nntool_main_process_port')[source]ΒΆ

Configuration for distributed tasks. This is used to set up the distributed environment variables for PyTorch distributed training.

Parameters:
  • num_processes (int) – The total number of processes to run across all machines.

  • num_machines (int) – The number of machines to run the task on.

  • machine_rank (int) – The rank of the current machine in the distributed setup.

  • main_process_ip (str) – The IP address of the main process (rank 0) in the distributed setup.

  • main_process_port (int) – The port of the main process (rank 0) in the distributed setup.

export_bash(output_folder)[source]ΒΆ

Export the distributed environment variables to a bash script. This script can be sourced to set the environment variables for the distributed task.

Parameters:

output_folder (str) – the folder to save the bash script to.

class nntool.slurm.task.PyTorchDistributedTask(launch_cmd, argv, slurm_config, verbose=False, **env_setup_kwargs)[source]ΒΆ

A task that runs on Slurm and sets up the PyTorch distributed environment variables. It runs the command locally if in other modes.

Parameters:
  • launch_cmd (str) – The command to launch the task.

  • argv (list[str]) – The command line arguments for the task.

  • slurm_config (SlurmConfig) – The Slurm configuration to use for the task.

  • verbose (bool, optional) – _description_. Defaults to False.

References

https://github.com/huggingface/accelerate/issues/1239 https://github.com/yuvalkirstain/PickScore/blob/main/trainer/slurm_scripts/slurm_train.py https://github.com/facebookincubator/submitit/pull/1703

set_up_dist_env()[source]ΒΆ

Set up the distributed environment variables for PyTorch distributed training.

command()[source]ΒΆ

Return the command to run the task. This method should be implemented by subclasses to return the actual command line to run the task.

Returns:

the command to run the task.

Return type:

str