Source code for textpruner.configurations
from dataclasses import asdict
import torch
import json
import logging
from typing import Union, Optional
from dataclasses import dataclass, asdict
logger = logging.getLogger(__name__)
[docs]@dataclass
class Config:
"""Base class for :class:`~textpruner.configurations.GeneralConfig`,
:class:`~textpruner.configurations.VocabularyPruningConfig` and :class:`~textpruner.configurations.TransformerPruningConfig`."""
[docs] @classmethod
def from_json(cls, json_filename: str):
"""Construct the configuration from a json file."""
with open(json_filename,'r') as f:
config_map = json.load(f)
config = CONFIG_CLASS[config_map['config_class']].from_dict(config_map)
return config
[docs] @classmethod
def from_dict(cls, config_map: dict):
"""Construct the configuration from a dict."""
config = CONFIG_CLASS[config_map['config_class']](**config_map)
return config
[docs] def save_to_json(self, json_filename: str):
"""Save the configuration the a json file."""
config_map = asdict(self)
with open(json_filename,'w') as f:
json.dump(config_map, f, indent = 2)
[docs]@dataclass
class GeneralConfig(Config):
'''
Configurations for the device and the output directory.
Args:
device: ``'cpu'`` or ``'cuda'`` or ``'cuda:0'`` etc. Specify which device to use. If it is set to ``'auto'``,
TextPruner will try to use the CUDA device if there is one; otherwise uses CPU.
output_dir: The diretory to save the pruned models.
config_class: Type of the configurations. Users should not change its value.
'''
use_device: str = 'auto'
output_dir: str = './pruned_models'
config_class : str = "GeneralConfig"
def __post_init__(self):
if self.use_device == 'auto':
if torch.cuda.is_available():
logger.info(f"Using current cuda device")
self.device = ('cuda')
else:
logger.info(f"Using cpu device")
self.device = ('cpu')
else:
self.device = self.use_device
[docs]@dataclass
class VocabularyPruningConfig(Config):
'''
Configurations for vocabulary pruning.
Args:
min_count: The threshold to decide if the token should be removed.
The token will be removed from the vocabulary if it appears less than ``min_count`` times in the corpus.
prune_lm_head: whether pruning the lm_head if the model has one. If ``prune_lm_head==False``, TextPruner will not prune the lm_head;
if ``prune_lm_head==True``, TextPruner will prune the lm_head and raise a error if the model does not have an lm_head;
if ``prune_lm_head=='auto'``, TextPruner will try to prune the lm_head and will continue if the model does not have an lm_head.
config_class: Type of the configurations. Users should not change its value.
'''
min_count: int = 1
prune_lm_head : Union[bool,str] = 'auto'
config_class: str = "VocabularyPruningConfig"
[docs]@dataclass
class TransformerPruningConfig(Config):
"""
Configurations for transformer pruning.
Args:
target_ffn_size : the target average FFN size per layer.
target_num_of_heads : the target average number of heads per layer.
pruning_method : ``'masks'`` or ``'iterative'``. If set to ``'masks'``, the pruner prunes the model with the given masks (``head_mask`` and ``ffn_mask``).
If set to ``'iterative'``. the pruner calculates the importance scores of the neurons based on the data provided by the ``dataloader`` and then prunes the model based on the scores.
ffn_even_masking : Whether the FFN size of each layer should be the same.
head_even_masking : Whether the number of attention heads of each layer should be the same.
n_iters : if ``pruning_method`` is set to ``'iterative'``, ``n_iters`` is number of pruning iterations to prune the model progressively.
multiple_of : if ``ffn_even_masking`` is ``False``, restrict the target FFN size of each layer to be a multiple of ``multiple_if``.
pruning_order: ``None`` or ``'head-first'`` or ``'ffn-first'``. ``None``: prune the attention heads and ffn layer simultaneously; if set to ``'head-first'`` or ``'ffn-first'``, the actual number of iterations is ``2*n_iters``.
use_logits : if ``True``, performs self-supervised pruning, where the logits are treated as the soft labels.
config_class: Type of the configurations. Users should not change its value.
Warning:
if ``ffn_even_masking`` is ``False``, the pruned model can not be save normally (we cannot load the model with the transformers libarary with the saved weights).
So make sure to set ``save_model=False`` when calling ``TransformerPruner.prune()`` or ``PipelinePruner.prune()``.
There are two ways to avoid this:
* Save the model in TorchScript format manually;
* Set ``keep_shape=False`` when calling ``TransformerPruner.prune()`` or ``PipelinePruner.prune()``, so the full model can be saved. Then save the ``ffn_masks`` and ``head_masks``. When loading the model, load the full model and then prune it with the masks.
"""
target_ffn_size : Optional[int] = None
target_num_of_heads: Optional[int] = None
pruning_method : str = 'masks'
ffn_even_masking : Optional[bool] = True
head_even_masking : Optional[bool] = True
n_iters : Optional[int] = 1
multiple_of : int = 1
pruning_order : Optional[str] = None
use_logits : bool = False
config_class: str = "TransformerPruningConfig"
def __post_init__(self):
assert self.pruning_method in ('masks','iterative'), "Unrecgonized pruning method"
assert (self.pruning_order is None) or (self.pruning_order in ('head-first','ffn-first')), "Unrecgonized pruning order"
if self.ffn_even_masking is False:
logger.warning("ffn_even_masking is False. Pruned model can only be save in TorchScript format manually.")
CONFIG_CLASS = {
'GeneralConfig': GeneralConfig,
'VocabularyPruningConfig': VocabularyPruningConfig,
'TransformerPruningConfig': TransformerPruningConfig
}