Coverage for dj/sql/functions.py: 100%

393 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1# pylint: disable=too-many-lines 

2# mypy: ignore-errors 

3 

4""" 

5SQL functions for type inference. 

6 

7This file holds all the functions that we want to support in the SQL used to define 

8nodes. The functions are used to infer types. 

9 

10Spark function reference 

11https://github.com/apache/spark/tree/74cddcfda3ac4779de80696cdae2ba64d53fc635/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions 

12 

13Java strictmath reference 

14https://docs.oracle.com/javase/8/docs/api/java/lang/StrictMath.html 

15 

16Databricks reference: 

17https://docs.databricks.com/sql/language-manual/sql-ref-functions-builtin-alpha.html 

18""" 

19 

20import inspect 

21import re 

22from itertools import zip_longest 

23 

24# pylint: disable=unused-argument, missing-function-docstring, arguments-differ, too-many-return-statements 

25from typing import ( 

26 TYPE_CHECKING, 

27 Callable, 

28 ClassVar, 

29 Dict, 

30 List, 

31 Optional, 

32 Tuple, 

33 Type, 

34 Union, 

35 get_origin, 

36) 

37 

38import dj.sql.parsing.types as ct 

39from dj.errors import ( 

40 DJError, 

41 DJInvalidInputException, 

42 DJNotImplementedException, 

43 ErrorCode, 

44) 

45from dj.sql.parsing.backends.exceptions import DJParseException 

46 

47if TYPE_CHECKING: 

48 from dj.sql.parsing.ast import Expression 

49 

50 

51def compare_registers(types, register) -> bool: 

52 """ 

53 Comparing registers 

54 """ 

55 for ((type_a, register_a), (type_b, register_b)) in zip_longest( 

56 types, 

57 register, 

58 fillvalue=(-1, None), 

59 ): 

60 if type_b == -1 and register_b is None: 

61 if register[-1][0] == -1: # args 

62 register_b = register[-1][1] 

63 else: 

64 return False # pragma: no cover 

65 if type_a == -1: 

66 register_a = type(register_a) 

67 if not issubclass(register_a, register_b): # type: ignore 

68 return False 

69 return True 

70 

71 

72class DispatchMeta(type): 

73 """ 

74 Dispatch abstract class for function registry 

75 """ 

76 

77 def __getattribute__(cls, func_name): # pylint: disable=redefined-outer-name 

78 if func_name in type.__getattribute__(cls, "registry").get(cls, {}): 

79 

80 def dynamic_dispatch(*args: "Expression"): 

81 return cls.dispatch(func_name, *args)(*args) 

82 

83 return dynamic_dispatch 

84 return type.__getattribute__(cls, func_name) 

85 

86 

87class Dispatch(metaclass=DispatchMeta): 

88 """ 

89 Function registry 

90 """ 

91 

92 registry: ClassVar[Dict[str, Dict[Tuple[Tuple[int, Type]], Callable]]] = {} 

93 

94 @classmethod 

95 def register(cls, func): # pylint: disable=redefined-outer-name 

96 func_name = func.__name__ 

97 params = inspect.signature(func).parameters 

98 spread_types = [[]] 

99 cls.registry[cls] = cls.registry.get(cls) or {} 

100 cls.registry[cls][func_name] = cls.registry[cls].get(func_name) or {} 

101 for i, (key, value) in enumerate(params.items()): 

102 name = str(value).split(":", maxsplit=1)[0] 

103 if name.startswith("**"): 

104 raise ValueError( 

105 "kwargs are not supported in dispatch.", 

106 ) # pragma: no cover 

107 if name.startswith("*"): 

108 i = -1 

109 type_ = params[key].annotation 

110 if type_ == inspect.Parameter.empty: 

111 raise ValueError( # pragma: no cover 

112 "All arguments must have a type annotation.", 

113 ) 

114 inner_types = [type_] 

115 if get_origin(type_) == Union: 

116 inner_types = type_.__args__ 

117 for _ in inner_types: 

118 spread_types += spread_types[:] 

119 temp = [] 

120 for type_ in inner_types: 

121 for types in spread_types: 

122 temp.append(types[:]) 

123 temp[-1].append((i, type_)) 

124 spread_types = temp 

125 for types in spread_types: 

126 cls.registry[cls][func_name][tuple(types)] = func # type: ignore 

127 

128 @classmethod 

129 def dispatch( # pylint: disable=redefined-outer-name 

130 cls, func_name, *args: "Expression" 

131 ): 

132 type_registry = cls.registry[cls].get(func_name) # type: ignore 

133 if not type_registry: 

134 raise ValueError( 

135 f"No function registered on {cls.__name__}`{func_name}`.", 

136 ) # pragma: no cover 

137 

138 type_list = [] 

139 for i, arg in enumerate(args): 

140 type_list.append((i, type(arg.type) if hasattr(arg, "type") else type(arg))) 

141 

142 types = tuple(type_list) 

143 

144 if types in type_registry: # type: ignore 

145 return type_registry[types] # type: ignore 

146 

147 for register, func in type_registry.items(): # type: ignore 

148 if compare_registers(types, register): 

149 return func 

