Source code for quasimoto.wave.writer
"""
A module implementing interfaces for writing WAVE files.
"""
# built-in
from contextlib import contextmanager
import os
from pathlib import Path
from typing import BinaryIO, Iterable, Iterator
# third-party
from runtimepy.primitives import Int16
# internal
from quasimoto.enums import ChunkType
from quasimoto.riff import RiffInterface
from quasimoto.riff.chunk import Chunk
from quasimoto.wave.mixins import FormatMixin
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_CHANNELS = 2
DEFAULT_BITS = 16
[docs]
class WaveWriter(FormatMixin):
"""A class for reading and writing WAVE files."""
def __init__(
self,
riff: RiffInterface,
num_channels: int = DEFAULT_CHANNELS,
sample_rate: int = DEFAULT_SAMPLE_RATE,
bits_per_sample: int = DEFAULT_BITS,
) -> None:
"""Initialize this instance."""
super().__init__()
self.riff = riff
assert self.riff.is_writer
# Finish writing RIFF header.
ChunkType.WAVE.to_stream(self.riff.stream)
assert (num_channels * bits_per_sample) % 8 == 0
class_num = num_channels * bits_per_sample // 8
# Write 'fmt ' chunk.
self.format["type"] = "pcm"
self.format["channels"] = num_channels
self.format["sample_rate"] = sample_rate
self.format["bytes_per_second"] = int(class_num * sample_rate)
self.format["class"] = class_num
self.format["bits_per_sample"] = bits_per_sample
data = bytes(self.format)
self.riff.write(Chunk(ChunkType.FMT, len(data), data=data))
# Write 'data' chunk header.
ChunkType.DATA.to_stream(self.riff.stream)
[docs]
@classmethod
def to_bytes(cls, value: int) -> bytes:
"""Convert a sample value to bytes."""
# Note the underlying type assumption.
return Int16.kind.encode(value, byte_order=cls.byte_order)
[docs]
@classmethod
def to_stream(cls, stream: BinaryIO, value: int) -> int:
"""Write a value to a stream."""
data = cls.to_bytes(value)
stream.write(data)
return len(data)
[docs]
def write(self, samples: Iterable[tuple[int, ...]]) -> None:
"""Write samples to the output."""
# Only support writing 16-bit samples.
assert self.sample_bytes == 2
with self.log_time("Writing samples", reminder=True):
self.riff.stream.seek(0, os.SEEK_END)
size_pos = self.riff.stream.tell()
self.riff.write_size(0)
size = 0
for sample in samples:
for point in sample:
size += self.to_stream(self.riff.stream, point)
self.riff.write_size(size, seek=size_pos)
[docs]
@staticmethod
@contextmanager
def from_path(path: Path, **kwargs) -> Iterator["WaveWriter"]:
"""Get a WAVE reader from a path."""
with RiffInterface.from_path(path) as riff:
yield WaveWriter(riff, **kwargs)