Fine-tune BERT model by removing unused layers

I came across this code for BERT sentiment analysis where the unused layers are removed, Update trainable vars/trainable weights are added and I am looking for a documentation which shows what are the different layers in bert, how can we remove the unused layers, add weights etc. However, I am unable to find any documentation for this.

BERT_PATH = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
MAX_SEQ_LENGTH = 512

class BertLayer(tf.keras.layers.Layer):
  def __init__(self, bert_path, n_fine_tune_encoders=10, **kwargs,):
    self.n_fine_tune_encoders = n_fine_tune_encoders
    self.trainable = True
    self.output_size = 768
    self.bert_path = bert_path
    super(BertLayer, self).__init__(**kwargs)     
  def build(self, input_shape):
    self.bert = tf_hub.Module(self.bert_path,
                              trainable=self.trainable, 
                              name=f"{self.name}_module")
    # Remove unused layers
    trainable_vars = self.bert.variables
    trainable_vars = [var for var in trainable_vars 
                              if not "/cls/" in var.name]
    trainable_layers = ["embeddings", "pooler/dense"]

    # Select how many layers to fine tune
    for i in range(self.n_fine_tune_encoders+1):
        trainable_layers.append(f"encoder/layer_{str(10 - i)}")

    # Update trainable vars to contain only the specified layers
    trainable_vars = [var for var in trainable_vars
                              if any([l in var.name 
                                          for l in trainable_layers])]

    # Add to trainable weights
    for var in trainable_vars:
        self._trainable_weights.append(var)
    for var in self.bert.variables:
        if var not in self._trainable_weights:# and 'encoder/layer' not in var.name:
            self._non_trainable_weights.append(var)
    print('Trainable layers:', len(self._trainable_weights))
    print('Non Trainable layers:', len(self._non_trainable_weights))

    super(BertLayer, self).build(input_shape)
 
  def call(self, inputs):  
    inputs = [K.cast(x, dtype="int32") for x in inputs]
    input_ids, input_mask, segment_ids = inputs
    bert_inputs = dict(input_ids=input_ids, 
                       input_mask=input_mask, 
                       segment_ids=segment_ids)
    
    pooled = self.bert(inputs=bert_inputs, 
                       signature="tokens", 
                       as_dict=True)["pooled_output"]

    return pooled

  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_size)

model = build_model(bert_path=BERT_PATH, max_seq_length=MAX_SEQ_LENGTH, n_fine_tune_encoders=10)

Can anyone pls help me where can I find to learn the different layers in bert, how to remove some layers, add weights, how many layers to fine tune etc.?

Source: Python Questions

LEAVE A COMMENT