150 

151 raise TypeError( 

152 f"`{cls.__name__}.{func_name}` got an invalid " 

153 "combination of types: " 

154 f'{", ".join(str(t[1].__name__) for t in types)}', 

155 ) 

156 

157 

158class Function(Dispatch): # pylint: disable=too-few-public-methods 

159 """ 

160 A DJ function. 

161 """ 

162 

163 is_aggregation: ClassVar[bool] = False 

164 

165 @staticmethod 

166 def infer_type(*args) -> ct.ColumnType: 

167 raise NotImplementedError() 

168 

169 

170class TableFunction(Dispatch): # pylint: disable=too-few-public-methods 

171 """ 

172 A DJ table-valued function. 

173 """ 

174 

175 @staticmethod 

176 def infer_type(*args) -> List[ct.ColumnType]: 

177 raise NotImplementedError() 

178 

179 

180class Avg(Function): # pylint: disable=abstract-method 

181 """ 

182 Computes the average of the input column or expression. 

183 """ 

184 

185 is_aggregation = True 

186 

187 

188@Avg.register 

189def infer_type( 

190 arg: ct.DecimalType, 

191) -> ct.DecimalType: # noqa: F811 # pylint: disable=function-redefined 

192 type_ = arg.type 

193 return ct.DecimalType(type_.precision + 4, type_.scale + 4) 

194 

195 

196@Avg.register # type: ignore 

197def infer_type( # noqa: F811 # pylint: disable=function-redefined 

198 arg: ct.IntervalTypeBase, 

199) -> ct.IntervalTypeBase: 

200 return type(arg.type)() 

201 

202 

203@Avg.register # type: ignore 

204def infer_type( # noqa: F811 # pylint: disable=function-redefined 

205 arg: ct.NumberType, 

206) -> ct.DoubleType: 

207 return ct.DoubleType() 

208 

209 

210class Min(Function): # pylint: disable=abstract-method 

211 """ 

212 Computes the minimum value of the input column or expression. 

213 """ 

214 

215 is_aggregation = True 

216 

217 

218@Min.register # type: ignore 

219def infer_type( # noqa: F811 # pylint: disable=function-redefined 

220 arg: ct.NumberType, 

221) -> ct.NumberType: 

222 return arg.type 

223 

224 

225class Max(Function): # pylint: disable=abstract-method 

226 """ 

227 Computes the maximum value of the input column or expression. 

228 """ 

229 

230 is_aggregation = True 

231 

232 

233@Max.register # type: ignore 

234def infer_type( # noqa: F811 # pylint: disable=function-redefined 

235 arg: ct.NumberType, 

236) -> ct.NumberType: 

237 return arg.type 

238 

239 

240@Max.register # type: ignore 

241def infer_type( # noqa: F811 # pylint: disable=function-redefined 

242 arg: ct.StringType, 

243) -> ct.StringType: 

244 return arg.type 

245 

246 

247class Sum(Function): # pylint: disable=abstract-method 

248 """ 

249 Computes the sum of the input column or expression. 

250 """ 

251 

252 is_aggregation = True 

253 

254 

255@Sum.register # type: ignore 

256def infer_type( # noqa: F811 # pylint: disable=function-redefined 

257 arg: ct.IntegerBase, 

258) -> ct.BigIntType: 

259 return ct.BigIntType() 

260 

261 

262@Sum.register # type: ignore 

263def infer_type( # noqa: F811 # pylint: disable=function-redefined 

264 arg: ct.DecimalType, 

265) -> ct.DecimalType: 

266 precision = arg.type.precision 

267 scale = arg.type.scale 

268 return ct.DecimalType(precision + min(10, 31 - precision), scale) 

269 

270 

271@Sum.register # type: ignore 

272def infer_type( # noqa: F811 # pylint: disable=function-redefined 

273 arg: Union[ct.NumberType, ct.IntervalTypeBase], 

274) -> ct.DoubleType: 

275 return ct.DoubleType() 

276 

277 

278class Ceil(Function): # pylint: disable=abstract-method 

279 """ 

280 Computes the smallest integer greater than or equal to the input value. 

281 """ 

282 

283 

284@Ceil.register 

285def infer_type( # noqa: F811 # pylint: disable=function-redefined 

286 args: ct.NumberType, 

287 _target_scale: ct.IntegerType, 

288) -> ct.DecimalType: 

289 target_scale = _target_scale.value 

290 if isinstance(args.type, ct.DecimalType): 

291 precision = max(args.type.precision - args.type.scale + 1, -target_scale + 1) 

292 scale = min(args.type.scale, max(0, target_scale)) 

293 return ct.DecimalType(precision, scale) 

294 if args.type == ct.TinyIntType(): 

295 precision = max(3, -target_scale + 1) 

296 return ct.DecimalType(precision, 0) 

297 if args.type == ct.SmallIntType(): 

298 precision = max(5, -target_scale + 1) 

299 return ct.DecimalType(precision, 0) 

300 if args.type == ct.IntegerType(): 

301 precision = max(10, -target_scale + 1) 

302 return ct.DecimalType(precision, 0) 

303 if args.type == ct.BigIntType(): 

