Source code for scitex_ml.utils._check_params

#!/usr/bin/env python3
# Time-stamp: "2024-02-17 12:38:40 (ywatanabe)"

from pprint import pprint as _pprint
from time import sleep

# def get_params(model, tgt_name=None, sleep_sec=2, show=False):

#     name_shape_dict = {}
#     for name, param in model.named_parameters():
#         learnable = "Learnable" if param.requires_grad else "Freezed"

#         if (tgt_name is not None) & (name == tgt_name):
#             return param
#         if tgt_name is None:
#             # print(f"\n{param}\n{param.shape}\nname: {name}\n")
#             if show is True:
#                 print(
#                     f"\n{param}: {param.shape}\nname: {name}\nStatus: {learnable}\n"
#                 )
#                 sleep(sleep_sec)
#             name_shape_dict[name] = list(param.shape)

#     if tgt_name is None:
#         print()
#         _pprint(name_shape_dict)
#         print()


[docs] def check_params(model, tgt_name=None, show=False): out_dict = {} for name, param in model.named_parameters(): learnable = "Learnable" if param.requires_grad else "Freezed" if tgt_name is None: out_dict[name] = (param.shape, learnable) elif (tgt_name is not None) & (name == tgt_name): out_dict[name] = (param.shape, learnable) elif (tgt_name is not None) & (name != tgt_name): continue if show: for k, v in out_dict.items(): print(f"\n{k}\n{v}") return out_dict