Utils

summary

textpruner.summary(model: Union[torch.nn.modules.module.Module, Dict], max_level: Optional[int] = 2)[source]

Show the summary of model parameters.

Parameters
  • model – the model to be inspected, can be a torch module or a state_dict.

  • max_level – The max level to display. If max_level==None, show all the levels.

Returns

A formatted string.

Example:

print(textpruner.summay(model))

inference_time

textpruner.inference_time(model: torch.nn.modules.module.Module, dummy_inputs: Union[List, Tuple, Dict], warm_up: int = 5, repetitions: int = 10)[source]

Measure and print the inference time of the model.

Parameters
  • model – the torch module to be measured.

  • dummpy_inputs – the inputs to be fed into the model, can be a list ,tuple or dict.

  • warm_up – Number of steps to warm up the device.

  • repetitions – Number of steps to perform forward propagation. More repetitions result in more accurate measurements.

Example:

input_ids = torch.randint(low=0,high=10000,size=(32,256))
textpruner.inference_time(model,dummy_inputs=[input_ids])