304 precision = max(20, -target_scale + 1) 

305 return ct.DecimalType(precision, 0) 

306 if args.type == ct.FloatType(): 

307 precision = max(14, -target_scale + 1) 

308 scale = min(7, max(0, target_scale)) 

309 return ct.DecimalType(precision, scale) 

310 if args.type == ct.DoubleType(): 

311 precision = max(30, -target_scale + 1) 

312 scale = min(15, max(0, target_scale)) 

313 return ct.DecimalType(precision, scale) 

314 

315 raise DJParseException( 

316 f"Unhandled numeric type in Ceil `{args.type}`", 

317 ) # pragma: no cover 

318 

319 

320@Ceil.register 

321def infer_type( # noqa: F811 # pylint: disable=function-redefined 

322 args: ct.DecimalType, 

323) -> ct.DecimalType: 

324 return ct.DecimalType(args.type.precision - args.type.scale + 1, 0) 

325 

326 

327@Ceil.register 

328def infer_type( # noqa: F811 # pylint: disable=function-redefined 

329 args: ct.NumberType, 

330) -> ct.BigIntType: 

331 return ct.BigIntType() 

332 

333 

334class Count(Function): # pylint: disable=abstract-method 

335 """ 

336 Counts the number of non-null values in the input column or expression. 

337 """ 

338 

339 is_aggregation = True 

340 

341 

342@Count.register # type: ignore 

343def infer_type( # noqa: F811 # pylint: disable=function-redefined 

344 *args: ct.ColumnType, 

345) -> ct.BigIntType: 

346 return ct.BigIntType() 

347 

348 

349class Coalesce(Function): # pylint: disable=abstract-method 

350 """ 

351 Computes the average of the input column or expression. 

352 """ 

353 

354 is_aggregation = False 

355 

356 

357@Coalesce.register # type: ignore 

358def infer_type( # noqa: F811 # pylint: disable=function-redefined 

359 *args: ct.ColumnType, 

360) -> ct.ColumnType: 

361 if not args: # pragma: no cover 

362 raise DJInvalidInputException( 

363 message="Wrong number of arguments to function", 

364 errors=[ 

365 DJError( 

366 code=ErrorCode.INVALID_ARGUMENTS_TO_FUNCTION, 

367 message="You need to pass at least one argument to `COALESCE`.", 

368 ), 

369 ], 

370 ) 

371 for arg in args: 

372 if arg.type != ct.NullType(): 

373 return arg.type 

374 return ct.NullType() 

375 

376 

377class CurrentDate(Function): # pylint: disable=abstract-method 

378 """ 

379 Returns the current date. 

380 """ 

381 

382 

383@CurrentDate.register # type: ignore 

384def infer_type() -> ct.DateType: # noqa: F811 # pylint: disable=function-redefined 

385 return ct.DateType() 

386 

387 

388class CurrentDatetime(Function): # pylint: disable=abstract-method 

389 """ 

390 Returns the current date and time. 

391 """ 

392 

393 

394@CurrentDatetime.register # type: ignore 

395def infer_type() -> ct.TimestampType: # noqa: F811 # pylint: disable=function-redefined 

396 return ct.TimestampType() 

397 

398 

399class CurrentTime(Function): # pylint: disable=abstract-method 

400 """ 

401 Returns the current time. 

402 """ 

403 

404 

405@CurrentTime.register # type: ignore 

406def infer_type() -> ct.TimeType: # noqa: F811 # pylint: disable=function-redefined 

407 return ct.TimeType() 

408 

409 

410class CurrentTimestamp(Function): # pylint: disable=abstract-method 

411 """ 

412 Returns the current timestamp. 

413 """ 

414 

415 

416@CurrentTimestamp.register # type: ignore 

417def infer_type() -> ct.TimestampType: # noqa: F811 # pylint: disable=function-redefined 

418 return ct.TimestampType() 

419 

420 

421class Now(Function): # pylint: disable=abstract-method 

422 """ 

423 Returns the current timestamp. 

424 """ 

425 

426 

427@Now.register # type: ignore 

428def infer_type() -> ct.TimestamptzType: # noqa: F811 # pylint: disable=function-redefined 

429 return ct.TimestamptzType() 

430 

431 

432class DateAdd(Function): # pylint: disable=abstract-method 

433 """ 

434 Adds a specified number of days to a date. 

435 """ 

436 

437 

438@DateAdd.register # type: ignore 

439def infer_type( # noqa: F811 # pylint: disable=function-redefined 

440 start_date: ct.DateType, 

441 days: ct.IntegerBase, 

442) -> ct.DateType: 

443 return ct.DateType() 

444 

445 

446@DateAdd.register # type: ignore 

447def infer_type( # noqa: F811 # pylint: disable=function-redefined 

448 start_date: ct.StringType, 

449 days: ct.IntegerBase, 

450) -> ct.DateType: 

451 return ct.DateType() 

452 

453 

454class DateSub(Function): # pylint: disable=abstract-method 

455 """ 

456 Subtracts a specified number of days from a date. 

457 """ 

458 

459 

460@DateSub.register # type: ignore 

