sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator, unsupported_args 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "athena", "presto", "trino"): 191 klass.generator_class.TRY_SUPPORTED = False 192 klass.generator_class.SUPPORTS_UESCAPE = False 193 194 if enum not in ("", "databricks", "hive", "spark", "spark2"): 195 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 196 for modifier in ("cluster", "distribute", "sort"): 197 modifier_transforms.pop(modifier, None) 198 199 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 200 201 if enum not in ("", "doris", "mysql"): 202 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 203 TokenType.STRAIGHT_JOIN, 204 } 205 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 209 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 210 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 211 TokenType.ANTI, 212 TokenType.SEMI, 213 } 214 215 return klass 216 217 218class Dialect(metaclass=_Dialect): 219 INDEX_OFFSET = 0 220 """The base index offset for arrays.""" 221 222 WEEK_OFFSET = 0 223 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 224 225 UNNEST_COLUMN_ONLY = False 226 """Whether `UNNEST` table aliases are treated as column aliases.""" 227 228 ALIAS_POST_TABLESAMPLE = False 229 """Whether the table alias comes after tablesample.""" 230 231 TABLESAMPLE_SIZE_IS_PERCENT = False 232 """Whether a size in the table sample clause represents percentage.""" 233 234 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 235 """Specifies the strategy according to which identifiers should be normalized.""" 236 237 IDENTIFIERS_CAN_START_WITH_DIGIT = False 238 """Whether an unquoted identifier can start with a digit.""" 239 240 DPIPE_IS_STRING_CONCAT = True 241 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 242 243 STRICT_STRING_CONCAT = False 244 """Whether `CONCAT`'s arguments must be strings.""" 245 246 SUPPORTS_USER_DEFINED_TYPES = True 247 """Whether user-defined data types are supported.""" 248 249 SUPPORTS_SEMI_ANTI_JOIN = True 250 """Whether `SEMI` or `ANTI` joins are supported.""" 251 252 SUPPORTS_COLUMN_JOIN_MARKS = False 253 """Whether the old-style outer join (+) syntax is supported.""" 254 255 COPY_PARAMS_ARE_CSV = True 256 """Separator of COPY statement parameters.""" 257 258 NORMALIZE_FUNCTIONS: bool | str = "upper" 259 """ 260 Determines how function names are going to be normalized. 261 Possible values: 262 "upper" or True: Convert names to uppercase. 263 "lower": Convert names to lowercase. 264 False: Disables function name normalization. 265 """ 266 267 LOG_BASE_FIRST: t.Optional[bool] = True 268 """ 269 Whether the base comes first in the `LOG` function. 270 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 271 """ 272 273 NULL_ORDERING = "nulls_are_small" 274 """ 275 Default `NULL` ordering method to use if not explicitly set. 276 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 277 """ 278 279 TYPED_DIVISION = False 280 """ 281 Whether the behavior of `a / b` depends on the types of `a` and `b`. 282 False means `a / b` is always float division. 283 True means `a / b` is integer division if both `a` and `b` are integers. 284 """ 285 286 SAFE_DIVISION = False 287 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 288 289 CONCAT_COALESCE = False 290 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 291 292 HEX_LOWERCASE = False 293 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 294 295 DATE_FORMAT = "'%Y-%m-%d'" 296 DATEINT_FORMAT = "'%Y%m%d'" 297 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 298 299 TIME_MAPPING: t.Dict[str, str] = {} 300 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 301 302 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 303 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 304 FORMAT_MAPPING: t.Dict[str, str] = {} 305 """ 306 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 307 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 308 """ 309 310 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 311 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 312 313 PSEUDOCOLUMNS: t.Set[str] = set() 314 """ 315 Columns that are auto-generated by the engine corresponding to this dialect. 316 For example, such columns may be excluded from `SELECT *` queries. 317 """ 318 319 PREFER_CTE_ALIAS_COLUMN = False 320 """ 321 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 322 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 323 any projection aliases in the subquery. 324 325 For example, 326 WITH y(c) AS ( 327 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 328 ) SELECT c FROM y; 329 330 will be rewritten as 331 332 WITH y(c) AS ( 333 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 334 ) SELECT c FROM y; 335 """ 336 337 COPY_PARAMS_ARE_CSV = True 338 """ 339 Whether COPY statement parameters are separated by comma or whitespace 340 """ 341 342 FORCE_EARLY_ALIAS_REF_EXPANSION = False 343 """ 344 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 345 346 For example: 347 WITH data AS ( 348 SELECT 349 1 AS id, 350 2 AS my_id 351 ) 352 SELECT 353 id AS my_id 354 FROM 355 data 356 WHERE 357 my_id = 1 358 GROUP BY 359 my_id, 360 HAVING 361 my_id = 1 362 363 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 364 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 365 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 366 - Clickhouse, which will forward the alias across the query i.e it resolves 367 to "WHERE id = 1 GROUP BY id HAVING id = 1" 368 """ 369 370 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 371 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 372 373 SUPPORTS_ORDER_BY_ALL = False 374 """ 375 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 376 """ 377 378 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 379 """ 380 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 381 as the former is of type INT[] vs the latter which is SUPER 382 """ 383 384 SUPPORTS_FIXED_SIZE_ARRAYS = False 385 """ 386 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 387 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 388 be interpreted as a subscript/index operator. 389 """ 390 391 STRICT_JSON_PATH_SYNTAX = True 392 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 393 394 ON_CONDITION_EMPTY_BEFORE_ERROR = True 395 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 396 397 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 398 """Whether ArrayAgg needs to filter NULL values.""" 399 400 REGEXP_EXTRACT_DEFAULT_GROUP = 0 401 """The default value for the capturing group.""" 402 403 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 404 exp.Except: True, 405 exp.Intersect: True, 406 exp.Union: True, 407 } 408 """ 409 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 410 must be explicitly specified. 411 """ 412 413 CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 """ 415 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 416 equivalent of CREATE SCHEMA is CREATE DATABASE. 417 """ 418 419 # --- Autofilled --- 420 421 tokenizer_class = Tokenizer 422 jsonpath_tokenizer_class = JSONPathTokenizer 423 parser_class = Parser 424 generator_class = Generator 425 426 # A trie of the time_mapping keys 427 TIME_TRIE: t.Dict = {} 428 FORMAT_TRIE: t.Dict = {} 429 430 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 431 INVERSE_TIME_TRIE: t.Dict = {} 432 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 433 INVERSE_FORMAT_TRIE: t.Dict = {} 434 435 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 436 437 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 438 439 # Delimiters for string literals and identifiers 440 QUOTE_START = "'" 441 QUOTE_END = "'" 442 IDENTIFIER_START = '"' 443 IDENTIFIER_END = '"' 444 445 # Delimiters for bit, hex, byte and unicode literals 446 BIT_START: t.Optional[str] = None 447 BIT_END: t.Optional[str] = None 448 HEX_START: t.Optional[str] = None 449 HEX_END: t.Optional[str] = None 450 BYTE_START: t.Optional[str] = None 451 BYTE_END: t.Optional[str] = None 452 UNICODE_START: t.Optional[str] = None 453 UNICODE_END: t.Optional[str] = None 454 455 DATE_PART_MAPPING = { 456 "Y": "YEAR", 457 "YY": "YEAR", 458 "YYY": "YEAR", 459 "YYYY": "YEAR", 460 "YR": "YEAR", 461 "YEARS": "YEAR", 462 "YRS": "YEAR", 463 "MM": "MONTH", 464 "MON": "MONTH", 465 "MONS": "MONTH", 466 "MONTHS": "MONTH", 467 "D": "DAY", 468 "DD": "DAY", 469 "DAYS": "DAY", 470 "DAYOFMONTH": "DAY", 471 "DAY OF WEEK": "DAYOFWEEK", 472 "WEEKDAY": "DAYOFWEEK", 473 "DOW": "DAYOFWEEK", 474 "DW": "DAYOFWEEK", 475 "WEEKDAY_ISO": "DAYOFWEEKISO", 476 "DOW_ISO": "DAYOFWEEKISO", 477 "DW_ISO": "DAYOFWEEKISO", 478 "DAY OF YEAR": "DAYOFYEAR", 479 "DOY": "DAYOFYEAR", 480 "DY": "DAYOFYEAR", 481 "W": "WEEK", 482 "WK": "WEEK", 483 "WEEKOFYEAR": "WEEK", 484 "WOY": "WEEK", 485 "WY": "WEEK", 486 "WEEK_ISO": "WEEKISO", 487 "WEEKOFYEARISO": "WEEKISO", 488 "WEEKOFYEAR_ISO": "WEEKISO", 489 "Q": "QUARTER", 490 "QTR": "QUARTER", 491 "QTRS": "QUARTER", 492 "QUARTERS": "QUARTER", 493 "H": "HOUR", 494 "HH": "HOUR", 495 "HR": "HOUR", 496 "HOURS": "HOUR", 497 "HRS": "HOUR", 498 "M": "MINUTE", 499 "MI": "MINUTE", 500 "MIN": "MINUTE", 501 "MINUTES": "MINUTE", 502 "MINS": "MINUTE", 503 "S": "SECOND", 504 "SEC": "SECOND", 505 "SECONDS": "SECOND", 506 "SECS": "SECOND", 507 "MS": "MILLISECOND", 508 "MSEC": "MILLISECOND", 509 "MSECS": "MILLISECOND", 510 "MSECOND": "MILLISECOND", 511 "MSECONDS": "MILLISECOND", 512 "MILLISEC": "MILLISECOND", 513 "MILLISECS": "MILLISECOND", 514 "MILLISECON": "MILLISECOND", 515 "MILLISECONDS": "MILLISECOND", 516 "US": "MICROSECOND", 517 "USEC": "MICROSECOND", 518 "USECS": "MICROSECOND", 519 "MICROSEC": "MICROSECOND", 520 "MICROSECS": "MICROSECOND", 521 "USECOND": "MICROSECOND", 522 "USECONDS": "MICROSECOND", 523 "MICROSECONDS": "MICROSECOND", 524 "NS": "NANOSECOND", 525 "NSEC": "NANOSECOND", 526 "NANOSEC": "NANOSECOND", 527 "NSECOND": "NANOSECOND", 528 "NSECONDS": "NANOSECOND", 529 "NANOSECS": "NANOSECOND", 530 "EPOCH_SECOND": "EPOCH", 531 "EPOCH_SECONDS": "EPOCH", 532 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 533 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 534 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 535 "TZH": "TIMEZONE_HOUR", 536 "TZM": "TIMEZONE_MINUTE", 537 "DEC": "DECADE", 538 "DECS": "DECADE", 539 "DECADES": "DECADE", 540 "MIL": "MILLENIUM", 541 "MILS": "MILLENIUM", 542 "MILLENIA": "MILLENIUM", 543 "C": "CENTURY", 544 "CENT": "CENTURY", 545 "CENTS": "CENTURY", 546 "CENTURIES": "CENTURY", 547 } 548 549 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 550 exp.DataType.Type.BIGINT: { 551 exp.ApproxDistinct, 552 exp.ArraySize, 553 exp.Length, 554 }, 555 exp.DataType.Type.BOOLEAN: { 556 exp.Between, 557 exp.Boolean, 558 exp.In, 559 exp.RegexpLike, 560 }, 561 exp.DataType.Type.DATE: { 562 exp.CurrentDate, 563 exp.Date, 564 exp.DateFromParts, 565 exp.DateStrToDate, 566 exp.DiToDate, 567 exp.StrToDate, 568 exp.TimeStrToDate, 569 exp.TsOrDsToDate, 570 }, 571 exp.DataType.Type.DATETIME: { 572 exp.CurrentDatetime, 573 exp.Datetime, 574 exp.DatetimeAdd, 575 exp.DatetimeSub, 576 }, 577 exp.DataType.Type.DOUBLE: { 578 exp.ApproxQuantile, 579 exp.Avg, 580 exp.Exp, 581 exp.Ln, 582 exp.Log, 583 exp.Pow, 584 exp.Quantile, 585 exp.Round, 586 exp.SafeDivide, 587 exp.Sqrt, 588 exp.Stddev, 589 exp.StddevPop, 590 exp.StddevSamp, 591 exp.ToDouble, 592 exp.Variance, 593 exp.VariancePop, 594 }, 595 exp.DataType.Type.INT: { 596 exp.Ceil, 597 exp.DatetimeDiff, 598 exp.DateDiff, 599 exp.TimestampDiff, 600 exp.TimeDiff, 601 exp.DateToDi, 602 exp.Levenshtein, 603 exp.Sign, 604 exp.StrPosition, 605 exp.TsOrDiToDi, 606 }, 607 exp.DataType.Type.JSON: { 608 exp.ParseJSON, 609 }, 610 exp.DataType.Type.TIME: { 611 exp.Time, 612 }, 613 exp.DataType.Type.TIMESTAMP: { 614 exp.CurrentTime, 615 exp.CurrentTimestamp, 616 exp.StrToTime, 617 exp.TimeAdd, 618 exp.TimeStrToTime, 619 exp.TimeSub, 620 exp.TimestampAdd, 621 exp.TimestampSub, 622 exp.UnixToTime, 623 }, 624 exp.DataType.Type.TINYINT: { 625 exp.Day, 626 exp.Month, 627 exp.Week, 628 exp.Year, 629 exp.Quarter, 630 }, 631 exp.DataType.Type.VARCHAR: { 632 exp.ArrayConcat, 633 exp.Concat, 634 exp.ConcatWs, 635 exp.DateToDateStr, 636 exp.GroupConcat, 637 exp.Initcap, 638 exp.Lower, 639 exp.Substring, 640 exp.TimeToStr, 641 exp.TimeToTimeStr, 642 exp.Trim, 643 exp.TsOrDsToDateStr, 644 exp.UnixToStr, 645 exp.UnixToTimeStr, 646 exp.Upper, 647 }, 648 } 649 650 ANNOTATORS: AnnotatorsType = { 651 **{ 652 expr_type: lambda self, e: self._annotate_unary(e) 653 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 654 }, 655 **{ 656 expr_type: lambda self, e: self._annotate_binary(e) 657 for expr_type in subclasses(exp.__name__, exp.Binary) 658 }, 659 **{ 660 expr_type: _annotate_with_type_lambda(data_type) 661 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 662 for expr_type in expressions 663 }, 664 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 665 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 666 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 667 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 668 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 669 exp.Bracket: lambda self, e: self._annotate_bracket(e), 670 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 671 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 672 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 673 exp.Count: lambda self, e: self._annotate_with_type( 674 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 675 ), 676 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 677 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 678 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 679 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 680 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 681 exp.Div: lambda self, e: self._annotate_div(e), 682 exp.Dot: lambda self, e: self._annotate_dot(e), 683 exp.Explode: lambda self, e: self._annotate_explode(e), 684 exp.Extract: lambda self, e: self._annotate_extract(e), 685 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 686 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 687 e, exp.DataType.build("ARRAY<DATE>") 688 ), 689 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 690 e, exp.DataType.build("ARRAY<TIMESTAMP>") 691 ), 692 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 693 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 694 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 695 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 696 exp.Literal: lambda self, e: self._annotate_literal(e), 697 exp.Map: lambda self, e: self._annotate_map(e), 698 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 699 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 700 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 701 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 702 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 703 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 704 exp.Struct: lambda self, e: self._annotate_struct(e), 705 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 706 exp.Timestamp: lambda self, e: self._annotate_with_type( 707 e, 708 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 709 ), 710 exp.ToMap: lambda self, e: self._annotate_to_map(e), 711 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 712 exp.Unnest: lambda self, e: self._annotate_unnest(e), 713 exp.VarMap: lambda self, e: self._annotate_map(e), 714 } 715 716 @classmethod 717 def get_or_raise(cls, dialect: DialectType) -> Dialect: 718 """ 719 Look up a dialect in the global dialect registry and return it if it exists. 720 721 Args: 722 dialect: The target dialect. If this is a string, it can be optionally followed by 723 additional key-value pairs that are separated by commas and are used to specify 724 dialect settings, such as whether the dialect's identifiers are case-sensitive. 725 726 Example: 727 >>> dialect = dialect_class = get_or_raise("duckdb") 728 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 729 730 Returns: 731 The corresponding Dialect instance. 732 """ 733 734 if not dialect: 735 return cls() 736 if isinstance(dialect, _Dialect): 737 return dialect() 738 if isinstance(dialect, Dialect): 739 return dialect 740 if isinstance(dialect, str): 741 try: 742 dialect_name, *kv_strings = dialect.split(",") 743 kv_pairs = (kv.split("=") for kv in kv_strings) 744 kwargs = {} 745 for pair in kv_pairs: 746 key = pair[0].strip() 747 value: t.Union[bool | str | None] = None 748 749 if len(pair) == 1: 750 # Default initialize standalone settings to True 751 value = True 752 elif len(pair) == 2: 753 value = pair[1].strip() 754 755 # Coerce the value to boolean if it matches to the truthy/falsy values below 756 value_lower = value.lower() 757 if value_lower in ("true", "1"): 758 value = True 759 elif value_lower in ("false", "0"): 760 value = False 761 762 kwargs[key] = value 763 764 except ValueError: 765 raise ValueError( 766 f"Invalid dialect format: '{dialect}'. " 767 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 768 ) 769 770 result = cls.get(dialect_name.strip()) 771 if not result: 772 from difflib import get_close_matches 773 774 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 775 if similar: 776 similar = f" Did you mean {similar}?" 777 778 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 779 780 return result(**kwargs) 781 782 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 783 784 @classmethod 785 def format_time( 786 cls, expression: t.Optional[str | exp.Expression] 787 ) -> t.Optional[exp.Expression]: 788 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 789 if isinstance(expression, str): 790 return exp.Literal.string( 791 # the time formats are quoted 792 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 793 ) 794 795 if expression and expression.is_string: 796 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 797 798 return expression 799 800 def __init__(self, **kwargs) -> None: 801 normalization_strategy = kwargs.pop("normalization_strategy", None) 802 803 if normalization_strategy is None: 804 self.normalization_strategy = self.NORMALIZATION_STRATEGY 805 else: 806 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 807 808 self.settings = kwargs 809 810 def __eq__(self, other: t.Any) -> bool: 811 # Does not currently take dialect state into account 812 return type(self) == other 813 814 def __hash__(self) -> int: 815 # Does not currently take dialect state into account 816 return hash(type(self)) 817 818 def normalize_identifier(self, expression: E) -> E: 819 """ 820 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 821 822 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 823 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 824 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 825 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 826 827 There are also dialects like Spark, which are case-insensitive even when quotes are 828 present, and dialects like MySQL, whose resolution rules match those employed by the 829 underlying operating system, for example they may always be case-sensitive in Linux. 830 831 Finally, the normalization behavior of some engines can even be controlled through flags, 832 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 833 834 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 835 that it can analyze queries in the optimizer and successfully capture their semantics. 836 """ 837 if ( 838 isinstance(expression, exp.Identifier) 839 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 840 and ( 841 not expression.quoted 842 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 843 ) 844 ): 845 expression.set( 846 "this", 847 ( 848 expression.this.upper() 849 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 850 else expression.this.lower() 851 ), 852 ) 853 854 return expression 855 856 def case_sensitive(self, text: str) -> bool: 857 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 858 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 859 return False 860 861 unsafe = ( 862 str.islower 863 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 864 else str.isupper 865 ) 866 return any(unsafe(char) for char in text) 867 868 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 869 """Checks if text can be identified given an identify option. 870 871 Args: 872 text: The text to check. 873 identify: 874 `"always"` or `True`: Always returns `True`. 875 `"safe"`: Only returns `True` if the identifier is case-insensitive. 876 877 Returns: 878 Whether the given text can be identified. 879 """ 880 if identify is True or identify == "always": 881 return True 882 883 if identify == "safe": 884 return not self.case_sensitive(text) 885 886 return False 887 888 def quote_identifier(self, expression: E, identify: bool = True) -> E: 889 """ 890 Adds quotes to a given identifier. 891 892 Args: 893 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 894 identify: If set to `False`, the quotes will only be added if the identifier is deemed 895 "unsafe", with respect to its characters and this dialect's normalization strategy. 896 """ 897 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 898 name = expression.this 899 expression.set( 900 "quoted", 901 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 902 ) 903 904 return expression 905 906 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 907 if isinstance(path, exp.Literal): 908 path_text = path.name 909 if path.is_number: 910 path_text = f"[{path_text}]" 911 try: 912 return parse_json_path(path_text, self) 913 except ParseError as e: 914 if self.STRICT_JSON_PATH_SYNTAX: 915 logger.warning(f"Invalid JSON path syntax. {str(e)}") 916 917 return path 918 919 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 920 return self.parser(**opts).parse(self.tokenize(sql), sql) 921 922 def parse_into( 923 self, expression_type: exp.IntoType, sql: str, **opts 924 ) -> t.List[t.Optional[exp.Expression]]: 925 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 926 927 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 928 return self.generator(**opts).generate(expression, copy=copy) 929 930 def transpile(self, sql: str, **opts) -> t.List[str]: 931 return [ 932 self.generate(expression, copy=False, **opts) if expression else "" 933 for expression in self.parse(sql) 934 ] 935 936 def tokenize(self, sql: str) -> t.List[Token]: 937 return self.tokenizer.tokenize(sql) 938 939 @property 940 def tokenizer(self) -> Tokenizer: 941 return self.tokenizer_class(dialect=self) 942 943 @property 944 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 945 return self.jsonpath_tokenizer_class(dialect=self) 946 947 def parser(self, **opts) -> Parser: 948 return self.parser_class(dialect=self, **opts) 949 950 def generator(self, **opts) -> Generator: 951 return self.generator_class(dialect=self, **opts) 952 953 954DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 955 956 957def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 958 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 959 960 961@unsupported_args("accuracy") 962def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 963 return self.func("APPROX_COUNT_DISTINCT", expression.this) 964 965 966def if_sql( 967 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 968) -> t.Callable[[Generator, exp.If], str]: 969 def _if_sql(self: Generator, expression: exp.If) -> str: 970 return self.func( 971 name, 972 expression.this, 973 expression.args.get("true"), 974 expression.args.get("false") or false_value, 975 ) 976 977 return _if_sql 978 979 980def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 981 this = expression.this 982 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 983 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 984 985 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 986 987 988def inline_array_sql(self: Generator, expression: exp.Array) -> str: 989 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 990 991 992def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 993 elem = seq_get(expression.expressions, 0) 994 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 995 return self.func("ARRAY", elem) 996 return inline_array_sql(self, expression) 997 998 999def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 1000 return self.like_sql( 1001 exp.Like( 1002 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 1003 ) 1004 ) 1005 1006 1007def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 1008 zone = self.sql(expression, "this") 1009 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 1010 1011 1012def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 1013 if expression.args.get("recursive"): 1014 self.unsupported("Recursive CTEs are unsupported") 1015 expression.args["recursive"] = False 1016 return self.with_sql(expression) 1017 1018 1019def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 1020 n = self.sql(expression, "this") 1021 d = self.sql(expression, "expression") 1022 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 1023 1024 1025def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1026 self.unsupported("TABLESAMPLE unsupported") 1027 return self.sql(expression.this) 1028 1029 1030def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1031 self.unsupported("PIVOT unsupported") 1032 return "" 1033 1034 1035def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1036 return self.cast_sql(expression) 1037 1038 1039def no_comment_column_constraint_sql( 1040 self: Generator, expression: exp.CommentColumnConstraint 1041) -> str: 1042 self.unsupported("CommentColumnConstraint unsupported") 1043 return "" 1044 1045 1046def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1047 self.unsupported("MAP_FROM_ENTRIES unsupported") 1048 return "" 1049 1050 1051def property_sql(self: Generator, expression: exp.Property) -> str: 1052 return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" 1053 1054 1055def str_position_sql( 1056 self: Generator, 1057 expression: exp.StrPosition, 1058 generate_instance: bool = False, 1059 str_position_func_name: str = "STRPOS", 1060) -> str: 1061 this = self.sql(expression, "this") 1062 substr = self.sql(expression, "substr") 1063 position = self.sql(expression, "position") 1064 instance = expression.args.get("instance") if generate_instance else None 1065 position_offset = "" 1066 1067 if position: 1068 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1069 this = self.func("SUBSTR", this, position) 1070 position_offset = f" + {position} - 1" 1071 1072 return self.func(str_position_func_name, this, substr, instance) + position_offset 1073 1074 1075def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1076 return ( 1077 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1078 ) 1079 1080 1081def var_map_sql( 1082 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1083) -> str: 1084 keys = expression.args["keys"] 1085 values = expression.args["values"] 1086 1087 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1088 self.unsupported("Cannot convert array columns into map.") 1089 return self.func(map_func_name, keys, values) 1090 1091 args = [] 1092 for key, value in zip(keys.expressions, values.expressions): 1093 args.append(self.sql(key)) 1094 args.append(self.sql(value)) 1095 1096 return self.func(map_func_name, *args) 1097 1098 1099def build_formatted_time( 1100 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1101) -> t.Callable[[t.List], E]: 1102 """Helper used for time expressions. 1103 1104 Args: 1105 exp_class: the expression class to instantiate. 1106 dialect: target sql dialect. 1107 default: the default format, True being time. 1108 1109 Returns: 1110 A callable that can be used to return the appropriately formatted time expression. 1111 """ 1112 1113 def _builder(args: t.List): 1114 return exp_class( 1115 this=seq_get(args, 0), 1116 format=Dialect[dialect].format_time( 1117 seq_get(args, 1) 1118 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1119 ), 1120 ) 1121 1122 return _builder 1123 1124 1125def time_format( 1126 dialect: DialectType = None, 1127) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1128 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1129 """ 1130 Returns the time format for a given expression, unless it's equivalent 1131 to the default time format of the dialect of interest. 1132 """ 1133 time_format = self.format_time(expression) 1134 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1135 1136 return _time_format 1137 1138 1139def build_date_delta( 1140 exp_class: t.Type[E], 1141 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1142 default_unit: t.Optional[str] = "DAY", 1143) -> t.Callable[[t.List], E]: 1144 def _builder(args: t.List) -> E: 1145 unit_based = len(args) == 3 1146 this = args[2] if unit_based else seq_get(args, 0) 1147 unit = None 1148 if unit_based or default_unit: 1149 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1150 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1151 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1152 1153 return _builder 1154 1155 1156def build_date_delta_with_interval( 1157 expression_class: t.Type[E], 1158) -> t.Callable[[t.List], t.Optional[E]]: 1159 def _builder(args: t.List) -> t.Optional[E]: 1160 if len(args) < 2: 1161 return None 1162 1163 interval = args[1] 1164 1165 if not isinstance(interval, exp.Interval): 1166 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1167 1168 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1169 1170 return _builder 1171 1172 1173def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1174 unit = seq_get(args, 0) 1175 this = seq_get(args, 1) 1176 1177 if isinstance(this, exp.Cast) and this.is_type("date"): 1178 return exp.DateTrunc(unit=unit, this=this) 1179 return exp.TimestampTrunc(this=this, unit=unit) 1180 1181 1182def date_add_interval_sql( 1183 data_type: str, kind: str 1184) -> t.Callable[[Generator, exp.Expression], str]: 1185 def func(self: Generator, expression: exp.Expression) -> str: 1186 this = self.sql(expression, "this") 1187 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1188 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1189 1190 return func 1191 1192 1193def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1194 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1195 args = [unit_to_str(expression), expression.this] 1196 if zone: 1197 args.append(expression.args.get("zone")) 1198 return self.func("DATE_TRUNC", *args) 1199 1200 return _timestamptrunc_sql 1201 1202 1203def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1204 zone = expression.args.get("zone") 1205 if not zone: 1206 from sqlglot.optimizer.annotate_types import annotate_types 1207 1208 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1209 return self.sql(exp.cast(expression.this, target_type)) 1210 if zone.name.lower() in TIMEZONES: 1211 return self.sql( 1212 exp.AtTimeZone( 1213 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1214 zone=zone, 1215 ) 1216 ) 1217 return self.func("TIMESTAMP", expression.this, zone) 1218 1219 1220def no_time_sql(self: Generator, expression: exp.Time) -> str: 1221 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1222 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1223 expr = exp.cast( 1224 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1225 ) 1226 return self.sql(expr) 1227 1228 1229def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1230 this = expression.this 1231 expr = expression.expression 1232 1233 if expr.name.lower() in TIMEZONES: 1234 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1235 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1236 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1237 return self.sql(this) 1238 1239 this = exp.cast(this, exp.DataType.Type.DATE) 1240 expr = exp.cast(expr, exp.DataType.Type.TIME) 1241 1242 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1243 1244 1245def locate_to_strposition(args: t.List) -> exp.Expression: 1246 return exp.StrPosition( 1247 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1248 ) 1249 1250 1251def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1252 return self.func( 1253 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1254 ) 1255 1256 1257def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1258 return self.sql( 1259 exp.Substring( 1260 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1261 ) 1262 ) 1263 1264 1265def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1266 return self.sql( 1267 exp.Substring( 1268 this=expression.this, 1269 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1270 ) 1271 ) 1272 1273 1274def timestrtotime_sql( 1275 self: Generator, 1276 expression: exp.TimeStrToTime, 1277 include_precision: bool = False, 1278) -> str: 1279 datatype = exp.DataType.build( 1280 exp.DataType.Type.TIMESTAMPTZ 1281 if expression.args.get("zone") 1282 else exp.DataType.Type.TIMESTAMP 1283 ) 1284 1285 if isinstance(expression.this, exp.Literal) and include_precision: 1286 precision = subsecond_precision(expression.this.name) 1287 if precision > 0: 1288 datatype = exp.DataType.build( 1289 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1290 ) 1291 1292 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1293 1294 1295def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1296 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1297 1298 1299# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1300def encode_decode_sql( 1301 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1302) -> str: 1303 charset = expression.args.get("charset") 1304 if charset and charset.name.lower() != "utf-8": 1305 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1306 1307 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1308 1309 1310def min_or_least(self: Generator, expression: exp.Min) -> str: 1311 name = "LEAST" if expression.expressions else "MIN" 1312 return rename_func(name)(self, expression) 1313 1314 1315def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1316 name = "GREATEST" if expression.expressions else "MAX" 1317 return rename_func(name)(self, expression) 1318 1319 1320def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1321 cond = expression.this 1322 1323 if isinstance(expression.this, exp.Distinct): 1324 cond = expression.this.expressions[0] 1325 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1326 1327 return self.func("sum", exp.func("if", cond, 1, 0)) 1328 1329 1330def trim_sql(self: Generator, expression: exp.Trim) -> str: 1331 target = self.sql(expression, "this") 1332 trim_type = self.sql(expression, "position") 1333 remove_chars = self.sql(expression, "expression") 1334 collation = self.sql(expression, "collation") 1335 1336 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1337 if not remove_chars: 1338 return self.trim_sql(expression) 1339 1340 trim_type = f"{trim_type} " if trim_type else "" 1341 remove_chars = f"{remove_chars} " if remove_chars else "" 1342 from_part = "FROM " if trim_type or remove_chars else "" 1343 collation = f" COLLATE {collation}" if collation else "" 1344 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1345 1346 1347def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1348 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1349 1350 1351def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1352 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1353 1354 1355def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1356 delim, *rest_args = expression.expressions 1357 return self.sql( 1358 reduce( 1359 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1360 rest_args, 1361 ) 1362 ) 1363 1364 1365@unsupported_args("position", "occurrence", "parameters") 1366def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1367 group = expression.args.get("group") 1368 1369 # Do not render group if it's the default value for this dialect 1370 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1371 group = None 1372 1373 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group) 1374 1375 1376@unsupported_args("position", "occurrence", "modifiers") 1377def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1378 return self.func( 1379 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1380 ) 1381 1382 1383def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1384 names = [] 1385 for agg in aggregations: 1386 if isinstance(agg, exp.Alias): 1387 names.append(agg.alias) 1388 else: 1389 """ 1390 This case corresponds to aggregations without aliases being used as suffixes 1391 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1392 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1393 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1394 """ 1395 agg_all_unquoted = agg.transform( 1396 lambda node: ( 1397 exp.Identifier(this=node.name, quoted=False) 1398 if isinstance(node, exp.Identifier) 1399 else node 1400 ) 1401 ) 1402 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1403 1404 return names 1405 1406 1407def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1408 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1409 1410 1411# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1412def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1413 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1414 1415 1416def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1417 return self.func("MAX", expression.this) 1418 1419 1420def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1421 a = self.sql(expression.left) 1422 b = self.sql(expression.right) 1423 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1424 1425 1426def is_parse_json(expression: exp.Expression) -> bool: 1427 return isinstance(expression, exp.ParseJSON) or ( 1428 isinstance(expression, exp.Cast) and expression.is_type("json") 1429 ) 1430 1431 1432def isnull_to_is_null(args: t.List) -> exp.Expression: 1433 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1434 1435 1436def generatedasidentitycolumnconstraint_sql( 1437 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1438) -> str: 1439 start = self.sql(expression, "start") or "1" 1440 increment = self.sql(expression, "increment") or "1" 1441 return f"IDENTITY({start}, {increment})" 1442 1443 1444def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1445 @unsupported_args("count") 1446 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1447 return self.func(name, expression.this, expression.expression) 1448 1449 return _arg_max_or_min_sql 1450 1451 1452def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1453 this = expression.this.copy() 1454 1455 return_type = expression.return_type 1456 if return_type.is_type(exp.DataType.Type.DATE): 1457 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1458 # can truncate timestamp strings, because some dialects can't cast them to DATE 1459 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1460 1461 expression.this.replace(exp.cast(this, return_type)) 1462 return expression 1463 1464 1465def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1466 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1467 if cast and isinstance(expression, exp.TsOrDsAdd): 1468 expression = ts_or_ds_add_cast(expression) 1469 1470 return self.func( 1471 name, 1472 unit_to_var(expression), 1473 expression.expression, 1474 expression.this, 1475 ) 1476 1477 return _delta_sql 1478 1479 1480def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1481 unit = expression.args.get("unit") 1482 1483 if isinstance(unit, exp.Placeholder): 1484 return unit 1485 if unit: 1486 return exp.Literal.string(unit.name) 1487 return exp.Literal.string(default) if default else None 1488 1489 1490def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1491 unit = expression.args.get("unit") 1492 1493 if isinstance(unit, (exp.Var, exp.Placeholder)): 1494 return unit 1495 return exp.Var(this=default) if default else None 1496 1497 1498@t.overload 1499def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1500 pass 1501 1502 1503@t.overload 1504def map_date_part( 1505 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1506) -> t.Optional[exp.Expression]: 1507 pass 1508 1509 1510def map_date_part(part, dialect: DialectType = Dialect): 1511 mapped = ( 1512 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1513 ) 1514 return exp.var(mapped) if mapped else part 1515 1516 1517def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1518 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1519 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1520 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1521 1522 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1523 1524 1525def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1526 """Remove table refs from columns in when statements.""" 1527 alias = expression.this.args.get("alias") 1528 1529 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1530 return self.dialect.normalize_identifier(identifier).name if identifier else None 1531 1532 targets = {normalize(expression.this.this)} 1533 1534 if alias: 1535 targets.add(normalize(alias.this)) 1536 1537 for when in expression.expressions: 1538 # only remove the target names from the THEN clause 1539 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1540 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1541 then = when.args.get("then") 1542 if then: 1543 then.transform( 1544 lambda node: ( 1545 exp.column(node.this) 1546 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1547 else node 1548 ), 1549 copy=False, 1550 ) 1551 1552 return self.merge_sql(expression) 1553 1554 1555def build_json_extract_path( 1556 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1557) -> t.Callable[[t.List], F]: 1558 def _builder(args: t.List) -> F: 1559 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1560 for arg in args[1:]: 1561 if not isinstance(arg, exp.Literal): 1562 # We use the fallback parser because we can't really transpile non-literals safely 1563 return expr_type.from_arg_list(args) 1564 1565 text = arg.name 1566 if is_int(text): 1567 index = int(text) 1568 segments.append( 1569 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1570 ) 1571 else: 1572 segments.append(exp.JSONPathKey(this=text)) 1573 1574 # This is done to avoid failing in the expression validator due to the arg count 1575 del args[2:] 1576 return expr_type( 1577 this=seq_get(args, 0), 1578 expression=exp.JSONPath(expressions=segments), 1579 only_json_types=arrow_req_json_type, 1580 ) 1581 1582 return _builder 1583 1584 1585def json_extract_segments( 1586 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1587) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1588 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1589 path = expression.expression 1590 if not isinstance(path, exp.JSONPath): 1591 return rename_func(name)(self, expression) 1592 1593 escape = path.args.get("escape") 1594 1595 segments = [] 1596 for segment in path.expressions: 1597 path = self.sql(segment) 1598 if path: 1599 if isinstance(segment, exp.JSONPathPart) and ( 1600 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1601 ): 1602 if escape: 1603 path = self.escape_str(path) 1604 1605 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1606 1607 segments.append(path) 1608 1609 if op: 1610 return f" {op} ".join([self.sql(expression.this), *segments]) 1611 return self.func(name, expression.this, *segments) 1612 1613 return _json_extract_segments 1614 1615 1616def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1617 if isinstance(expression.this, exp.JSONPathWildcard): 1618 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1619 1620 return expression.name 1621 1622 1623def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1624 cond = expression.expression 1625 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1626 alias = cond.expressions[0] 1627 cond = cond.this 1628 elif isinstance(cond, exp.Predicate): 1629 alias = "_u" 1630 else: 1631 self.unsupported("Unsupported filter condition") 1632 return "" 1633 1634 unnest = exp.Unnest(expressions=[expression.this]) 1635 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1636 return self.sql(exp.Array(expressions=[filtered])) 1637 1638 1639def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1640 return self.func( 1641 "TO_NUMBER", 1642 expression.this, 1643 expression.args.get("format"), 1644 expression.args.get("nlsparam"), 1645 ) 1646 1647 1648def build_default_decimal_type( 1649 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1650) -> t.Callable[[exp.DataType], exp.DataType]: 1651 def _builder(dtype: exp.DataType) -> exp.DataType: 1652 if dtype.expressions or precision is None: 1653 return dtype 1654 1655 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1656 return exp.DataType.build(f"DECIMAL({params})") 1657 1658 return _builder 1659 1660 1661def build_timestamp_from_parts(args: t.List) -> exp.Func: 1662 if len(args) == 2: 1663 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1664 # so we parse this into Anonymous for now instead of introducing complexity 1665 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1666 1667 return exp.TimestampFromParts.from_arg_list(args) 1668 1669 1670def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1671 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1672 1673 1674def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1675 start = expression.args.get("start") 1676 end = expression.args.get("end") 1677 step = expression.args.get("step") 1678 1679 if isinstance(start, exp.Cast): 1680 target_type = start.to 1681 elif isinstance(end, exp.Cast): 1682 target_type = end.to 1683 else: 1684 target_type = None 1685 1686 if start and end and target_type and target_type.is_type("date", "timestamp"): 1687 if isinstance(start, exp.Cast) and target_type is start.to: 1688 end = exp.cast(end, target_type) 1689 else: 1690 start = exp.cast(start, target_type) 1691 1692 return self.func("SEQUENCE", start, end, step) 1693 1694 1695def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract: 1696 return exp.RegexpExtract( 1697 this=seq_get(args, 0), 1698 expression=seq_get(args, 1), 1699 group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), 1700 ) 1701 1702 1703def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1704 if isinstance(expression.this, exp.Explode): 1705 return self.sql( 1706 exp.Join( 1707 this=exp.Unnest( 1708 expressions=[expression.this.this], 1709 alias=expression.args.get("alias"), 1710 offset=isinstance(expression.this, exp.Posexplode), 1711 ), 1712 kind="cross", 1713 ) 1714 ) 1715 return self.lateral_sql(expression)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
219class Dialect(metaclass=_Dialect): 220 INDEX_OFFSET = 0 221 """The base index offset for arrays.""" 222 223 WEEK_OFFSET = 0 224 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 225 226 UNNEST_COLUMN_ONLY = False 227 """Whether `UNNEST` table aliases are treated as column aliases.""" 228 229 ALIAS_POST_TABLESAMPLE = False 230 """Whether the table alias comes after tablesample.""" 231 232 TABLESAMPLE_SIZE_IS_PERCENT = False 233 """Whether a size in the table sample clause represents percentage.""" 234 235 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 236 """Specifies the strategy according to which identifiers should be normalized.""" 237 238 IDENTIFIERS_CAN_START_WITH_DIGIT = False 239 """Whether an unquoted identifier can start with a digit.""" 240 241 DPIPE_IS_STRING_CONCAT = True 242 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 243 244 STRICT_STRING_CONCAT = False 245 """Whether `CONCAT`'s arguments must be strings.""" 246 247 SUPPORTS_USER_DEFINED_TYPES = True 248 """Whether user-defined data types are supported.""" 249 250 SUPPORTS_SEMI_ANTI_JOIN = True 251 """Whether `SEMI` or `ANTI` joins are supported.""" 252 253 SUPPORTS_COLUMN_JOIN_MARKS = False 254 """Whether the old-style outer join (+) syntax is supported.""" 255 256 COPY_PARAMS_ARE_CSV = True 257 """Separator of COPY statement parameters.""" 258 259 NORMALIZE_FUNCTIONS: bool | str = "upper" 260 """ 261 Determines how function names are going to be normalized. 262 Possible values: 263 "upper" or True: Convert names to uppercase. 264 "lower": Convert names to lowercase. 265 False: Disables function name normalization. 266 """ 267 268 LOG_BASE_FIRST: t.Optional[bool] = True 269 """ 270 Whether the base comes first in the `LOG` function. 271 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 272 """ 273 274 NULL_ORDERING = "nulls_are_small" 275 """ 276 Default `NULL` ordering method to use if not explicitly set. 277 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 278 """ 279 280 TYPED_DIVISION = False 281 """ 282 Whether the behavior of `a / b` depends on the types of `a` and `b`. 283 False means `a / b` is always float division. 284 True means `a / b` is integer division if both `a` and `b` are integers. 285 """ 286 287 SAFE_DIVISION = False 288 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 289 290 CONCAT_COALESCE = False 291 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 292 293 HEX_LOWERCASE = False 294 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 295 296 DATE_FORMAT = "'%Y-%m-%d'" 297 DATEINT_FORMAT = "'%Y%m%d'" 298 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 299 300 TIME_MAPPING: t.Dict[str, str] = {} 301 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 302 303 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 304 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 305 FORMAT_MAPPING: t.Dict[str, str] = {} 306 """ 307 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 308 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 309 """ 310 311 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 312 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 313 314 PSEUDOCOLUMNS: t.Set[str] = set() 315 """ 316 Columns that are auto-generated by the engine corresponding to this dialect. 317 For example, such columns may be excluded from `SELECT *` queries. 318 """ 319 320 PREFER_CTE_ALIAS_COLUMN = False 321 """ 322 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 323 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 324 any projection aliases in the subquery. 325 326 For example, 327 WITH y(c) AS ( 328 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 329 ) SELECT c FROM y; 330 331 will be rewritten as 332 333 WITH y(c) AS ( 334 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 335 ) SELECT c FROM y; 336 """ 337 338 COPY_PARAMS_ARE_CSV = True 339 """ 340 Whether COPY statement parameters are separated by comma or whitespace 341 """ 342 343 FORCE_EARLY_ALIAS_REF_EXPANSION = False 344 """ 345 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 346 347 For example: 348 WITH data AS ( 349 SELECT 350 1 AS id, 351 2 AS my_id 352 ) 353 SELECT 354 id AS my_id 355 FROM 356 data 357 WHERE 358 my_id = 1 359 GROUP BY 360 my_id, 361 HAVING 362 my_id = 1 363 364 In most dialects, "my_id" would refer to "data.my_id" across the query, except: 365 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e 366 it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 367 - Clickhouse, which will forward the alias across the query i.e it resolves 368 to "WHERE id = 1 GROUP BY id HAVING id = 1" 369 """ 370 371 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 372 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 373 374 SUPPORTS_ORDER_BY_ALL = False 375 """ 376 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 377 """ 378 379 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 380 """ 381 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 382 as the former is of type INT[] vs the latter which is SUPER 383 """ 384 385 SUPPORTS_FIXED_SIZE_ARRAYS = False 386 """ 387 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. 388 in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should 389 be interpreted as a subscript/index operator. 390 """ 391 392 STRICT_JSON_PATH_SYNTAX = True 393 """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" 394 395 ON_CONDITION_EMPTY_BEFORE_ERROR = True 396 """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" 397 398 ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True 399 """Whether ArrayAgg needs to filter NULL values.""" 400 401 REGEXP_EXTRACT_DEFAULT_GROUP = 0 402 """The default value for the capturing group.""" 403 404 SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { 405 exp.Except: True, 406 exp.Intersect: True, 407 exp.Union: True, 408 } 409 """ 410 Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` 411 must be explicitly specified. 412 """ 413 414 CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 """ 416 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 417 equivalent of CREATE SCHEMA is CREATE DATABASE. 418 """ 419 420 # --- Autofilled --- 421 422 tokenizer_class = Tokenizer 423 jsonpath_tokenizer_class = JSONPathTokenizer 424 parser_class = Parser 425 generator_class = Generator 426 427 # A trie of the time_mapping keys 428 TIME_TRIE: t.Dict = {} 429 FORMAT_TRIE: t.Dict = {} 430 431 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 432 INVERSE_TIME_TRIE: t.Dict = {} 433 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 434 INVERSE_FORMAT_TRIE: t.Dict = {} 435 436 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 437 438 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 439 440 # Delimiters for string literals and identifiers 441 QUOTE_START = "'" 442 QUOTE_END = "'" 443 IDENTIFIER_START = '"' 444 IDENTIFIER_END = '"' 445 446 # Delimiters for bit, hex, byte and unicode literals 447 BIT_START: t.Optional[str] = None 448 BIT_END: t.Optional[str] = None 449 HEX_START: t.Optional[str] = None 450 HEX_END: t.Optional[str] = None 451 BYTE_START: t.Optional[str] = None 452 BYTE_END: t.Optional[str] = None 453 UNICODE_START: t.Optional[str] = None 454 UNICODE_END: t.Optional[str] = None 455 456 DATE_PART_MAPPING = { 457 "Y": "YEAR", 458 "YY": "YEAR", 459 "YYY": "YEAR", 460 "YYYY": "YEAR", 461 "YR": "YEAR", 462 "YEARS": "YEAR", 463 "YRS": "YEAR", 464 "MM": "MONTH", 465 "MON": "MONTH", 466 "MONS": "MONTH", 467 "MONTHS": "MONTH", 468 "D": "DAY", 469 "DD": "DAY", 470 "DAYS": "DAY", 471 "DAYOFMONTH": "DAY", 472 "DAY OF WEEK": "DAYOFWEEK", 473 "WEEKDAY": "DAYOFWEEK", 474 "DOW": "DAYOFWEEK", 475 "DW": "DAYOFWEEK", 476 "WEEKDAY_ISO": "DAYOFWEEKISO", 477 "DOW_ISO": "DAYOFWEEKISO", 478 "DW_ISO": "DAYOFWEEKISO", 479 "DAY OF YEAR": "DAYOFYEAR", 480 "DOY": "DAYOFYEAR", 481 "DY": "DAYOFYEAR", 482 "W": "WEEK", 483 "WK": "WEEK", 484 "WEEKOFYEAR": "WEEK", 485 "WOY": "WEEK", 486 "WY": "WEEK", 487 "WEEK_ISO": "WEEKISO", 488 "WEEKOFYEARISO": "WEEKISO", 489 "WEEKOFYEAR_ISO": "WEEKISO", 490 "Q": "QUARTER", 491 "QTR": "QUARTER", 492 "QTRS": "QUARTER", 493 "QUARTERS": "QUARTER", 494 "H": "HOUR", 495 "HH": "HOUR", 496 "HR": "HOUR", 497 "HOURS": "HOUR", 498 "HRS": "HOUR", 499 "M": "MINUTE", 500 "MI": "MINUTE", 501 "MIN": "MINUTE", 502 "MINUTES": "MINUTE", 503 "MINS": "MINUTE", 504 "S": "SECOND", 505 "SEC": "SECOND", 506 "SECONDS": "SECOND", 507 "SECS": "SECOND", 508 "MS": "MILLISECOND", 509 "MSEC": "MILLISECOND", 510 "MSECS": "MILLISECOND", 511 "MSECOND": "MILLISECOND", 512 "MSECONDS": "MILLISECOND", 513 "MILLISEC": "MILLISECOND", 514 "MILLISECS": "MILLISECOND", 515 "MILLISECON": "MILLISECOND", 516 "MILLISECONDS": "MILLISECOND", 517 "US": "MICROSECOND", 518 "USEC": "MICROSECOND", 519 "USECS": "MICROSECOND", 520 "MICROSEC": "MICROSECOND", 521 "MICROSECS": "MICROSECOND", 522 "USECOND": "MICROSECOND", 523 "USECONDS": "MICROSECOND", 524 "MICROSECONDS": "MICROSECOND", 525 "NS": "NANOSECOND", 526 "NSEC": "NANOSECOND", 527 "NANOSEC": "NANOSECOND", 528 "NSECOND": "NANOSECOND", 529 "NSECONDS": "NANOSECOND", 530 "NANOSECS": "NANOSECOND", 531 "EPOCH_SECOND": "EPOCH", 532 "EPOCH_SECONDS": "EPOCH", 533 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 534 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 535 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 536 "TZH": "TIMEZONE_HOUR", 537 "TZM": "TIMEZONE_MINUTE", 538 "DEC": "DECADE", 539 "DECS": "DECADE", 540 "DECADES": "DECADE", 541 "MIL": "MILLENIUM", 542 "MILS": "MILLENIUM", 543 "MILLENIA": "MILLENIUM", 544 "C": "CENTURY", 545 "CENT": "CENTURY", 546 "CENTS": "CENTURY", 547 "CENTURIES": "CENTURY", 548 } 549 550 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 551 exp.DataType.Type.BIGINT: { 552 exp.ApproxDistinct, 553 exp.ArraySize, 554 exp.Length, 555 }, 556 exp.DataType.Type.BOOLEAN: { 557 exp.Between, 558 exp.Boolean, 559 exp.In, 560 exp.RegexpLike, 561 }, 562 exp.DataType.Type.DATE: { 563 exp.CurrentDate, 564 exp.Date, 565 exp.DateFromParts, 566 exp.DateStrToDate, 567 exp.DiToDate, 568 exp.StrToDate, 569 exp.TimeStrToDate, 570 exp.TsOrDsToDate, 571 }, 572 exp.DataType.Type.DATETIME: { 573 exp.CurrentDatetime, 574 exp.Datetime, 575 exp.DatetimeAdd, 576 exp.DatetimeSub, 577 }, 578 exp.DataType.Type.DOUBLE: { 579 exp.ApproxQuantile, 580 exp.Avg, 581 exp.Exp, 582 exp.Ln, 583 exp.Log, 584 exp.Pow, 585 exp.Quantile, 586 exp.Round, 587 exp.SafeDivide, 588 exp.Sqrt, 589 exp.Stddev, 590 exp.StddevPop, 591 exp.StddevSamp, 592 exp.ToDouble, 593 exp.Variance, 594 exp.VariancePop, 595 }, 596 exp.DataType.Type.INT: { 597 exp.Ceil, 598 exp.DatetimeDiff, 599 exp.DateDiff, 600 exp.TimestampDiff, 601 exp.TimeDiff, 602 exp.DateToDi, 603 exp.Levenshtein, 604 exp.Sign, 605 exp.StrPosition, 606 exp.TsOrDiToDi, 607 }, 608 exp.DataType.Type.JSON: { 609 exp.ParseJSON, 610 }, 611 exp.DataType.Type.TIME: { 612 exp.Time, 613 }, 614 exp.DataType.Type.TIMESTAMP: { 615 exp.CurrentTime, 616 exp.CurrentTimestamp, 617 exp.StrToTime, 618 exp.TimeAdd, 619 exp.TimeStrToTime, 620 exp.TimeSub, 621 exp.TimestampAdd, 622 exp.TimestampSub, 623 exp.UnixToTime, 624 }, 625 exp.DataType.Type.TINYINT: { 626 exp.Day, 627 exp.Month, 628 exp.Week, 629 exp.Year, 630 exp.Quarter, 631 }, 632 exp.DataType.Type.VARCHAR: { 633 exp.ArrayConcat, 634 exp.Concat, 635 exp.ConcatWs, 636 exp.DateToDateStr, 637 exp.GroupConcat, 638 exp.Initcap, 639 exp.Lower, 640 exp.Substring, 641 exp.TimeToStr, 642 exp.TimeToTimeStr, 643 exp.Trim, 644 exp.TsOrDsToDateStr, 645 exp.UnixToStr, 646 exp.UnixToTimeStr, 647 exp.Upper, 648 }, 649 } 650 651 ANNOTATORS: AnnotatorsType = { 652 **{ 653 expr_type: lambda self, e: self._annotate_unary(e) 654 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 655 }, 656 **{ 657 expr_type: lambda self, e: self._annotate_binary(e) 658 for expr_type in subclasses(exp.__name__, exp.Binary) 659 }, 660 **{ 661 expr_type: _annotate_with_type_lambda(data_type) 662 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 663 for expr_type in expressions 664 }, 665 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 666 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 667 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 668 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 669 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 670 exp.Bracket: lambda self, e: self._annotate_bracket(e), 671 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 672 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 673 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Count: lambda self, e: self._annotate_with_type( 675 e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT 676 ), 677 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 678 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 679 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 680 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 681 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 682 exp.Div: lambda self, e: self._annotate_div(e), 683 exp.Dot: lambda self, e: self._annotate_dot(e), 684 exp.Explode: lambda self, e: self._annotate_explode(e), 685 exp.Extract: lambda self, e: self._annotate_extract(e), 686 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 687 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 688 e, exp.DataType.build("ARRAY<DATE>") 689 ), 690 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 691 e, exp.DataType.build("ARRAY<TIMESTAMP>") 692 ), 693 exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 694 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 695 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 696 exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 697 exp.Literal: lambda self, e: self._annotate_literal(e), 698 exp.Map: lambda self, e: self._annotate_map(e), 699 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 700 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 701 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 702 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 703 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 704 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 705 exp.Struct: lambda self, e: self._annotate_struct(e), 706 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 707 exp.Timestamp: lambda self, e: self._annotate_with_type( 708 e, 709 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 710 ), 711 exp.ToMap: lambda self, e: self._annotate_to_map(e), 712 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 713 exp.Unnest: lambda self, e: self._annotate_unnest(e), 714 exp.VarMap: lambda self, e: self._annotate_map(e), 715 } 716 717 @classmethod 718 def get_or_raise(cls, dialect: DialectType) -> Dialect: 719 """ 720 Look up a dialect in the global dialect registry and return it if it exists. 721 722 Args: 723 dialect: The target dialect. If this is a string, it can be optionally followed by 724 additional key-value pairs that are separated by commas and are used to specify 725 dialect settings, such as whether the dialect's identifiers are case-sensitive. 726 727 Example: 728 >>> dialect = dialect_class = get_or_raise("duckdb") 729 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 730 731 Returns: 732 The corresponding Dialect instance. 733 """ 734 735 if not dialect: 736 return cls() 737 if isinstance(dialect, _Dialect): 738 return dialect() 739 if isinstance(dialect, Dialect): 740 return dialect 741 if isinstance(dialect, str): 742 try: 743 dialect_name, *kv_strings = dialect.split(",") 744 kv_pairs = (kv.split("=") for kv in kv_strings) 745 kwargs = {} 746 for pair in kv_pairs: 747 key = pair[0].strip() 748 value: t.Union[bool | str | None] = None 749 750 if len(pair) == 1: 751 # Default initialize standalone settings to True 752 value = True 753 elif len(pair) == 2: 754 value = pair[1].strip() 755 756 # Coerce the value to boolean if it matches to the truthy/falsy values below 757 value_lower = value.lower() 758 if value_lower in ("true", "1"): 759 value = True 760 elif value_lower in ("false", "0"): 761 value = False 762 763 kwargs[key] = value 764 765 except ValueError: 766 raise ValueError( 767 f"Invalid dialect format: '{dialect}'. " 768 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 769 ) 770 771 result = cls.get(dialect_name.strip()) 772 if not result: 773 from difflib import get_close_matches 774 775 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 776 if similar: 777 similar = f" Did you mean {similar}?" 778 779 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 780 781 return result(**kwargs) 782 783 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 784 785 @classmethod 786 def format_time( 787 cls, expression: t.Optional[str | exp.Expression] 788 ) -> t.Optional[exp.Expression]: 789 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 790 if isinstance(expression, str): 791 return exp.Literal.string( 792 # the time formats are quoted 793 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 794 ) 795 796 if expression and expression.is_string: 797 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 798 799 return expression 800 801 def __init__(self, **kwargs) -> None: 802 normalization_strategy = kwargs.pop("normalization_strategy", None) 803 804 if normalization_strategy is None: 805 self.normalization_strategy = self.NORMALIZATION_STRATEGY 806 else: 807 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 808 809 self.settings = kwargs 810 811 def __eq__(self, other: t.Any) -> bool: 812 # Does not currently take dialect state into account 813 return type(self) == other 814 815 def __hash__(self) -> int: 816 # Does not currently take dialect state into account 817 return hash(type(self)) 818 819 def normalize_identifier(self, expression: E) -> E: 820 """ 821 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 822 823 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 824 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 825 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 826 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 827 828 There are also dialects like Spark, which are case-insensitive even when quotes are 829 present, and dialects like MySQL, whose resolution rules match those employed by the 830 underlying operating system, for example they may always be case-sensitive in Linux. 831 832 Finally, the normalization behavior of some engines can even be controlled through flags, 833 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 834 835 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 836 that it can analyze queries in the optimizer and successfully capture their semantics. 837 """ 838 if ( 839 isinstance(expression, exp.Identifier) 840 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 841 and ( 842 not expression.quoted 843 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 844 ) 845 ): 846 expression.set( 847 "this", 848 ( 849 expression.this.upper() 850 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 851 else expression.this.lower() 852 ), 853 ) 854 855 return expression 856 857 def case_sensitive(self, text: str) -> bool: 858 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 859 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 860 return False 861 862 unsafe = ( 863 str.islower 864 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 865 else str.isupper 866 ) 867 return any(unsafe(char) for char in text) 868 869 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 870 """Checks if text can be identified given an identify option. 871 872 Args: 873 text: The text to check. 874 identify: 875 `"always"` or `True`: Always returns `True`. 876 `"safe"`: Only returns `True` if the identifier is case-insensitive. 877 878 Returns: 879 Whether the given text can be identified. 880 """ 881 if identify is True or identify == "always": 882 return True 883 884 if identify == "safe": 885 return not self.case_sensitive(text) 886 887 return False 888 889 def quote_identifier(self, expression: E, identify: bool = True) -> E: 890 """ 891 Adds quotes to a given identifier. 892 893 Args: 894 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 895 identify: If set to `False`, the quotes will only be added if the identifier is deemed 896 "unsafe", with respect to its characters and this dialect's normalization strategy. 897 """ 898 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 899 name = expression.this 900 expression.set( 901 "quoted", 902 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 903 ) 904 905 return expression 906 907 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 908 if isinstance(path, exp.Literal): 909 path_text = path.name 910 if path.is_number: 911 path_text = f"[{path_text}]" 912 try: 913 return parse_json_path(path_text, self) 914 except ParseError as e: 915 if self.STRICT_JSON_PATH_SYNTAX: 916 logger.warning(f"Invalid JSON path syntax. {str(e)}") 917 918 return path 919 920 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 921 return self.parser(**opts).parse(self.tokenize(sql), sql) 922 923 def parse_into( 924 self, expression_type: exp.IntoType, sql: str, **opts 925 ) -> t.List[t.Optional[exp.Expression]]: 926 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 927 928 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 929 return self.generator(**opts).generate(expression, copy=copy) 930 931 def transpile(self, sql: str, **opts) -> t.List[str]: 932 return [ 933 self.generate(expression, copy=False, **opts) if expression else "" 934 for expression in self.parse(sql) 935 ] 936 937 def tokenize(self, sql: str) -> t.List[Token]: 938 return self.tokenizer.tokenize(sql) 939 940 @property 941 def tokenizer(self) -> Tokenizer: 942 return self.tokenizer_class(dialect=self) 943 944 @property 945 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 946 return self.jsonpath_tokenizer_class(dialect=self) 947 948 def parser(self, **opts) -> Parser: 949 return self.parser_class(dialect=self, **opts) 950 951 def generator(self, **opts) -> Generator: 952 return self.generator_class(dialect=self, **opts)
801 def __init__(self, **kwargs) -> None: 802 normalization_strategy = kwargs.pop("normalization_strategy", None) 803 804 if normalization_strategy is None: 805 self.normalization_strategy = self.NORMALIZATION_STRATEGY 806 else: 807 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 808 809 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects, "my_id" would refer to "data.my_id" across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator.
Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.
Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).
Whether a set operation uses DISTINCT by default. This is None
when either DISTINCT
or ALL
must be explicitly specified.
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
717 @classmethod 718 def get_or_raise(cls, dialect: DialectType) -> Dialect: 719 """ 720 Look up a dialect in the global dialect registry and return it if it exists. 721 722 Args: 723 dialect: The target dialect. If this is a string, it can be optionally followed by 724 additional key-value pairs that are separated by commas and are used to specify 725 dialect settings, such as whether the dialect's identifiers are case-sensitive. 726 727 Example: 728 >>> dialect = dialect_class = get_or_raise("duckdb") 729 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 730 731 Returns: 732 The corresponding Dialect instance. 733 """ 734 735 if not dialect: 736 return cls() 737 if isinstance(dialect, _Dialect): 738 return dialect() 739 if isinstance(dialect, Dialect): 740 return dialect 741 if isinstance(dialect, str): 742 try: 743 dialect_name, *kv_strings = dialect.split(",") 744 kv_pairs = (kv.split("=") for kv in kv_strings) 745 kwargs = {} 746 for pair in kv_pairs: 747 key = pair[0].strip() 748 value: t.Union[bool | str | None] = None 749 750 if len(pair) == 1: 751 # Default initialize standalone settings to True 752 value = True 753 elif len(pair) == 2: 754 value = pair[1].strip() 755 756 # Coerce the value to boolean if it matches to the truthy/falsy values below 757 value_lower = value.lower() 758 if value_lower in ("true", "1"): 759 value = True 760 elif value_lower in ("false", "0"): 761 value = False 762 763 kwargs[key] = value 764 765 except ValueError: 766 raise ValueError( 767 f"Invalid dialect format: '{dialect}'. " 768 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 769 ) 770 771 result = cls.get(dialect_name.strip()) 772 if not result: 773 from difflib import get_close_matches 774 775 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 776 if similar: 777 similar = f" Did you mean {similar}?" 778 779 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 780 781 return result(**kwargs) 782 783 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
785 @classmethod 786 def format_time( 787 cls, expression: t.Optional[str | exp.Expression] 788 ) -> t.Optional[exp.Expression]: 789 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 790 if isinstance(expression, str): 791 return exp.Literal.string( 792 # the time formats are quoted 793 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 794 ) 795 796 if expression and expression.is_string: 797 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 798 799 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
819 def normalize_identifier(self, expression: E) -> E: 820 """ 821 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 822 823 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 824 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 825 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 826 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 827 828 There are also dialects like Spark, which are case-insensitive even when quotes are 829 present, and dialects like MySQL, whose resolution rules match those employed by the 830 underlying operating system, for example they may always be case-sensitive in Linux. 831 832 Finally, the normalization behavior of some engines can even be controlled through flags, 833 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 834 835 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 836 that it can analyze queries in the optimizer and successfully capture their semantics. 837 """ 838 if ( 839 isinstance(expression, exp.Identifier) 840 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 841 and ( 842 not expression.quoted 843 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 844 ) 845 ): 846 expression.set( 847 "this", 848 ( 849 expression.this.upper() 850 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 851 else expression.this.lower() 852 ), 853 ) 854 855 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
857 def case_sensitive(self, text: str) -> bool: 858 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 859 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 860 return False 861 862 unsafe = ( 863 str.islower 864 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 865 else str.isupper 866 ) 867 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
869 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 870 """Checks if text can be identified given an identify option. 871 872 Args: 873 text: The text to check. 874 identify: 875 `"always"` or `True`: Always returns `True`. 876 `"safe"`: Only returns `True` if the identifier is case-insensitive. 877 878 Returns: 879 Whether the given text can be identified. 880 """ 881 if identify is True or identify == "always": 882 return True 883 884 if identify == "safe": 885 return not self.case_sensitive(text) 886 887 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
889 def quote_identifier(self, expression: E, identify: bool = True) -> E: 890 """ 891 Adds quotes to a given identifier. 892 893 Args: 894 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 895 identify: If set to `False`, the quotes will only be added if the identifier is deemed 896 "unsafe", with respect to its characters and this dialect's normalization strategy. 897 """ 898 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 899 name = expression.this 900 expression.set( 901 "quoted", 902 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 903 ) 904 905 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
907 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 908 if isinstance(path, exp.Literal): 909 path_text = path.name 910 if path.is_number: 911 path_text = f"[{path_text}]" 912 try: 913 return parse_json_path(path_text, self) 914 except ParseError as e: 915 if self.STRICT_JSON_PATH_SYNTAX: 916 logger.warning(f"Invalid JSON path syntax. {str(e)}") 917 918 return path
967def if_sql( 968 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 969) -> t.Callable[[Generator, exp.If], str]: 970 def _if_sql(self: Generator, expression: exp.If) -> str: 971 return self.func( 972 name, 973 expression.this, 974 expression.args.get("true"), 975 expression.args.get("false") or false_value, 976 ) 977 978 return _if_sql
981def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 982 this = expression.this 983 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 984 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 985 986 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1056def str_position_sql( 1057 self: Generator, 1058 expression: exp.StrPosition, 1059 generate_instance: bool = False, 1060 str_position_func_name: str = "STRPOS", 1061) -> str: 1062 this = self.sql(expression, "this") 1063 substr = self.sql(expression, "substr") 1064 position = self.sql(expression, "position") 1065 instance = expression.args.get("instance") if generate_instance else None 1066 position_offset = "" 1067 1068 if position: 1069 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1070 this = self.func("SUBSTR", this, position) 1071 position_offset = f" + {position} - 1" 1072 1073 return self.func(str_position_func_name, this, substr, instance) + position_offset
1082def var_map_sql( 1083 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1084) -> str: 1085 keys = expression.args["keys"] 1086 values = expression.args["values"] 1087 1088 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1089 self.unsupported("Cannot convert array columns into map.") 1090 return self.func(map_func_name, keys, values) 1091 1092 args = [] 1093 for key, value in zip(keys.expressions, values.expressions): 1094 args.append(self.sql(key)) 1095 args.append(self.sql(value)) 1096 1097 return self.func(map_func_name, *args)
1100def build_formatted_time( 1101 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1102) -> t.Callable[[t.List], E]: 1103 """Helper used for time expressions. 1104 1105 Args: 1106 exp_class: the expression class to instantiate. 1107 dialect: target sql dialect. 1108 default: the default format, True being time. 1109 1110 Returns: 1111 A callable that can be used to return the appropriately formatted time expression. 1112 """ 1113 1114 def _builder(args: t.List): 1115 return exp_class( 1116 this=seq_get(args, 0), 1117 format=Dialect[dialect].format_time( 1118 seq_get(args, 1) 1119 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1120 ), 1121 ) 1122 1123 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1126def time_format( 1127 dialect: DialectType = None, 1128) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1129 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1130 """ 1131 Returns the time format for a given expression, unless it's equivalent 1132 to the default time format of the dialect of interest. 1133 """ 1134 time_format = self.format_time(expression) 1135 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1136 1137 return _time_format
1140def build_date_delta( 1141 exp_class: t.Type[E], 1142 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1143 default_unit: t.Optional[str] = "DAY", 1144) -> t.Callable[[t.List], E]: 1145 def _builder(args: t.List) -> E: 1146 unit_based = len(args) == 3 1147 this = args[2] if unit_based else seq_get(args, 0) 1148 unit = None 1149 if unit_based or default_unit: 1150 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1151 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1152 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1153 1154 return _builder
1157def build_date_delta_with_interval( 1158 expression_class: t.Type[E], 1159) -> t.Callable[[t.List], t.Optional[E]]: 1160 def _builder(args: t.List) -> t.Optional[E]: 1161 if len(args) < 2: 1162 return None 1163 1164 interval = args[1] 1165 1166 if not isinstance(interval, exp.Interval): 1167 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1168 1169 return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) 1170 1171 return _builder
1174def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1175 unit = seq_get(args, 0) 1176 this = seq_get(args, 1) 1177 1178 if isinstance(this, exp.Cast) and this.is_type("date"): 1179 return exp.DateTrunc(unit=unit, this=this) 1180 return exp.TimestampTrunc(this=this, unit=unit)
1183def date_add_interval_sql( 1184 data_type: str, kind: str 1185) -> t.Callable[[Generator, exp.Expression], str]: 1186 def func(self: Generator, expression: exp.Expression) -> str: 1187 this = self.sql(expression, "this") 1188 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1189 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1190 1191 return func
1194def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1195 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1196 args = [unit_to_str(expression), expression.this] 1197 if zone: 1198 args.append(expression.args.get("zone")) 1199 return self.func("DATE_TRUNC", *args) 1200 1201 return _timestamptrunc_sql
1204def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1205 zone = expression.args.get("zone") 1206 if not zone: 1207 from sqlglot.optimizer.annotate_types import annotate_types 1208 1209 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1210 return self.sql(exp.cast(expression.this, target_type)) 1211 if zone.name.lower() in TIMEZONES: 1212 return self.sql( 1213 exp.AtTimeZone( 1214 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1215 zone=zone, 1216 ) 1217 ) 1218 return self.func("TIMESTAMP", expression.this, zone)
1221def no_time_sql(self: Generator, expression: exp.Time) -> str: 1222 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1223 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1224 expr = exp.cast( 1225 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1226 ) 1227 return self.sql(expr)
1230def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1231 this = expression.this 1232 expr = expression.expression 1233 1234 if expr.name.lower() in TIMEZONES: 1235 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1236 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1237 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1238 return self.sql(this) 1239 1240 this = exp.cast(this, exp.DataType.Type.DATE) 1241 expr = exp.cast(expr, exp.DataType.Type.TIME) 1242 1243 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1275def timestrtotime_sql( 1276 self: Generator, 1277 expression: exp.TimeStrToTime, 1278 include_precision: bool = False, 1279) -> str: 1280 datatype = exp.DataType.build( 1281 exp.DataType.Type.TIMESTAMPTZ 1282 if expression.args.get("zone") 1283 else exp.DataType.Type.TIMESTAMP 1284 ) 1285 1286 if isinstance(expression.this, exp.Literal) and include_precision: 1287 precision = subsecond_precision(expression.this.name) 1288 if precision > 0: 1289 datatype = exp.DataType.build( 1290 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1291 ) 1292 1293 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1301def encode_decode_sql( 1302 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1303) -> str: 1304 charset = expression.args.get("charset") 1305 if charset and charset.name.lower() != "utf-8": 1306 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1307 1308 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1321def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1322 cond = expression.this 1323 1324 if isinstance(expression.this, exp.Distinct): 1325 cond = expression.this.expressions[0] 1326 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1327 1328 return self.func("sum", exp.func("if", cond, 1, 0))
1331def trim_sql(self: Generator, expression: exp.Trim) -> str: 1332 target = self.sql(expression, "this") 1333 trim_type = self.sql(expression, "position") 1334 remove_chars = self.sql(expression, "expression") 1335 collation = self.sql(expression, "collation") 1336 1337 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1338 if not remove_chars: 1339 return self.trim_sql(expression) 1340 1341 trim_type = f"{trim_type} " if trim_type else "" 1342 remove_chars = f"{remove_chars} " if remove_chars else "" 1343 from_part = "FROM " if trim_type or remove_chars else "" 1344 collation = f" COLLATE {collation}" if collation else "" 1345 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1366@unsupported_args("position", "occurrence", "parameters") 1367def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1368 group = expression.args.get("group") 1369 1370 # Do not render group if it's the default value for this dialect 1371 if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): 1372 group = None 1373 1374 return self.func("REGEXP_EXTRACT", expression.this, expression.expression, group)
1384def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1385 names = [] 1386 for agg in aggregations: 1387 if isinstance(agg, exp.Alias): 1388 names.append(agg.alias) 1389 else: 1390 """ 1391 This case corresponds to aggregations without aliases being used as suffixes 1392 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1393 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1394 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1395 """ 1396 agg_all_unquoted = agg.transform( 1397 lambda node: ( 1398 exp.Identifier(this=node.name, quoted=False) 1399 if isinstance(node, exp.Identifier) 1400 else node 1401 ) 1402 ) 1403 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1404 1405 return names
1445def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1446 @unsupported_args("count") 1447 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1448 return self.func(name, expression.this, expression.expression) 1449 1450 return _arg_max_or_min_sql
1453def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1454 this = expression.this.copy() 1455 1456 return_type = expression.return_type 1457 if return_type.is_type(exp.DataType.Type.DATE): 1458 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1459 # can truncate timestamp strings, because some dialects can't cast them to DATE 1460 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1461 1462 expression.this.replace(exp.cast(this, return_type)) 1463 return expression
1466def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1467 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1468 if cast and isinstance(expression, exp.TsOrDsAdd): 1469 expression = ts_or_ds_add_cast(expression) 1470 1471 return self.func( 1472 name, 1473 unit_to_var(expression), 1474 expression.expression, 1475 expression.this, 1476 ) 1477 1478 return _delta_sql
1481def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1482 unit = expression.args.get("unit") 1483 1484 if isinstance(unit, exp.Placeholder): 1485 return unit 1486 if unit: 1487 return exp.Literal.string(unit.name) 1488 return exp.Literal.string(default) if default else None
1518def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1519 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1520 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1521 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1522 1523 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1526def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1527 """Remove table refs from columns in when statements.""" 1528 alias = expression.this.args.get("alias") 1529 1530 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1531 return self.dialect.normalize_identifier(identifier).name if identifier else None 1532 1533 targets = {normalize(expression.this.this)} 1534 1535 if alias: 1536 targets.add(normalize(alias.this)) 1537 1538 for when in expression.expressions: 1539 # only remove the target names from the THEN clause 1540 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1541 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1542 then = when.args.get("then") 1543 if then: 1544 then.transform( 1545 lambda node: ( 1546 exp.column(node.this) 1547 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1548 else node 1549 ), 1550 copy=False, 1551 ) 1552 1553 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1556def build_json_extract_path( 1557 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1558) -> t.Callable[[t.List], F]: 1559 def _builder(args: t.List) -> F: 1560 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1561 for arg in args[1:]: 1562 if not isinstance(arg, exp.Literal): 1563 # We use the fallback parser because we can't really transpile non-literals safely 1564 return expr_type.from_arg_list(args) 1565 1566 text = arg.name 1567 if is_int(text): 1568 index = int(text) 1569 segments.append( 1570 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1571 ) 1572 else: 1573 segments.append(exp.JSONPathKey(this=text)) 1574 1575 # This is done to avoid failing in the expression validator due to the arg count 1576 del args[2:] 1577 return expr_type( 1578 this=seq_get(args, 0), 1579 expression=exp.JSONPath(expressions=segments), 1580 only_json_types=arrow_req_json_type, 1581 ) 1582 1583 return _builder
1586def json_extract_segments( 1587 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1588) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1589 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1590 path = expression.expression 1591 if not isinstance(path, exp.JSONPath): 1592 return rename_func(name)(self, expression) 1593 1594 escape = path.args.get("escape") 1595 1596 segments = [] 1597 for segment in path.expressions: 1598 path = self.sql(segment) 1599 if path: 1600 if isinstance(segment, exp.JSONPathPart) and ( 1601 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1602 ): 1603 if escape: 1604 path = self.escape_str(path) 1605 1606 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1607 1608 segments.append(path) 1609 1610 if op: 1611 return f" {op} ".join([self.sql(expression.this), *segments]) 1612 return self.func(name, expression.this, *segments) 1613 1614 return _json_extract_segments
1624def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1625 cond = expression.expression 1626 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1627 alias = cond.expressions[0] 1628 cond = cond.this 1629 elif isinstance(cond, exp.Predicate): 1630 alias = "_u" 1631 else: 1632 self.unsupported("Unsupported filter condition") 1633 return "" 1634 1635 unnest = exp.Unnest(expressions=[expression.this]) 1636 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1637 return self.sql(exp.Array(expressions=[filtered]))
1649def build_default_decimal_type( 1650 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1651) -> t.Callable[[exp.DataType], exp.DataType]: 1652 def _builder(dtype: exp.DataType) -> exp.DataType: 1653 if dtype.expressions or precision is None: 1654 return dtype 1655 1656 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1657 return exp.DataType.build(f"DECIMAL({params})") 1658 1659 return _builder
1662def build_timestamp_from_parts(args: t.List) -> exp.Func: 1663 if len(args) == 2: 1664 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1665 # so we parse this into Anonymous for now instead of introducing complexity 1666 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1667 1668 return exp.TimestampFromParts.from_arg_list(args)
1675def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1676 start = expression.args.get("start") 1677 end = expression.args.get("end") 1678 step = expression.args.get("step") 1679 1680 if isinstance(start, exp.Cast): 1681 target_type = start.to 1682 elif isinstance(end, exp.Cast): 1683 target_type = end.to 1684 else: 1685 target_type = None 1686 1687 if start and end and target_type and target_type.is_type("date", "timestamp"): 1688 if isinstance(start, exp.Cast) and target_type is start.to: 1689 end = exp.cast(end, target_type) 1690 else: 1691 start = exp.cast(start, target_type) 1692 1693 return self.func("SEQUENCE", start, end, step)
1704def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: 1705 if isinstance(expression.this, exp.Explode): 1706 return self.sql( 1707 exp.Join( 1708 this=exp.Unnest( 1709 expressions=[expression.this.this], 1710 alias=expression.args.get("alias"), 1711 offset=isinstance(expression.this, exp.Posexplode), 1712 ), 1713 kind="cross", 1714 ) 1715 ) 1716 return self.lateral_sql(expression)