Source code for textpruner.pruners.transformer_pruner

import torch
from torch import nn
import os

from torch.nn.functional import softmax, log_softmax
from .utils import move_to_device, generate_mask, infer_model_type
from .utils import infer_logits, infer_loss
from ..configurations import  TransformerPruningConfig, GeneralConfig

from ..model_map import MODEL_MAP
import logging
from tqdm import tqdm
from collections import abc
from typing import Mapping, Optional
from copy import deepcopy
logger = logging.getLogger(__name__)

[docs]class TransformerPruner: ''' Args: model : The model to be pruned. transformer_pruning_config : a :class:`~textpruner.configurations.TransformerPruningConfig` object. general_config : a :class:`~textpruner.configurations.GeneralConfig` object. base_model_prefix : The prefix of the base model, i.e., the name of the base model as a member in the model. \ For example, if ``model.bert_encoder = BertModel(...)``, then the ``base_model_prefix`` is ``bert_encoder``. \ TextPruner will infer the ``base_model_prefix`` so we can leave its value as ``None``. But if it fails, users have to set its value explicitly. ''' def __init__(self, model : nn.Module, transformer_pruning_config : Optional[TransformerPruningConfig] = None, general_config : Optional[GeneralConfig] = None, base_model_prefix : Optional[str] = None): self.model = model base_model, model_type = infer_model_type(model, base_model_prefix) assert model_type in MODEL_MAP, \ f"Model type {self.model_type} is not supported, or not understood. Model type must be one of {list(MODEL_MAP.keys())}" self.base_model = base_model self.model_type = model_type self.model_structure = MODEL_MAP[self.model_type]['structure'] self.general_config = GeneralConfig() if general_config is None else general_config self.transformer_pruning_config = TransformerPruningConfig() if transformer_pruning_config is None else transformer_pruning_config self.model.to(self.general_config.device) self.output_dir : str = self.general_config.output_dir # None before pruning self.head_mask : Optional[torch.Tensor] = None self.ffn_mask : Optional[torch.Tensor] = None self.keep_shape : Optional[bool] = None os.makedirs(self.output_dir, exist_ok=True) self.shoule_cache_logits = True self.soft_labels = [] if self.transformer_pruning_config.use_logits is True: self.model_rep = deepcopy(model) self.model_rep.half().to(model.device) self.save_dir = None
[docs] def prune(self, dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] =None, ffn_mask: Optional[torch.Tensor]=None, keep_shape=False, save_model=True, rewrite_cache=True): ''' Prunes the transformers. If ``self.transformer_pruning_config.pruning_method=='masks'``, the pruner prune the attention heads and the FFN neurons based on the ``head_masks`` and ``ffn_masks``; if ``self.transformer_pruning_config.pruning_method=='iterative'``, the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from the ``dataloader``. Args: dataloader : a dataloader that generates batches. Each batch should contains both the inputs and the labels. adaptor : a function that takes the model output and return the loss. batch_postprocessor : a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed. head_mask : a tensor of shape ``(num_layers, num_attention_heads)``. `1` means to keep, `0` means to prune. ffn_mask : a tensor of shape ``(num_layers, intermediate_hidden_size)``. `1` means to keep, `0` means to prune. keep_shape : if ``True``, the model is no actually pruned and the model stucture is not changed, but the weights that *should be pruned* are set to zero. save_model : whether to save the model when the pruning is finished. ''' pruning_method = self.transformer_pruning_config.pruning_method if pruning_method == 'masks': if head_mask is not None or ffn_mask is not None: save_dir = self.prune_with_masks(head_mask=head_mask, ffn_mask=ffn_mask, set_masks=True, save_model=save_model) else: raise TypeError("Pruning method is 'masks', but no masks are given.") elif pruning_method == 'iterative': assert (dataloader is not None ), "Pruning method is 'iterative', but dataloader is not given." save_dir = self.iterative_pruning(dataloader, adaptor, batch_postprocessor, keep_shape, save_model=save_model, rewrite_cache=rewrite_cache) else: raise NotImplementedError(f"Unknow pruning method {pruning_method}.") self.save_dir = save_dir return save_dir
def prune_with_masks(self,head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape : bool = False, set_masks = False, save_model = False) -> Optional[str]: if head_mask is None: head_mask = self.head_mask if ffn_mask is None: ffn_mask = self.ffn_mask if set_masks is True: if head_mask is not None: self.head_mask = head_mask if ffn_mask is not None: self.ffn_mask = ffn_mask if ffn_mask is not None: ffn_mask_tensor = ffn_mask.clone().detach().to(dtype=torch.float32, device=self.general_config.device) self.reorder_ffn_weights(ffn_mask_tensor, keep_shape) if head_mask is not None: if keep_shape: head_mask_tensor = head_mask.clone().detach().to(dtype=torch.float32, device=self.general_config.device) self.reorder_attention_heads(head_mask_tensor, keep_shape) else: heads_to_prune_dict = {} for layer_num, layer_head in enumerate(head_mask.tolist()): heads_to_prune_dict[layer_num] = [] for head_idx, v in enumerate(layer_head): if v==0: heads_to_prune_dict[layer_num].append(head_idx) self.base_model.prune_heads(heads_to_prune_dict) self.keep_shape = keep_shape if save_model is True: return self.save_model() def iterative_pruning(self, dataloader, adaptor, batch_postprocessor=None, keep_shape=False, save_model=True, rewrite_cache=False) -> Optional[str]: target_ffn_size = self.transformer_pruning_config.target_ffn_size target_num_of_heads = self.transformer_pruning_config.target_num_of_heads n_iters = self.transformer_pruning_config.n_iters multiple_of = self.transformer_pruning_config.multiple_of head_even_masking = self.transformer_pruning_config.head_even_masking ffn_even_masking = self.transformer_pruning_config.ffn_even_masking pruning_order = self.transformer_pruning_config.pruning_order head_importance_fn = os.path.join(self.output_dir, f'head_importance.pt') ffn_importance_fn = os.path.join(self.output_dir,f'ffn_importance.pt') if os.path.exists(head_importance_fn) and os.path.exists(ffn_importance_fn) and rewrite_cache is False: logger.info(f"Loading pre-cached head importance score {head_importance_fn}") head_importance = torch.load(head_importance_fn) logger.info(f"Loading pre-cached ffn importance score {ffn_importance_fn}") ffn_importance = torch.load(ffn_importance_fn) else: logger.info("Calculating head importance and ffn importance") if self.transformer_pruning_config.use_logits: head_importance, ffn_importance = self.get_importance_score_with_logits(dataloader, adaptor, batch_postprocessor) else: head_importance, ffn_importance = self.get_importance_score(dataloader, adaptor, batch_postprocessor) head_importance = head_importance.cpu() # (num_layers, num_heads) ffn_importance = ffn_importance.cpu() # (num_layers, intermediate_size) # Save importance score logger.info("Save...") torch.save(head_importance, head_importance_fn) torch.save(ffn_importance, ffn_importance_fn) total_num_of_heads = head_importance.size(0)*head_importance.size(1) total_ffn_size = ffn_importance.size(0)*ffn_importance.size(1) total_target_ffn_size = target_ffn_size * ffn_importance.size(0) total_target_num_of_heads = target_num_of_heads *head_importance.size(0) ffn_size_per_iter = (total_ffn_size - total_target_ffn_size) // n_iters num_of_heads_per_iter = (total_num_of_heads - total_target_num_of_heads) // n_iters ffn_size_res = (total_ffn_size - total_target_ffn_size) % n_iters num_of_heads_res = (total_num_of_heads - total_target_num_of_heads) % n_iters dffn_size = total_ffn_size dnum_of_heads = total_num_of_heads if pruning_order is None: for i in range(n_iters): logger.info(f'Number of pruning iterations: {i+1}/{n_iters}') if i > 0: logger.info("Calculating head importance and ffn importance") if self.transformer_pruning_config.use_logits: head_importance, ffn_importance = self.get_importance_score_with_logits(dataloader, adaptor, batch_postprocessor) else: head_importance, ffn_importance = self.get_importance_score(dataloader, adaptor, batch_postprocessor) head_importance = head_importance.cpu() # (num_layers, num_heads) ffn_importance = ffn_importance.cpu() # (num_layers, intermediate_size) assert torch.all(head_importance==head_importance*self.head_mask) assert torch.all(ffn_importance==ffn_importance*self.ffn_mask) #head_importance *= self.head_mask #ffn_importance *= self.ffn_mask dffn_size -= ffn_size_per_iter + 1 if i < ffn_size_res else ffn_size_per_iter dnum_of_heads -= num_of_heads_per_iter + 1 if i < num_of_heads_res else num_of_heads_per_iter self.head_mask = generate_mask(head_importance, dnum_of_heads, head_even_masking) self.ffn_mask = generate_mask(ffn_importance, dffn_size, ffn_even_masking, multiple_of=multiple_of) logger.info(f"New ffn size:{self.ffn_mask.sum(-1).tolist()}") logger.info(f"New num heads:{self.head_mask.sum(-1).tolist()}") if i==n_iters-1: self.prune_with_masks(keep_shape=keep_shape, save_model=False) else: self.prune_with_masks(keep_shape=True, save_model=False) else: for i in range(n_iters * 2): # n_iters for head, n_iters for ffn logger.info(f'Number of pruning iterations: {i+1}/{n_iters * 2}') if pruning_order=='head-first': current_is_head = (i%2==0) current_is_ffn = (i%2==1) else: current_is_ffn = (i%2==0) current_is_head = (i%2==1) if i > 0: logger.info("Calculating head importance and ffn importance") if self.transformer_pruning_config.use_logits: head_importance, ffn_importance = self.get_importance_score_with_logits(dataloader, adaptor, batch_postprocessor) else: head_importance, ffn_importance = self.get_importance_score(dataloader, adaptor, batch_postprocessor) head_importance = head_importance.cpu() # (num_layers, num_heads) ffn_importance = ffn_importance.cpu() # (num_layers, intermediate_size) if current_is_ffn: dffn_size -= ffn_size_per_iter + 1 if i//2 < ffn_size_res else ffn_size_per_iter self.ffn_mask = generate_mask(ffn_importance, dffn_size, ffn_even_masking, multiple_of=multiple_of) logger.info(f"New ffn size:{self.ffn_mask.sum(-1).tolist()}") if current_is_head: dnum_of_heads -= num_of_heads_per_iter + 1 if i//2 < num_of_heads_res else num_of_heads_per_iter self.head_mask = generate_mask(head_importance, dnum_of_heads, head_even_masking) logger.info(f"New num heads:{self.head_mask.sum(-1).tolist()}") if i==2 * n_iters-1: self.prune_with_masks(keep_shape=keep_shape, save_model=False) else: self.prune_with_masks(keep_shape=True, save_model=False) #clear cache self.soft_labels = [] self.shoule_cache_logits = True logger.info("Head and ffn masks have been generated, can be accessed via self.head_mask and self.ffn_mask") if save_model is True: return self.save_model() def save_masks(self,name='mask.pt') -> str: save_dir = os.path.join(self.general_config.output_dir,f'head_ffn_masks') os.makedirs(save_dir, exist_ok=True) torch.save((self.head_mask,self.ffn_mask),os.path.join(save_dir,f'{name}')) # save config logger.info(f"Masks have been saved to {save_dir}") return save_dir def save_model(self, dir_name=None) -> str: ffn_sizes = self.ffn_mask.to(int).sum(-1).tolist() if self.keep_shape is False: ffn_size = ffn_sizes[0] num_of_heads = self.head_mask.sum().item() / self.head_mask.size(0) # self.head_mask.to(int).sum().item() if len(set(ffn_sizes)) != 1: raise NotImplementedError("Cannot save pruned model with different ffn size per layer with keep_shape=False. \ Call TransformerPruner.save_masks or TransformerPruner.save_jit_model manually instead.") else: self.base_model.config.intermediate_size = ffn_size else: ffn_size = self.ffn_mask.size(1) #base_model.config.intermediate_size num_of_heads = self.head_mask.size(1) #self.transformer_pruning_config.target_num_of_heads if dir_name is None: save_dir = os.path.join(self.general_config.output_dir,f'pruned_H{num_of_heads}F{ffn_size}') else: save_dir = os.path.join(self.general_config.output_dir,dir_name) os.makedirs(save_dir, exist_ok=True) torch.save(self.model.state_dict(),os.path.join(save_dir,'pytorch_model.bin')) # save config self.base_model.config.save_pretrained(save_dir) logger.info(f"Model and configuration have been saved to {save_dir}") return save_dir def save_jit_model(self, example_inputs, dir_name=None) -> str: self.model.eval() with torch.no_grad(): traced_model = torch.jit.trace(self.model, example_inputs=example_inputs, strict=False) if dir_name is None: save_dir = os.path.join(self.general_config.output_dir,'pruned_H{num_of_heads}F{ffn_size}_traced') else: save_dir = os.path.join(self.general_config.output_dir,dir_name) os.makedirs(save_dir, exist_ok=True) torch.jit.save(traced_model, os.path.join(save_dir,'pytorch_model.ts')) return save_dir def reorder_attention_heads(self, head_mask, keep_shape = False): n_layers = head_mask.size(0) head_size = int(self.base_model.config.hidden_size / self.base_model.config.num_attention_heads) #assert torch.all(new_num_heads_vec==new_num_heads_vec[0]), "Numbers of heads in each layer must be equal" att_queries = self.model_structure.get_att_query(self.base_model, ignore_model_prefix=True) att_keys = self.model_structure.get_att_key(self.base_model, ignore_model_prefix=True) att_values = self.model_structure.get_att_value(self.base_model, ignore_model_prefix=True) att_outputs = self.model_structure.get_att_output(self.base_model, ignore_model_prefix=True) for layer_num in range(n_layers): query_weight = att_queries[layer_num].weight query_bias = att_queries[layer_num].bias key_weight = att_keys[layer_num].weight key_bias = att_keys[layer_num].bias value_weight = att_values[layer_num].weight value_bias = att_values[layer_num].bias output_weight = att_outputs[layer_num].weight # sort query, key, value based on the scores query_weight, query_bias = rearange_weights(query_weight,query_bias,head_mask[layer_num],head_size,keep_shape) att_queries[layer_num].weight = torch.nn.Parameter(query_weight) att_queries[layer_num].bias = torch.nn.Parameter(query_bias) key_weight, key_bias = rearange_weights(key_weight,key_bias,head_mask[layer_num],head_size,keep_shape) att_keys[layer_num].weight = torch.nn.Parameter(key_weight) att_keys[layer_num].bias = torch.nn.Parameter(key_bias) value_weight, value_bias = rearange_weights(value_weight,value_bias,head_mask[layer_num],head_size,keep_shape) att_values[layer_num].weight = torch.nn.Parameter(value_weight) att_values[layer_num].bias = torch.nn.Parameter(value_bias) output_weight, _ = rearange_weights(output_weight.transpose(0,1), None, head_mask[layer_num],head_size,keep_shape) output_weight = output_weight.transpose(0,1) att_outputs[layer_num].weight = torch.nn.Parameter(output_weight) def reorder_ffn_weights(self, ffn_mask, keep_shape = False): head_size = 1 #int(base_model.config.hidden_size / base_model.config.num_attention_heads) n_layers = ffn_mask.size(0) ffn_interm = self.model_structure.get_ffn_interm(self.base_model, ignore_model_prefix=True) ffn_output = self.model_structure.get_ffn_output(self.base_model, ignore_model_prefix=True) for layer_num in range(n_layers): inter_weight = ffn_interm[layer_num].weight inter_bias = ffn_interm[layer_num].bias output_weight = ffn_output[layer_num].weight # sort query, key, value based on the confidence scores inter_weight, inter_bias = rearange_weights(inter_weight, inter_bias, ffn_mask[layer_num], head_size, keep_shape) ffn_interm[layer_num].weight = torch.nn.Parameter(inter_weight) ffn_interm[layer_num].bias = torch.nn.Parameter(inter_bias) output_weight, _ = rearange_weights(output_weight.transpose(0,1), None, ffn_mask[layer_num], head_size, keep_shape) output_weight = output_weight.transpose(0,1) ffn_output[layer_num].weight = torch.nn.Parameter(output_weight) def get_importance_score(self, dataloader, adaptor=None, batch_postprocessor=None) -> torch.Tensor : model = self.model n_layers = self.model_structure.get_num_layers(self.base_model, ignore_model_prefix=True) n_heads = self.base_model.config.num_attention_heads intermediate_size = self.base_model.config.intermediate_size device = self.general_config.device logger.info("***** Running Forward and Backward to calcuate importance score*****") logger.info(" Length of dataloader = %d", len(dataloader)) model.eval() head_importance = torch.zeros(n_layers, n_heads).to(device) #get ffn weights and bias ffn_inter_weights = [] ffn_inter_biases = [] ffn_output_weights = [] att_output_weights = [] ffn_interm = self.model_structure.get_ffn_interm(self.base_model, ignore_model_prefix=True) ffn_output = self.model_structure.get_ffn_output(self.base_model, ignore_model_prefix=True) att_output = self.model_structure.get_att_output(self.base_model, ignore_model_prefix=True) for layer_num in range(n_layers): ffn_inter_weights.append(ffn_interm[layer_num].weight) #.detach().to(device) ffn_inter_biases.append(ffn_interm[layer_num].bias) #.detach().to(device) ffn_output_weights.append(ffn_output[layer_num].weight) #.detach().to(device) att_output_weights.append(att_output[layer_num].weight) ffn_importance = torch.zeros(n_layers, intermediate_size).to(device) #ex. (12,3072) num_examples = 0.0 for batch in tqdm(dataloader, desc="Calculating IS with loss"): if batch_postprocessor is not None: batch = batch_postprocessor(batch) batch = move_to_device(batch, device) if isinstance(batch,abc.Mapping): outputs = model(**batch) batch_num_examples = len(list(batch.values())[0]) else: outputs = model(*batch) batch_num_examples = len(batch[0]) loss = infer_loss(outputs, adaptor) loss.backward() for layer_num in range(n_layers): weight = att_output_weights[layer_num] head_importance[layer_num] += (weight.grad * weight).view(weight.size(0),n_heads, -1).sum(dim=(0,2)).abs().detach() # (num_heads, ) for layer_num in range(n_layers): weight1 = ffn_inter_weights[layer_num] bias1 = ffn_inter_biases[layer_num] weight2 = ffn_output_weights[layer_num] if self.transformer_pruning_config.ffn_even_masking: ffn_importance[layer_num] += ((weight1.grad * weight1).sum(dim=1)+ bias1.grad * bias1).abs().detach() ffn_importance[layer_num] += ((weight2.grad * weight2).sum(dim=0)).abs().detach() model.zero_grad() num_examples += batch_num_examples head_importance /= num_examples ffn_importance /= num_examples return head_importance, ffn_importance def get_importance_score_with_logits(self, dataloader, adaptor=None, batch_postprocessor=None) -> torch.Tensor : model = self.model n_layers = self.model_structure.get_num_layers(self.base_model, ignore_model_prefix=True) n_heads = self.base_model.config.num_attention_heads intermediate_size = self.base_model.config.intermediate_size device = self.general_config.device logger.info("***** Running Forward and Backward to calcuate importance score*****") logger.info(" Length of dataloader = %d", len(dataloader)) model.eval() self.model_rep.eval() head_importance = torch.zeros(n_layers, n_heads).to(device) #get ffn weights and bias ffn_inter_weights = [] ffn_inter_biases = [] ffn_output_weights = [] att_output_weights = [] ffn_interm = self.model_structure.get_ffn_interm(self.base_model, ignore_model_prefix=True) ffn_output = self.model_structure.get_ffn_output(self.base_model, ignore_model_prefix=True) att_output = self.model_structure.get_att_output(self.base_model, ignore_model_prefix=True) for layer_num in range(n_layers): ffn_inter_weights.append(ffn_interm[layer_num].weight) #.detach().to(device) ffn_inter_biases.append(ffn_interm[layer_num].bias) #.detach().to(device) ffn_output_weights.append(ffn_output[layer_num].weight) #.detach().to(device) att_output_weights.append(att_output[layer_num].weight) ffn_importance = torch.zeros(n_layers, intermediate_size).to(device) #ex. (12,3072) num_examples = 0.0 for idx,batch in enumerate(tqdm(dataloader, desc="Calculating IS with logits")): if batch_postprocessor is not None: batch = batch_postprocessor(batch) batch = move_to_device(batch, device) if isinstance(batch,abc.Mapping): outputs = model(**batch) batch_num_examples = len(list(batch.values())[0]) else: outputs = model(*batch) batch_num_examples = len(batch[0]) with torch.no_grad(): outputs_rep = self.model_rep(**batch) if isinstance(batch,abc.Mapping) else self.model_rep(*batch) logits_rep = infer_logits(outputs_rep, adaptor) logits = infer_logits(outputs, adaptor) #if self.shoule_cache_logits is True: # cache soft labels if the cache is empty # p = softmax(logits, dim=-1).detach() # self.soft_labels.append(p) if isinstance(logits,(list,tuple)): entropy = 0 for logits_p, logits_q in zip(logits_rep, logits): current_p = softmax(logits_p, dim=-1).detach() current_q = logits_q entropy += -(log_softmax(current_q,dim=-1) * current_p).sum(dim=-1).mean() else: current_p = softmax(logits_rep, dim=-1).detach() #p = softmax(logits, dim=-1).detach() #self.soft_labels[idx] #current_p = self.soft_labels[idx] current_q = logits entropy = - (log_softmax(current_q,dim=-1) * current_p).sum(dim=-1).mean() entropy.backward() for layer_num in range(n_layers): weight = att_output_weights[layer_num] head_importance[layer_num] += (weight.grad * weight).view(weight.size(0),n_heads, -1).sum(dim=(0,2)).abs().detach() # (num_heads, ) for layer_num in range(n_layers): weight1 = ffn_inter_weights[layer_num] bias1 = ffn_inter_biases[layer_num] weight2 = ffn_output_weights[layer_num] if self.transformer_pruning_config.ffn_even_masking: ffn_importance[layer_num] += ((weight1.grad * weight1).sum(dim=1)+ bias1.grad * bias1).abs().detach() ffn_importance[layer_num] += ((weight2.grad * weight2).sum(dim=0)).abs().detach() model.zero_grad() num_examples += batch_num_examples if self.shoule_cache_logits is True: self.shoule_cache_logits = False head_importance /= num_examples ffn_importance /= num_examples return head_importance, ffn_importance
def rearange_weights(weight, bias, mask, head_size, keep_shape = False): num_heads = mask.size(0) mask_dim3 = mask.view(num_heads,1,1).to(torch.bool) # 12,1,1 ? weight_dim3 = weight.view(num_heads,head_size,weight.size(1)) # 12,64,768 if keep_shape is False: selected_weight = weight_dim3.masked_select(mask_dim3) new_num_heads = int(mask.sum().item()) else: selected_weight = torch.mul(weight_dim3, mask_dim3) new_num_heads = num_heads ##reshape back selected_weight = selected_weight.view(new_num_heads*head_size, weight.size(1)) selected_bias = None if bias is not None: mask_dim2 = mask.view(num_heads,1).to(torch.bool) # 12,1 ? bias_dim2 = bias.view(num_heads,head_size) #12,64 if keep_shape == False: selected_bias = bias_dim2.masked_select(mask_dim2) else: selected_bias = torch.mul(bias_dim2, mask_dim2) selected_bias = selected_bias.view(new_num_heads*head_size) return selected_weight, selected_bias