461def infer_type( # noqa: F811 # pylint: disable=function-redefined 

462 start_date: ct.DateType, 

463 days: ct.IntegerBase, 

464) -> ct.DateType: 

465 return ct.DateType() 

466 

467 

468@DateSub.register # type: ignore 

469def infer_type( # noqa: F811 # pylint: disable=function-redefined 

470 start_date: ct.StringType, 

471 days: ct.IntegerBase, 

472) -> ct.DateType: 

473 return ct.DateType() 

474 

475 

476class If(Function): # pylint: disable=abstract-method 

477 """ 

478 If statement 

479 

480 if(condition, result, else_result): if condition evaluates to true, 

481 then returns result; otherwise returns else_result. 

482 """ 

483 

484 

485@If.register # type: ignore 

486def infer_type( # noqa: F811 # pylint: disable=function-redefined 

487 cond: ct.BooleanType, 

488 then: ct.ColumnType, 

489 else_: ct.ColumnType, 

490) -> ct.ColumnType: 

491 if then.type != else_.type: 

492 raise DJInvalidInputException( 

493 message="The then result and else result must match in type! " 

494 f"Got {then.type} and {else_.type}", 

495 ) 

496 

497 return then.type 

498 

499 

500class DateDiff(Function): # pylint: disable=abstract-method 

501 """ 

502 Computes the difference in days between two dates. 

503 """ 

504 

505 

506@DateDiff.register # type: ignore 

507def infer_type( # noqa: F811 # pylint: disable=function-redefined 

508 start_date: ct.DateType, 

509 end_date: ct.DateType, 

510) -> ct.IntegerType: 

511 return ct.IntegerType() 

512 

513 

514@DateDiff.register # type: ignore 

515def infer_type( # noqa: F811 # pylint: disable=function-redefined 

516 start_date: ct.StringType, 

517 end_date: ct.StringType, 

518) -> ct.IntegerType: 

519 return ct.IntegerType() 

520 

521 

522class Extract(Function): 

523 """ 

524 Returns a specified component of a timestamp, such as year, month or day. 

525 """ 

526 

527 @staticmethod 

528 def infer_type( # type: ignore 

529 field: "Expression", 

530 source: "Expression", 

531 ) -> Union[ct.DecimalType, ct.IntegerType]: 

532 if str(field.name) == "SECOND": # type: ignore 

533 return ct.DecimalType(8, 6) 

534 return ct.IntegerType() 

535 

536 

537class ToDate(Function): # pragma: no cover # pylint: disable=abstract-method 

538 """ 

539 Converts a date string to a date value. 

540 """ 

541 

542 

543@ToDate.register # type: ignore 

544def infer_type( # noqa: F811 # pylint: disable=function-redefined 

545 expr: ct.StringType, 

546 fmt: Optional[ct.StringType] = None, 

547) -> ct.DateType: 

548 return ct.DateType() 

549 

550 

551class Day(Function): # pylint: disable=abstract-method 

552 """ 

553 Returns the day of the month for a specified date. 

554 """ 

555 

556 

557@Day.register # type: ignore 

558def infer_type( # noqa: F811 # pylint: disable=function-redefined 

559 arg: Union[ct.StringType, ct.DateType, ct.TimestampType], 

560) -> ct.IntegerType: # type: ignore 

561 return ct.IntegerType() 

562 

563 

564class Exp(Function): # pylint: disable=abstract-method 

565 """ 

566 Returns e to the power of expr. 

567 """ 

568 

569 

570@Exp.register # type: ignore 

571def infer_type( # noqa: F811 # pylint: disable=function-redefined 

572 args: ct.ColumnType, 

573) -> ct.DoubleType: 

574 return ct.DoubleType() 

575 

576 

577class Floor(Function): # pylint: disable=abstract-method 

578 """ 

579 Returns the largest integer less than or equal to a specified number. 

580 """ 

581 

582 

583@Floor.register # type: ignore 

584def infer_type( # noqa: F811 # pylint: disable=function-redefined 

585 args: ct.DecimalType, 

586) -> ct.DecimalType: 

587 return ct.DecimalType(args.type.precision - args.type.scale + 1, 0) 

588 

589 

590@Floor.register # type: ignore 

591def infer_type( # noqa: F811 # pylint: disable=function-redefined 

592 args: ct.NumberType, 

593) -> ct.BigIntType: 

594 return ct.BigIntType() 

595 

596 

597@Floor.register # type: ignore 

598def infer_type( # noqa: F811 # pylint: disable=function-redefined 

599 args: ct.NumberType, 

600 _target_scale: ct.IntegerType, 

601) -> ct.DecimalType: 

602 target_scale = _target_scale.value 

603 if isinstance(args.type, ct.DecimalType): # pylint: disable=R1705 

604 precision = max(args.type.precision - args.type.scale + 1, -target_scale + 1) 

605 scale = min(args.type.scale, max(0, target_scale)) 

606 return ct.DecimalType(precision, scale) 

607 if args.type == ct.TinyIntType(): 

608 precision = max(3, -target_scale + 1) 

609 return ct.DecimalType(precision, 0) 

610 if args.type == ct.SmallIntType(): 

