Search
Duplicate
🤗

huggingface BertForMaskedLM fine-tuning 하기

AntiBERTy는 따로 항체 서열을 위한 transformer를 직접 짠 게 아니라 huggingface BertForMaskedLM 을 fine-tuning했다고 한다.
이번 기회에 huggingface에서 제공하는 기본 BERT 모델 fine-tuning 하는 법을 정리해 보자.

transformers.BertForMaskedLM

초기화 방법은 간단하다. transformers.BertForMaskedLM(config) 로 한다.
configBertConfig의 instance로서 BERT 모델의 명세를 나타낸다. 자세히 알아보자.

transformers.BertConfig

어떤 파라미터들을 설정해줄 수 있나?

파라미터를 아래와 같이 쭉 정리해보면서 감을 잡아보자.
파라미터
타입
초깃값
설명
vocab_size
int, optional
30522
BERT 모델의 vocabulary size를 나타낸다.
hidden_size
int, optional
768
Encoder layer와 pooler layer의 차원 크기
num_hidden_layers
int, optional
12
Encoder layer의 수
num_attention_heads
int, optional
12
Encoder layer 내의 attention head 수
intermediate_size
int, optional
3072
Encoder layer 내의 FeedForward layer의 차원 크기
hidden_act
str, or Callable
“gelu”
Encoder/Pooler layer 내의 activation function 종류
hidden_dropout_prob
float, optional
0.1
Embedding, Encoder, Pooler layer 내 모든 fully connected layer에서 사용되는 dropout probability 값.
attention_probs_dropout_prob
float, optional
0.1
Attention probability 값들에 적용되는 dropout probability 값.
max_position_embeddings
int, optional
512
모델이 한 번에 입력받을 수 있는 최대 token 수 (최대 문장 길이)
type_vocab_size
int, optional
2
BertModel이나 TFBertModel을 실행할 때 사용되는 token_type_ids의 vocabulary size
initializer_large
float, optional
0.02
모든 가중치 초기화에 사용되는 truncated_normal_initializer의 standard deviation 값
layer_norm_eps
float, optional
1e-12
LayerNorm 레이어에 사용되는 eps 값
position_embedding_type
str, optional
“absolute”
Position embedding의 종류. "absolute", "relative_key", "relative_key_query" 중에 하나 고르자.

사용법?

from transformers import BertConfig, BertModel configuration = BertConfig() model = BertModel(configuration) configuration = model.config
Python
복사

주의) 파라미터 이름 오타 없는지 확인 잘 하자.

유효하지 않은 파라미터를 config에 설정해주고 모델에 건네 줘도 딱히 오류를 발생시키지 않는다.
config 파일에 설정되지 않은 파라미터들은 default 값으로 설정되므로, config를 설정해줬다고 생각했는데 오타 때문에 설정이 적용되지 않을 수도 있다. 확인 잘 하자.
from transformers import BertConfig, BertModel configuration = BertConfig(any_parameter_name=8) # No error model = BertModel(configuration) # No error, either! configuration = model.config # {any_parameter_name: 8, ...}
Python
복사

Tokenizer와 Vocabulary

모델 configuration을 어떻게 하는지는 잘 알아보았다.
자연어가 아닌 단백질 서열을 가지고 학습하고자 할 때는 Tokenizer와 Vocabulary도 아마 커스텀하게 만들어줘야 할 것 같은데, huggingface에서 제공하는 클래스를 활용하면 편할 것 같다. 할 수 있을까?

transformers.BertTokenizer를 사용하자

