class textpruner.configurations.Config[source]

Base class for GeneralConfig, VocabularyPruningConfig and TransformerPruningConfig.

classmethod from_dict(config_map: dict)[source]

Construct the configuration from a dict.

classmethod from_json(json_filename: str)[source]

Construct the configuration from a json file.

save_to_json(json_filename: str)[source]

Save the configuration the a json file.


class textpruner.configurations.GeneralConfig(use_device: str = 'auto', output_dir: str = './pruned_models', config_class: str = 'GeneralConfig')[source]

Configurations for the device and the output directory.

  • device'cpu' or 'cuda' or 'cuda:0' etc. Specify which device to use. If it is set to 'auto', TextPruner will try to use the CUDA device if there is one; otherwise uses CPU.

  • output_dir – The diretory to save the pruned models.

  • config_class – Type of the configurations. Users should not change its value.


class textpruner.configurations.VocabularyPruningConfig(min_count: int = 1, prune_lm_head: Union[bool, str] = 'auto', config_class: str = 'VocabularyPruningConfig')[source]

Configurations for vocabulary pruning.

  • min_count – The threshold to decide if the token should be removed. The token will be removed from the vocabulary if it appears less than min_count times in the corpus.

  • prune_lm_head – whether pruning the lm_head if the model has one. If prune_lm_head==False, TextPruner will not prune the lm_head; if prune_lm_head==True, TextPruner will prune the lm_head and raise a error if the model does not have an lm_head; if prune_lm_head=='auto', TextPruner will try to prune the lm_head and will continue if the model does not have an lm_head.

  • config_class – Type of the configurations. Users should not change its value.


class textpruner.configurations.TransformerPruningConfig(target_ffn_size: Optional[int] = None, target_num_of_heads: Optional[int] = None, pruning_method: str = 'masks', ffn_even_masking: Optional[bool] = True, head_even_masking: Optional[bool] = True, n_iters: Optional[int] = 1, multiple_of: int = 1, pruning_order: Optional[str] = None, use_logits: bool = False, config_class: str = 'TransformerPruningConfig')[source]

Configurations for transformer pruning.

  • target_ffn_size – the target average FFN size per layer.

  • target_num_of_heads – the target average number of heads per layer.

  • pruning_method'masks' or 'iterative'. If set to 'masks', the pruner prunes the model with the given masks (head_mask and ffn_mask). If set to 'iterative'. the pruner calculates the importance scores of the neurons based on the data provided by the dataloader and then prunes the model based on the scores.

  • ffn_even_masking – Whether the FFN size of each layer should be the same.

  • head_even_masking – Whether the number of attention heads of each layer should be the same.

  • n_iters – if pruning_method is set to 'iterative', n_iters is number of pruning iterations to prune the model progressively.

  • multiple_of – if ffn_even_masking is False, restrict the target FFN size of each layer to be a multiple of multiple_if.

  • pruning_orderNone or 'head-first' or 'ffn-first'. None: prune the attention heads and ffn layer simultaneously; if set to 'head-first' or 'ffn-first', the actual number of iterations is 2*n_iters.

  • use_logits – if True, performs self-supervised pruning, where the logits are treated as the soft labels.

  • config_class – Type of the configurations. Users should not change its value.


if ffn_even_masking is False, the pruned model can not be save normally (we cannot load the model with the transformers libarary with the saved weights). So make sure to set save_model=False when calling TransformerPruner.prune() or PipelinePruner.prune(). There are two ways to avoid this:

  • Save the model in TorchScript format manually;

  • Set keep_shape=False when calling TransformerPruner.prune() or PipelinePruner.prune(), so the full model can be saved. Then save the ffn_masks and head_masks. When loading the model, load the full model and then prune it with the masks.