611 precision = max(5, -target_scale + 1) 

612 return ct.DecimalType(precision, 0) 

613 if args.type == ct.IntegerType(): 

614 precision = max(10, -target_scale + 1) 

615 return ct.DecimalType(precision, 0) 

616 if args.type == ct.BigIntType(): 

617 precision = max(20, -target_scale + 1) 

618 return ct.DecimalType(precision, 0) 

619 if args.type == ct.FloatType(): 

620 precision = max(14, -target_scale + 1) 

621 scale = min(7, max(0, target_scale)) 

622 return ct.DecimalType(precision, scale) 

623 if args.type == ct.DoubleType(): 

624 precision = max(30, -target_scale + 1) 

625 scale = min(15, max(0, target_scale)) 

626 return ct.DecimalType(precision, scale) 

627 

628 raise DJParseException( 

629 f"Unhandled numeric type in Floor `{args.type}`", 

630 ) # pragma: no cover 

631 

632 

633class IfNull(Function): 

634 """ 

635 Returns the second expression if the first is null, else returns the first expression. 

636 """ 

637 

638 @staticmethod 

639 def infer_type(*args: "Expression") -> ct.ColumnType: # type: ignore 

640 return ( # type: ignore 

641 args[0].type if args[1].type == ct.NullType() else args[1].type 

642 ) 

643 

644 

645class Length(Function): # pylint: disable=abstract-method 

646 """ 

647 Returns the length of a string. 

648 """ 

649 

650 

651@Length.register # type: ignore 

652def infer_type( # noqa: F811 # pylint: disable=function-redefined 

653 arg: ct.StringType, 

654) -> ct.IntegerType: 

655 return ct.IntegerType() 

656 

657 

658class Levenshtein(Function): # pylint: disable=abstract-method 

659 """ 

660 Returns the Levenshtein distance between two strings. 

661 """ 

662 

663 

664@Levenshtein.register # type: ignore 

665def infer_type( # noqa: F811 # pylint: disable=function-redefined 

666 string1: ct.StringType, 

667 string2: ct.StringType, 

668) -> ct.IntegerType: 

669 return ct.IntegerType() 

670 

671 

672class Ln(Function): # pylint: disable=abstract-method 

673 """ 

674 Returns the natural logarithm of a number. 

675 """ 

676 

677 

678@Ln.register # type: ignore 

679def infer_type( # noqa: F811 # pylint: disable=function-redefined 

680 args: ct.ColumnType, 

681) -> ct.DoubleType: 

682 return ct.DoubleType() 

683 

684 

685class Log(Function): # pylint: disable=abstract-method 

686 """ 

687 Returns the logarithm of a number with the specified base. 

688 """ 

689 

690 

691@Log.register # type: ignore 

692def infer_type( # noqa: F811 # pylint: disable=function-redefined 

693 base: ct.ColumnType, 

694 expr: ct.ColumnType, 

695) -> ct.DoubleType: 

696 return ct.DoubleType() 

697 

698 

699class Log2(Function): # pylint: disable=abstract-method 

700 """ 

701 Returns the base-2 logarithm of a number. 

702 """ 

703 

704 

705@Log2.register # type: ignore 

706def infer_type( # noqa: F811 # pylint: disable=function-redefined 

707 args: ct.ColumnType, 

708) -> ct.DoubleType: 

709 return ct.DoubleType() 

710 

711 

712class Log10(Function): # pylint: disable=abstract-method 

713 """ 

714 Returns the base-10 logarithm of a number. 

715 """ 

716 

717 

718@Log10.register # type: ignore 

719def infer_type( # noqa: F811 # pylint: disable=function-redefined 

720 args: ct.ColumnType, 

721) -> ct.DoubleType: 

722 return ct.DoubleType() 

723 

724 

725class Lower(Function): 

726 """ 

727 Converts a string to lowercase. 

728 """ 

729 

730 @staticmethod 

731 def infer_type(arg: "Expression") -> ct.StringType: # type: ignore 

732 return ct.StringType() 

733 

734 

735class Month(Function): 

736 """ 

737 Extracts the month of a date or timestamp. 

738 """ 

739 

740 @staticmethod 

741 def infer_type(arg: "Expression") -> ct.TinyIntType: # type: ignore 

742 return ct.TinyIntType() 

743 

744 

745class Pow(Function): # pylint: disable=abstract-method 

746 """ 

747 Raises a base expression to the power of an exponent expression. 

748 """ 

749 

750 

751@Pow.register # type: ignore 

752def infer_type( # noqa: F811 # pylint: disable=function-redefined 

753 base: ct.ColumnType, 

754 power: ct.ColumnType, 

755) -> ct.DoubleType: 

756 return ct.DoubleType() 

757 

758 

759class PercentRank(Function): 

760 """ 

761 Window function: returns the relative rank (i.e. percentile) of rows within a window partition 

762 """ 

763 

764 is_aggregation = True 

765 

766 @staticmethod 

767 def infer_type() -> ct.DoubleType: 

768 return ct.DoubleType() 

769 

770 

771class Quantile(Function): # pragma: no cover 