파라미터
타입
초깃값
설명
vocab_file
str
Vocabulary 파일 경로
do_lower_case
bool, optional
True
입력을 모두 소문자로 바꿀 것인지?
do_basic_tokenize
bool, optional
True
WordPiece 이전에 basic tokenization을 할 것인지?
never_split
Iterable, optional
Tokenization 과정 중에 절대 쪼개져서는 안 될 token들
unk_token
str, optional
"[UNK]”
The unknown token. Vocabulary에 없는 token들은 이 토큰으로 바뀐다
sep_token
str, optional
"[SEP]”
The separator token. 두개 이상의 문장을 입력으로 줄 때 문장 사이를 구분하는 token으로 활용된다.
pad_token
str, optional
"[PAD]”
서로 다른 길이의 서열들을 한 배치에 담을 때 padding이 필요하다. Padding된 자리는 이 token으로 채워진다.
cls_token
str, optional
"[CLS]”
Sequence-level classification을 할 때 [CLS] 토큰에 대해 예측을 수행하면 됨.
mask_token
str, optional
“[MASK]”
Masked language modeling으로 모델을 학습시킬 때 masking된 단어를 나타내는 토큰.
vocab_file 이 뭔데? 파일 어떻게 구성하면 되나? 아래 예시를 보자.
[PAD] [UNK] [CLS] [SEP] [MASK] 지시 ##훈련 ##기구 ##물에 런던 ##해지는 늘어난 상황이다 ...생략...
Python
복사
보통 파일명을 vocab.txt로 한다.
단백질에 대해서는 우선 아래 vocab.txt 파일로 디폴트 BertTokenizer를 만들고,
[PAD] [UNK] [CLS] [SEP] [MASK]
Python
복사
아래 코드로 아미노산 20개를 token에 추가 후 저장한다.
from transformers import BertTokenizer tokenizer = BertTokenizer('vocab.txt', do_lower_case=False) amino_acids = "ACDEFGHIKLMNPQRSTVWY" tokenizer.add_tokens(list(amino_acids)) tokenizer.save_pretrained('ProteinTokenizer') # 테스트 tokenizer = BertTokenizer.from_pretrained('ProteinTokenizer') tokenizer.tokenize('ACDEFGHIKLMNPQRSTVWY') # ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
Python
복사
사실 이렇게 되면 infrequent amino acids들 (Pyrrolysine (O), Selenocysteine (U))과, indeterminate amino acid (X, B, J, Z)들은 모두 [UNK] 로 처리된다. (Tranception 논문에서는 이런 amino acid 전처리를 정성스럽게 했다. 더 잘 하려면 tranception 같이 해야할 듯)

MLM pretraining 방법

transformers.DataCollatorForLanguageModeling에 MLM masking을 맡기자

Dataset을 input_ids 를 리턴하도록 잘 짰다면, DataLoader에 건네주는 collator를 transformers.DataCollatorForLanguageModeling 클래스를 사용하면 MLM을 위한 token masking을 알아서 처리해준다.
실제로는 80:10:10 rule에 따라서, 15% 확률로 선택된 token에 대해서 80%는 masking, 10%는 replacement, 나머지 10%는 그대로 놔둔 batch를 만들게 된다.
아래처럼 구현한다.
from transformers import DataCollatorForLanguageModeling collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=0.15, ) # dataset이 있다고 가정 loader = DataLoader(dataset, batch_size=bsz, collate_fn=collator)
Python
복사
생성된 batch 데이터는 기본적으로 dictionary이고, 각 요소 별 특징은 다음과 같다.
input_ids: pad_token을 이용해서 적절히 padding된 형태로 제공된다.
mlm=True일 때, 무작위 token을 tokenizer의 mask_token으로 바꾼다.
mlm=False일 때, 그냥 input_ids와 동일하다.
labels: loss 계산이 수행될 부분, 즉 mlm=True일 때 masking이 일어난 부분은 값이 원래 token이고, 나머지 loss 계산에서 무시될 부분은 -100 값으로 채워진다.

BertForMaskedLM의 forward에 label을 넘겨주면 loss를 return해준다.

MLM 학습의 loss는 masking된 token에 대해서 예측 값과 실제 label 사이의 Cross-Entropy Loss이다.
nn.CrossEntropyLoss(reduction='none')과 적절한 masking → 평균 계산을 통해서 loss를 직접 구할 수도 있겠지만, BertForMaskedLM 클래스는 내부적으로 forward에 label 파라미터가 None이 아닌 경우 label != -100인 부분에 대해서 CrossEntropyLoss를 구해서 loss를 계산해 return 해준다. 편리한 기능이다.

BertForMaskedLM의 output

HuggingFace 내의 transformer모델들은 output이 transformers.modeling_outputs 내의 클래스 인스턴스인 경우가 많다.
BertForMaskedLM을 비롯해서 MLM task를 위한 모델들은 output을transformers.modeling_outputs.MaskedLMOutput의 인스턴스 형태로 제공한다.
MaskedLMOutput은 다음과 같은 attribute들을 가진다. out.logit과 같은 형태로 접근 가능하다.
attribute
type
description
loss
torch.FloatTensor of shape (1,)
returned when labels is provided
logits
torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)
hidden_states
tuple(torch.FloatTensor)
returned when output_hidden_states is True or when config.output_hidden_states=True FloatTensor들의 tuple이다. (embedding layer가 있는 경우 embeddings, 각 layer의 output들 … )
attentions
tuple(torch.FloatTensor)
returned when output_attentions is True or when config.output_attentions=True FloatTensor들의 tuple이다. 각 layer 별로 (batch_size, num_heads, sequence_length, sequence_length)

참고