This commit is contained in:
caandt
2025-03-13 16:56:36 -05:00
parent b2439eee3e
commit 046e80cdd1
27 changed files with 0 additions and 0 deletions
@@ -0,0 +1,74 @@
import json
import logging
import pathlib
from dataclasses import dataclass
from typing import Optional
@dataclass
class TrainingParameters:
batch_size: int
epochs: int
learning_rate: float
@dataclass
class SegmentationConfiguration:
base_repo_name: str
dataset_repo_name: str
pretrained_mlm_repo_name: str
cache_dir: pathlib.Path
max_token_length: int
dataset_percentage: int
mlm_training_parameters: TrainingParameters
segmentation_training_parameters: TrainingParameters
@property
def tokenizer_repo_name(self):
return self.base_repo_name + "-tokenizer"
@property
def tokenizer_json_path(self):
return self.cache_dir / "tokenizers" / self.tokenizer_repo_name / "tokenizer.json"
@property
def tokenized_dataset_repo_name(self):
return self.dataset_repo_name + "-tokenized"
@property
def mlm_repo_name(self):
return self.base_repo_name + "-mlm"
@property
def mlm_dir(self):
return self.cache_dir / "models" / self.mlm_repo_name
@property
def segmenter_repo_name(self):
return self.base_repo_name + "-segmenter"
@property
def segmenter_dir(self):
return self.cache_dir / "models" / self.segmenter_repo_name
@property
def dataset_dir(self):
return self.cache_dir / "datasets" / self.dataset_repo_name
def __post_init__(self):
self.cache_dir = pathlib.Path(self.cache_dir)
def parse_segmentation_config_json(json_file_path: pathlib.Path, logger: Optional[logging.Logger] = None) -> SegmentationConfiguration:
if not json_file_path.exists():
raise FileNotFoundError(f"{json_file_path} does not exist")
if logger:
logger.info(f"Loading model description from {json_file_path}...")
with json_file_path.open() as json_file:
segmentation_config_dict = json.load(json_file)
segmentation_config_dict["mlm_training_parameters"] = TrainingParameters(**segmentation_config_dict["mlm_training_parameters"])
segmentation_config_dict["segmentation_training_parameters"] = TrainingParameters(**segmentation_config_dict["segmentation_training_parameters"])
return SegmentationConfiguration(**segmentation_config_dict)