772 """ 

773 Computes the quantile of a numerical column or expression. 

774 """ 

775 

776 is_aggregation = True 

777 

778 @staticmethod 

779 def infer_type( # type: ignore 

780 arg1: "Expression", 

781 arg2: "Expression", 

782 ) -> ct.DoubleType: 

783 return ct.DoubleType() 

784 

785 

786class ApproxQuantile(Function): # pragma: no cover 

787 """ 

788 Computes the approximate quantile of a numerical column or expression. 

789 """ 

790 

791 is_aggregation = True 

792 

793 @staticmethod 

794 def infer_type( # type: ignore 

795 arg1: "Expression", 

796 arg2: "Expression", 

797 ) -> ct.DoubleType: 

798 return ct.DoubleType() 

799 

800 

801class RegexpLike(Function): # pragma: no cover 

802 """ 

803 Matches a string column or expression against a regular expression pattern. 

804 """ 

805 

806 @staticmethod 

807 def infer_type( # type: ignore 

808 arg1: "Expression", 

809 arg2: "Expression", 

810 ) -> ct.BooleanType: 

811 return ct.BooleanType() 

812 

813 

814class Round(Function): # pylint: disable=abstract-method 

815 """ 

816 Rounds a numeric column or expression to the specified number of decimal places. 

817 """ 

818 

819 

820@Round.register # type: ignore 

821def infer_type( # noqa: F811 # pylint: disable=function-redefined 

822 child: ct.DecimalType, 

823 scale: ct.IntegerBase, 

824) -> ct.NumberType: 

825 child_type = child.type 

826 integral_least_num_digits = child_type.precision - child_type.scale + 1 

827 if scale.value < 0: 

828 new_precision = max( 

829 integral_least_num_digits, 

830 -scale.type.value + 1, 

831 ) # pragma: no cover 

832 return ct.DecimalType(new_precision, 0) # pragma: no cover 

833 new_scale = min(child_type.scale, scale.value) 

834 return ct.DecimalType(integral_least_num_digits + new_scale, new_scale) 

835 

836 

837@Round.register 

838def infer_type( # noqa: F811 # pylint: disable=function-redefined # type: ignore 

839 child: ct.NumberType, 

840 scale: ct.IntegerBase, 

841) -> ct.NumberType: 

842 return child.type 

843 

844 

845class SafeDivide(Function): # pragma: no cover 

846 """ 

847 Divides two numeric columns or expressions and returns NULL if the denominator is 0. 

848 """ 

849 

850 @staticmethod 

851 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.DoubleType: # type: ignore 

852 return ct.DoubleType() 

853 

854 

855class Substring(Function): 

856 """ 

857 Extracts a substring from a string column or expression. 

858 """ 

859 

860 @staticmethod 

861 def infer_type( # type: ignore 

862 arg1: "Expression", 

863 arg2: "Expression", 

864 arg3: "Expression", 

865 ) -> ct.StringType: 

866 return ct.StringType() 

867 

868 

869class StrPosition(Function): # pylint: disable=abstract-method 

870 """ 

871 Returns the position of the first occurrence of a substring in a string column or expression. 

872 """ 

873 

874 

875@StrPosition.register 

876def infer_type( # noqa: F811 # pylint: disable=function-redefined # pragma: no cover 

877 arg1: ct.StringType, 

878 arg2: ct.StringType, 

879) -> ct.IntegerType: 

880 return ct.IntegerType() # pragma: no cover 

881 

882 

883class StrToDate(Function): # pragma: no cover 

884 """ 

885 Converts a string in a specified format to a date. 

886 """ 

887 

888 @staticmethod 

889 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.DateType: 

890 return ct.DateType() 

891 

892 

893class StrToTime(Function): # pragma: no cover 

894 """ 

895 Converts a string in a specified format to a timestamp. 

896 """ 

897 

898 @staticmethod 

899 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.TimestampType: 

900 return ct.TimestampType() 

901 

902 

903class Sqrt(Function): 

904 """ 

905 Computes the square root of a numeric column or expression. 

906 """ 

907 

908 @staticmethod 

909 def infer_type(arg: "Expression") -> ct.DoubleType: 

910 return ct.DoubleType() 

911 

912 

913class Stddev(Function): 

914 """ 

915 Computes the sample standard deviation of a numerical column or expression. 

916 """ 

917 

918 is_aggregation = True 

919 

920 @staticmethod 

921 def infer_type(arg: "Expression") -> ct.DoubleType: 

922 return ct.DoubleType() 

923 

924 

925class StddevPop(Function): # pragma: no cover 

926 """ 

927 Computes the population standard deviation of the input column or expression. 

928 """ 

929 

930 is_aggregation = True 

931 

932 @staticmethod 

933 def infer_type(arg: "Expression") -> ct.DoubleType: 

934 return ct.DoubleType() 

935 

936 

937class StddevSamp(Function): # pragma: no cover 

938 """ 

939 Computes the sample standard deviation of the input column or expression. 

940 """ 

941 

942 is_aggregation = True 

943 

944 @staticmethod 

945 def infer_type(arg: "Expression") -> ct.DoubleType: 

946 return ct.DoubleType() 

947 

948 

