Source code for gobbli.test.model.test_transformer
import pytest
from gobbli.model.transformer import Transformer
[docs]@pytest.mark.parametrize(
"params,exception",
[
# Unknown param
({"unknown": None}, ValueError),
# Bad type (max_seq_length)
({"max_seq_length": "100"}, TypeError),
# OK type (max_seq_length)
({"max_seq_length": 100}, None),
# Bad type (config_overrides)
({"config_overrides": 1}, TypeError),
# OK type (config_overrides)
({"config_overrides": {}}, None),
# Bad type (lr)
({"lr": 1}, TypeError),
# OK type (lr)
({"lr": 1e-3}, None),
# Bad type (adam_eps)
({"adam_eps": 1}, TypeError),
# OK type (adam_eps)
({"adam_eps": 1e-5}, None),
# Bad type (gradient_accumulation_steps)
({"gradient_accumulation_steps": 1.0}, TypeError),
# OK type (gradient_accumulation_steps)
({"gradient_accumulation_steps": 2}, None),
# OK values (all params),
(
{
"max_seq_length": 100,
"config_overrides": {},
"lr": 1e-3,
"adam_eps": 1e-5,
"gradient_accumulation_steps": 2,
},
None,
),
],
)
def test_init(params, exception):
if exception is None:
Transformer(**params)
else:
with pytest.raises(exception):
Transformer(**params)