Source code for oumi.core.synthesis.synthesis_pipeline
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any
from oumi.core.configs.synthesis_config import SynthesisConfig
from oumi.core.synthesis.attribute_synthesizer import AttributeSynthesizer
from oumi.core.synthesis.attribute_transformation import AttributeTransformer
from oumi.core.synthesis.data_synthesizer import DataSynthesizer
from oumi.core.synthesis.dataset_planner import DatasetPlanner
from oumi.utils.io_utils import save_jsonlines
from oumi.utils.logging import logger
[docs]
class SynthesisPipeline:
"""Pipeline for synthesizing a dataset."""
def __init__(self, config: SynthesisConfig):
"""Initialize the synthesis pipeline."""
self._config = config
attribute_synthesizer = AttributeSynthesizer(
config.strategy_params, config.inference_config
)
self._attribute_transformer = AttributeTransformer(config.strategy_params)
self._dataset_planner = DatasetPlanner()
self._data_synthesizer = (
DataSynthesizer(
config.strategy_params.generated_attributes,
attribute_synthesizer,
)
if config.strategy_params.generated_attributes
else None
)
[docs]
def synthesize(self) -> list[dict[str, Any]]:
"""Synthesize a dataset."""
# Populate the dataset plan with column values for each non-generated attribute
logger.info(
f"Loading dependencies to synthesize dataset with "
f"{self._config.num_samples} samples"
)
dataset = self._dataset_planner.plan(
self._config.strategy_params,
self._config.num_samples,
)
# Synthesize the generated attributes
logger.info("Synthesizing generated attributes")
if self._data_synthesizer:
dataset = self._data_synthesizer.synthesize(dataset)
# Add the transformed attributes to the dataset
logger.info("Adding transformed attributes")
if self._config.strategy_params.transformed_attributes:
dataset = self._attribute_transformer.transform(dataset)
# If passthrough attributes are specified, keep only those attributes
logger.info("Keeping passthrough attributes")
if self._config.strategy_params.passthrough_attributes:
dataset = self._passthrough_attributes(dataset)
# Save the dataset to the output path
logger.info("Saving dataset")
if self._config.output_path:
self._save_dataset(dataset)
logger.info("Synthesis complete")
return dataset
def _passthrough_attributes(
self,
dataset: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Keep only the passthrough attributes in the dataset."""
if not self._config.strategy_params.passthrough_attributes:
return dataset
passthrough_attributes = set(
self._config.strategy_params.passthrough_attributes
)
return [
{k: v for k, v in sample.items() if k in passthrough_attributes}
for sample in dataset
]
def _save_dataset(self, dataset: list[dict[str, Any]]):
"""Save the dataset to the output path."""
if not self._config.output_path:
raise ValueError("SynthesisConfig.output_path is not specified.")
path_str = self._config.output_path
path = Path(path_str)
parent = path.parent
if not parent.exists():
parent.mkdir(parents=True)
if path.suffix == ".jsonl":
save_jsonlines(path, dataset)
else:
raise ValueError(f"Unsupported output path: {path_str}")