949class TimeToStr(Function): # pragma: no cover 

950 """ 

951 Converts a time value to a string using the specified format. 

952 """ 

953 

954 @staticmethod 

955 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType: 

956 return ct.StringType() 

957 

958 

959class TimeToTimeStr(Function): # pragma: no cover 

960 """ 

961 Converts a time value to a string using the specified format. 

962 """ 

963 

964 @staticmethod 

965 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType: 

966 return ct.StringType() 

967 

968 

969class TimeStrToDate(Function): # pragma: no cover 

970 """ 

971 Converts a string value to a date. 

972 """ 

973 

974 @staticmethod 

975 def infer_type(arg: "Expression") -> ct.DateType: 

976 return ct.DateType() 

977 

978 

979class TimeStrToTime(Function): # pragma: no cover 

980 """ 

981 Converts a string value to a time. 

982 """ 

983 

984 @staticmethod 

985 def infer_type(arg: "Expression") -> ct.TimestampType: 

986 return ct.TimestampType() 

987 

988 

989class Trim(Function): # pragma: no cover 

990 """ 

991 Removes leading and trailing whitespace from a string value. 

992 """ 

993 

994 @staticmethod 

995 def infer_type(arg: "Expression") -> ct.StringType: 

996 return ct.StringType() 

997 

998 

999class TsOrDsToDateStr(Function): # pragma: no cover 

1000 """ 

1001 Converts a timestamp or date value to a string using the specified format. 

1002 """ 

1003 

1004 @staticmethod 

1005 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType: 

1006 return ct.StringType() 

1007 

1008 

1009class TsOrDsToDate(Function): # pragma: no cover 

1010 """ 

1011 Converts a timestamp or date value to a date. 

1012 """ 

1013 

1014 @staticmethod 

1015 def infer_type(arg: "Expression") -> ct.DateType: 

1016 return ct.DateType() 

1017 

1018 

1019class TsOrDiToDi(Function): # pragma: no cover 

1020 """ 

1021 Converts a timestamp or date value to a date. 

1022 """ 

1023 

1024 @staticmethod 

1025 def infer_type(arg: "Expression") -> ct.IntegerType: 

1026 return ct.IntegerType() 

1027 

1028 

1029class UnixToStr(Function): # pragma: no cover 

1030 """ 

1031 Converts a Unix timestamp to a string using the specified format. 

1032 """ 

1033 

1034 @staticmethod 

1035 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType: 

1036 return ct.StringType() 

1037 

1038 

1039class UnixToTime(Function): # pragma: no cover 

1040 """ 

1041 Converts a Unix timestamp to a time. 

1042 """ 

1043 

1044 @staticmethod 

1045 def infer_type(arg: "Expression") -> ct.TimestampType: 

1046 return ct.TimestampType() 

1047 

1048 

1049class UnixToTimeStr(Function): # pragma: no cover 

1050 """ 

1051 Converts a Unix timestamp to a string using the specified format. 

1052 """ 

1053 

1054 @staticmethod 

1055 def infer_type(arg1: "Expression", arg2: "Expression") -> ct.StringType: 

1056 return ct.StringType() 

1057 

1058 

1059class Upper(Function): # pragma: no cover 

1060 """ 

1061 Converts a string value to uppercase. 

1062 """ 

1063 

1064 @staticmethod 

1065 def infer_type(arg: "Expression") -> ct.StringType: 

1066 return ct.StringType() 

1067 

1068 

1069class Variance(Function): # pragma: no cover 

1070 """ 

1071 Computes the sample variance of the input column or expression. 

1072 """ 

1073 

1074 is_aggregation = True 

1075 

1076 @staticmethod 

1077 def infer_type(arg: "Expression") -> ct.DoubleType: 

1078 return ct.DoubleType() 

1079 

1080 

1081class VariancePop(Function): # pragma: no cover 

1082 """ 

1083 Computes the population variance of the input column or expression. 

1084 """ 

1085 

1086 is_aggregation = True 

1087 

1088 @staticmethod 

1089 def infer_type(arg: "Expression") -> ct.DoubleType: 

1090 return ct.DoubleType() 

1091 

1092 

1093class Array(Function): # pylint: disable=abstract-method 

1094 """ 

1095 Returns an array of constants 

1096 """ 

1097 

1098 

1099@Array.register # type: ignore 

1100def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1101 *elements: ct.ColumnType, 

1102) -> ct.ListType: 

1103 types = {element.type for element in elements} 

1104 if len(types) > 1: 

1105 raise DJParseException( 

1106 f"Multiple types {', '.join(sorted(str(typ) for typ in types))} passed to array.", 

1107 ) 

1108 element_type = elements[0].type if elements else ct.NullType() 

1109 return ct.ListType(element_type=element_type) 

1110 

1111 

1112class Map(Function): # pylint: disable=abstract-method 

1113 """ 

1114 Returns a map of constants 

1115 """ 

1116 

1117 

1118def extract_consistent_type(elements): 

1119 """ 

1120 Check if all elements are the same type and return that type. 

1121 """ 

1122 if all(isinstance(element.type, ct.IntegerType) for element in elements): 

