>>> import pyarrow as pa
>>> import pyarrow.compute as pc
>>> from typing import Iterator
>>>
>>> def double_v(batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
...     for batch in batches:
...         v = pc.multiply(batch.column("v"), 2.0)
...         yield batch.set_column(1, "v", v)
...
>>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0)], ["id", "v"])
>>> df.groupby("id").applyInArrow(double_v, schema="id long, v double").sort("id", "v").show()
+---+---+
| id|  v|
+---+---+
|  1|2.0|
|  1|4.0|
|  2|6.0|
+---+---+
