Utils
summary
- textpruner.summary(model: Union[torch.nn.modules.module.Module, Dict], max_level: Optional[int] = 2)[source]
Show the summary of model parameters.
- Parameters
model – the model to be inspected, can be a torch module or a state_dict.
max_level – The max level to display. If
max_level==None
, show all the levels.
- Returns
A formatted string.
Example:
print(textpruner.summay(model))
inference_time
- textpruner.inference_time(model: torch.nn.modules.module.Module, dummy_inputs: Union[List, Tuple, Dict], warm_up: int = 5, repetitions: int = 10)[source]
Measure and print the inference time of the model.
- Parameters
model – the torch module to be measured.
dummpy_inputs – the inputs to be fed into the model, can be a list ,tuple or dict.
warm_up – Number of steps to warm up the device.
repetitions – Number of steps to perform forward propagation. More repetitions result in more accurate measurements.
Example:
input_ids = torch.randint(low=0,high=10000,size=(32,256)) textpruner.inference_time(model,dummy_inputs=[input_ids])