1123 return ct.IntegerType() 

1124 if all(isinstance(element.type, ct.DoubleType) for element in elements): 

1125 return ct.DoubleType() 

1126 if all(isinstance(element.type, ct.FloatType) for element in elements): 

1127 return ct.FloatType() 

1128 return ct.StringType() 

1129 

1130 

1131@Map.register # type: ignore 

1132def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1133 *elements: ct.ColumnType, 

1134) -> ct.MapType: 

1135 keys = elements[0::2] 

1136 values = elements[1::2] 

1137 if len(keys) != len(values): 

1138 raise DJParseException("Different number of keys and values for MAP.") 

1139 

1140 key_type = extract_consistent_type(keys) 

1141 value_type = extract_consistent_type(values) 

1142 return ct.MapType(key_type=key_type, value_type=value_type) 

1143 

1144 

1145class Week(Function): 

1146 """ 

1147 Returns the week number of the year of the input date value. 

1148 """ 

1149 

1150 @staticmethod 

1151 def infer_type(arg: "Expression") -> ct.TinyIntType: 

1152 return ct.TinyIntType() 

1153 

1154 

1155class Year(Function): 

1156 """ 

1157 Returns the year of the input date value. 

1158 """ 

1159 

1160 @staticmethod 

1161 def infer_type(arg: "Expression") -> ct.TinyIntType: 

1162 return ct.TinyIntType() 

1163 

1164 

1165class FromJson(Function): # pragma: no cover # pylint: disable=abstract-method 

1166 """ 

1167 Converts a JSON string to a struct or map. 

1168 """ 

1169 

1170 

1171@FromJson.register # type: ignore 

1172def infer_type( # noqa: F811 # pylint: disable=function-redefined # pragma: no cover 

1173 json: ct.StringType, 

1174 schema: ct.StringType, 

1175 options: Optional[Function] = None, 

1176) -> ct.StructType: 

1177 # TODO: Handle options? # pylint: disable=fixme 

1178 # pylint: disable=import-outside-toplevel 

1179 from dj.sql.parsing.backends.antlr4 import parse_rule # pragma: no cover 

1180 

1181 return ct.StructType( 

1182 *parse_rule(schema.value, "complexColTypeList") 

1183 ) # pragma: no cover 

1184 

1185 

1186class FunctionRegistryDict(dict): 

1187 """ 

1188 Custom dictionary mapping for functions 

1189 """ 

1190 

1191 def __getitem__(self, key): 

1192 """ 

1193 Returns a custom error about functions that haven't been implemented yet. 

1194 """ 

1195 try: 

1196 return super().__getitem__(key) 

1197 except KeyError as exc: 

1198 raise DJNotImplementedException( 

1199 f"The function `{key}` hasn't been implemented in " 

1200 "DJ yet. You can file an issue at https://github." 

1201 "com/DataJunction/dj/issues/new?title=Function+" 

1202 f"missing:+{key} to request it to be added, or use " 

1203 "the documentation at https://github.com/DataJunct" 

1204 "ion/dj/blob/main/docs/functions.rst to implement it.", 

1205 ) from exc 

1206 

1207 

1208# https://spark.apache.org/docs/3.3.2/sql-ref-syntax-qry-select-tvf.html#content 

1209class Explode(TableFunction): # pylint: disable=abstract-method 

1210 """ 

1211 The Explode function is used to explode the specified array, 

1212 nested array, or map column into multiple rows. 

1213 The explode function will generate a new row for each 

1214 element in the specified column. 

1215 """ 

1216 

1217 

1218@Explode.register 

1219def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1220 arg: ct.ListType, 

1221) -> List[ct.NestedField]: 

1222 return [arg.element] 

1223 

1224 

1225@Explode.register 

1226def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1227 arg: ct.MapType, 

1228) -> List[ct.NestedField]: 

1229 return [arg.key, arg.value] 

1230 

1231 

1232class Unnest(TableFunction): # pylint: disable=abstract-method 

1233 """ 

1234 The unnest function is used to explode the specified array, 

1235 nested array, or map column into multiple rows. 

1236 It will generate a new row for each element in the specified column. 

1237 """ 

1238 

1239 

1240@Unnest.register 

1241def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1242 arg: ct.ListType, 

1243) -> List[ct.NestedField]: 

1244 return [arg.element] # pragma: no cover 

1245 

1246 

1247@Unnest.register 

1248def infer_type( # noqa: F811 # pylint: disable=function-redefined 

1249 arg: ct.MapType, 

1250) -> List[ct.NestedField]: 

1251 return [arg.key, arg.value] 

1252 

1253 

1254function_registry = FunctionRegistryDict() 

1255for cls in Function.__subclasses__(): 

1256 snake_cased = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__) 

1257 function_registry[cls.__name__.upper()] = cls 

1258 function_registry[snake_cased.upper()] = cls 

1259 

1260 

1261table_function_registry = FunctionRegistryDict() 

1262for cls in TableFunction.__subclasses__(): 

1263 snake_cased = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__) 

1264 table_function_registry[cls.__name__.upper()] = cls 

1265 table_function_registry[snake_cased.upper()] = cls