# seq_len = 384 for inference and 512 for training
# batch = num_cores for throughput inference / training
# and batch = 1 for real time inference

# mb = batch * seq_len
# num_encoder_stages = 24
# for int8 inference there are u8s8s8

# 4 per stage (for int8 - 2 x u8s8s8 + u8s8u8 + u8s8f32), 96 int total
mb384ic1024oc1024n"BERT-L:Encoder_MM_1*96"
# 1 per stage, 24 int total
mb384ic1024oc4096n"BERT-L:Encoder_MM_7*24"
mb384ic4096oc1024n"BERT-L:Encoder_MM_8*24"

# Used for training mostly
# num_mask_tokens = 20
# mb = num_mask_tokens * batch
#mb20ic1024oc1024n"BERT-L:Masked_1*1"
#mb20ic1024oc30522n"BERT-L:Masked_2*1"

# mb = batch
#mb1ic1024oc1024n"BERT-L:Pooler*1"
#mb1ic1024oc2n"BERT-L:Prediction*1"

# mb = batch * seq_len
#mb512ic2oc1024n"BERT-L:Embedding*1"
