# The current problem list corresponds to sequence length paremeter value
# equals to 40 (seq_len = 40) and beam patameter equals to 4
# batch = 56 (num_cores)

# =============================================================================
# Encoder part
# num_encoder_stages = 6
#   2d problems - M = batch * seq_len
#   4d problems - M = seq_len, B = batch x 16
# =============================================================================

--reset
--dt=u8:s8:s8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 3 per each encoder stage, 18 in total
2240x1024:1024x1024n"Transformer_lt:Encoder_MM_1*18"

--dt=u8:s8:s8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 2 per each encoder stage, 12 in total
2240x1024:1024x1024n"Transformer_lt:Encoder_MM_2*12"

--dt=u8:s8:bf16 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abdc --dtag=abcd
56x16x40x64:56x16x64x40n"Transformer_lt:Encoder_MM_4*6"

--dt=u8:s8:u8 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abcd --dtag=abcd
56x16x40x40:56x16x40x64n"Transformer_lt:Encoder_MM_5*6"

--dt=u8:s8:bf16 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
2240x1024:1024x1024n"Transformer_lt:Encoder_MM_6*6"

--dt=u8:s8:u8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
2240x1024:1024x4096n"Transformer_lt:Encoder_MM_7*6"

--dt=u8:s8:bf16 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
2240x4096:4096x1024n"Transformer_lt:Encoder_MM_8*6"

# =============================================================================
# Decoder part
#   2d problems - M = batch * beam
#   4d problems - M = beam, B = batch x 16
# number of calls depends on sequence length value
# =============================================================================

--dt=u8:s8:s8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 3 per each encoder and sequence length, 720 in total for seq_len = 40
224x1024:1024x1024n"Transformer_lt:Decoder_MM_1*720"

--dt=u8:s8:u8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
224x1024:1024x1024n"Transformer_lt:Decoder_MM_2*240"

--dt=u8:s8:bf16 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abdc --dtag=abcd
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
56x16x4x64:56x16x64x40n"Transformer_lt:Decoder_MM_4*240"

--dt=u8:s8:u8 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abcd --dtag=abcd
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
56x16x4x40:56x16x40x64n"Transformer_lt:Decoder_MM_5*240"

--dt=u8:s8:bf16 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 2 per each encoder and sequence length, 480 in total for seq_len = 40
224x1024:1024x1024n"Transformer_lt:Decoder_MM_6*480"

--dt=u8:s8:u8 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
224x1024:1024x4096n"Transformer_lt:Decoder_MM_7*240"

--dt=u8:s8:bf16 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 1 per each encoder and sequence length, 240 in total for seq_len = 40
224x4096:4096x1024n"Transformer_lt:Decoder_MM_8*240"

--dt=u8:s8:bf16 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=src:common:7*+dst:common:3*
--stag=ab --wtag=any --dtag=ab
# 1 per each sequence length, 40 in total for seq_len = 40
224x1024:1024x32768n"Transformer_lt:Decoder_vocabulary*40"

# for each encoder and sequence length, there is set of problems like
# batch x 16 x 1 x 64 : batch x 16 x 64 x sl
# batch x 16 x 1 x sl : batch x 16 x sl x 64
# for  1 <= sl  <= seq_len
# instead of running whole set of such problems we fix sl = 20 and consider
# this problem is running 6*seq_len times

--dt=u8:s8:bf16 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abdc --dtag=abcd
56x16x1x64:56x16x64x20n"Transformer_lt:Decoder_MM_xx20*240"
--dt=u8:s8:u8 --bia_dt=undef
--attr-zero-points=src:common:7*+wei:common:-2*+dst:common:3*
--stag=abcd --wtag=abcd --dtag=abcd
56x16x1x20:56x16x20x64n"Transformer_lt:Decoder_MM_yy20*240"
