Source code for oumi.core.synthesis.attribute_transformation

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import uuid
from typing import Any, Union

from oumi.core.configs.params.synthesis_params import (
    GeneralSynthesisParams,
    TransformationStrategy,
    TransformationType,
    TransformedAttribute,
)
from oumi.core.synthesis.attribute_formatter import AttributeFormatter
from oumi.core.types.conversation import Conversation, Message

SampleValue = Union[str, list[str], dict[str, str], Conversation]


[docs] class AttributeTransformer: """Transforms attributes of a dataset plan to a particular format.""" def __init__(self, params: GeneralSynthesisParams): """Initializes the attribute transformer. Args: params: The general synthesis parameters containing the transformed attributes. """ self._formatter = AttributeFormatter(params) self._transformed_attributes = ( params.transformed_attributes if params.transformed_attributes else [] )
[docs] def transform( self, samples: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Transforms attributes of a dataset plan to a particular format. Args: samples: The samples to add transformed attributes to, using the values in each sample as the input to the transformation. Returns: The samples with the transformed attributes added. """ for attribute in self._transformed_attributes: transformed_attribute_id = attribute.id for sample in samples: sample[transformed_attribute_id] = self._transform_attribute( sample, attribute, ) return samples
def _transform_attribute( self, sample: dict[str, Any], attribute: TransformedAttribute, ) -> SampleValue: """Transforms an attribute of a sample to a particular format.""" strategy = attribute.get_strategy() if strategy.type == TransformationType.STRING: assert strategy.string_transform is not None # Validated in __post_init__ return self._transform_string(sample, strategy.string_transform) elif strategy.type == TransformationType.LIST: return self._transform_list(sample, strategy) elif strategy.type == TransformationType.DICT: return self._transform_dict(sample, strategy) elif strategy.type == TransformationType.CHAT: return self._transform_chat(sample, strategy, attribute.id) else: raise ValueError(f"Unsupported transformation strategy: {strategy.type}") def _transform_string( self, sample: dict[str, SampleValue], transform: str, ) -> str: """Transforms a string attribute of a sample to a particular format.""" string_sample = {k: v for k, v in sample.items() if isinstance(v, str)} formatted_string = self._formatter.format( string_sample, transform, missing_values_allowed=False, ) return formatted_string def _transform_list( self, sample: dict[str, SampleValue], transform: TransformationStrategy, ) -> list[str]: """Transforms a list attribute of a sample to a particular format.""" assert transform.list_transform is not None return [self._transform_string(sample, e) for e in transform.list_transform] def _transform_dict( self, sample: dict[str, SampleValue], transform: TransformationStrategy, ) -> dict[str, str]: """Transforms a dict attribute of a sample to a particular format.""" assert transform.dict_transform is not None # Validated in __post_init__ return { k: self._transform_string(sample, v) for k, v in transform.dict_transform.items() } def _transform_chat( self, sample: dict[str, SampleValue], transform: TransformationStrategy, attribute_id: str, ) -> dict[str, Any]: """Transforms a chat attribute of a sample to a particular format.""" assert transform.chat_transform is not None # Validated in __post_init__ messages = [] for message in transform.chat_transform.messages: content = message.content if not isinstance(content, str): raise ValueError( "ChatTransform.transforms.messages.content must be a string." ) formatted_content = self._transform_string(sample, content) messages.append(Message(role=message.role, content=formatted_content)) transformed_metadata = {} if transform.chat_transform.metadata: # Create a TransformationStrategy for the metadata dict transformation metadata_transform = TransformationStrategy( type=TransformationType.DICT, dict_transform=transform.chat_transform.metadata, ) transformed_metadata = self._transform_dict(sample, metadata_transform) new_conv_id = transform.chat_transform.conversation_id if not transform.chat_transform.conversation_id: new_conv_id = f"{attribute_id}-{uuid.uuid4()}" return Conversation( messages=messages, conversation_id=new_conv_id, metadata=transformed_metadata, ).to_dict()