diff --git a/djongo/operations.py b/djongo/operations.py index 3a9abfb3..d4d08d9f 100644 --- a/djongo/operations.py +++ b/djongo/operations.py @@ -155,3 +155,7 @@ def date_trunc_sql(self, lookup_type, field_name): def datetime_trunc_sql(self, lookup_type, field_name, tzname): return "DATE_TRUNC(%s, %s)" % (lookup_type.upper(), field_name) + + # force to use where A = %s when A is BooleanField + def conditional_expression_supported_in_where_clause(self, expression): + return False \ No newline at end of file diff --git a/djongo/sql2mongo/converters.py b/djongo/sql2mongo/converters.py index 08230a82..7a059dd3 100644 --- a/djongo/sql2mongo/converters.py +++ b/djongo/sql2mongo/converters.py @@ -2,12 +2,12 @@ import typing from collections import OrderedDict from sqlparse import tokens, parse as sqlparse -from sqlparse.sql import Parenthesis +from sqlparse.sql import Parenthesis,Comparison from typing import Union as U, List, Optional as O from . import query as query_module from .sql_tokens import SQLIdentifier, SQLConstIdentifier, SQLComparison from .functions import SQLFunc, CountFuncAll -from .operators import WhereOp +from .operators import WhereOp,CmpOp from ..exceptions import SQLDecodeError from .sql_tokens import SQLToken, SQLStatement @@ -288,7 +288,6 @@ def to_mongo(self): class SetConverter(Converter): - def __init__(self, *args): self.sql_tokens: List[SQLComparison] = [] super().__init__(*args) @@ -296,15 +295,20 @@ def __init__(self, *args): def parse(self): tok = self.statement.next() self.sql_tokens.extend(SQLToken.tokens2sql(tok, self.query)) + + self.update_pipeline = [] + + for sql_token in self.sql_tokens: + if not isinstance(sql_token._token, Comparison): + continue + + parser = CmpOp(sql_token._token, self.query) + parser.evaluate() + self.update_pipeline.append(parser.to_mongo()) def to_mongo(self): return { - 'update': { - '$set': { - sql.left_column: self.query.params[sql.rhs_indexes] - if sql.rhs_indexes is not None else None - for sql in self.sql_tokens} - } + 'update': self.update_pipeline } diff --git a/djongo/sql2mongo/operators.py b/djongo/sql2mongo/operators.py index 02821d7e..c04e1d1b 100644 --- a/djongo/sql2mongo/operators.py +++ b/djongo/sql2mongo/operators.py @@ -4,7 +4,7 @@ from itertools import chain from sqlparse import tokens -from sqlparse.sql import Token, Parenthesis, Comparison, IdentifierList, Identifier, Function +from sqlparse.sql import Token, Parenthesis, Comparison, IdentifierList, Identifier, Function,Operation from ..exceptions import SQLDecodeError from .sql_tokens import SQLToken, SQLStatement @@ -432,6 +432,9 @@ def _token2op(self, elif isinstance(tok, Identifier): pass + elif isinstance(tok, Operation): + #debug_operation(tok) + op = ArithmeticOp(Comparison(tok), self.query) else: raise SQLDecodeError @@ -524,19 +527,36 @@ class CmpOp(_Op): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + #debug_operation(self.statement.left) self._identifier = SQLToken.token2sql(self.statement.left, self.query) - + self._is_expr = False + + # identifier_allow is special case for where a>b or update a=b+1 + identifier_allow = False if isinstance(self.statement.right, Identifier): - raise SQLDecodeError('Join using WHERE not supported') - + self._right_identifier = SQLToken.token2sql(self.statement.right, self.query) + identifier_allow = (self._right_identifier.table == self.query.left_table) or (self._right_identifier.table == self._identifier.column) + if not identifier_allow: + raise SQLDecodeError('Join using WHERE not supported') + self._operator = OPERATOR_MAP[self.statement.token_next(0)[1].value] - index = re_index(self.statement.right.value) - - self._constant = self.params[index] if index is not None else None - if isinstance(self._constant, dict): - self._field_ext, self._constant = next(iter(self._constant.items())) - else: + if isinstance(self.statement.right, Parenthesis): + parser = ParenthesisOp(self.statement.right, self.query) + parser.evaluate() + self._is_expr = True + self._constant = parser._op.to_mongo() self._field_ext = None + else: + if identifier_allow: + self._constant = f"${self._right_identifier.column}" + self._field_ext = None + else: + index = re_index(self.statement.right.value) + self._constant = self.params[index] if index is not None else None + if isinstance(self._constant, dict): + self._field_ext, self._constant = next(iter(self._constant.items())) + else: + self._field_ext = None def negate(self): self.is_negated = True @@ -545,14 +565,62 @@ def evaluate(self): pass def to_mongo(self): - field = self._identifier.field - if self._field_ext: - field += '.' + self._field_ext + if not hasattr(self.query, 'is_set') or not self.query.is_set: + field = self._identifier.field + if self._field_ext: + field += '.' + self._field_ext - if not self.is_negated: - return {field: {self._operator: self._constant}} + if self._is_expr: + return {"$expr": {self._operator: ['$' + field, self._constant]}} + + if not self.is_negated: + return {field: {self._operator: self._constant}} + else: + return {field: {'$not': {self._operator: self._constant}}} + else: + field = self._identifier.column + return {"$set":{field: self._constant}} + +class ArithmeticOp(_Op): + OPERATORS = { + '+': '$add', + '-': '$subtract', + '*': '$multiply', + '/': '$divide', + '%': '$mod' + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if isinstance(self.statement.left, Parenthesis): + parser = ParenthesisOp(self.statement.left, self.query) + parser.evaluate() + self._left = parser.to_mongo() + elif isinstance(self.statement.left, Identifier): + self._left = '$' + SQLToken.token2sql(self.statement.left, self.query).field + else: + self._left = self.statement.left.value + + self._operator = self.OPERATORS.get(self.statement.token_next(0)[1].value) + + if isinstance(self.statement.right, Parenthesis): + parser = ParenthesisOp(self.statement.right, self.query) + parser.evaluate() + self._right = parser.to_mongo() + elif isinstance(self.statement.right, Identifier): + self._right = '$' + SQLToken.token2sql(self.statement.right, self.query).field else: - return {field: {'$not': {self._operator: self._constant}}} + index = re_index(self.statement.right.value) + self._right = self.params[index] if index is not None else self.statement.right.value + + def evaluate(self): + pass + + def to_mongo(self): + return { + self._operator: [self._left, self._right] + } class FuncOp(CmpOp): @@ -592,3 +660,20 @@ def to_mongo(self): 'OR': 1, 'generic': 0 } + + +def debug_operation(op): + print("=== Operation Debug Info ===") + print(f"Operation: {op}") + print(f"Type: {type(op)}") + print(f"Value: {op.value}") + + # 打印所有tokens + print("All Tokens:") + for idx, token in enumerate(op.tokens): + print(f" Token {idx}:") + print(f" Value: {token.value}") + print(f" Type: {type(token)}") + print(f" ttype: {token.ttype}") + if hasattr(token, 'tokens'): + print(f" Sub-tokens: {[t.value for t in token.tokens]}") \ No newline at end of file diff --git a/djongo/sql2mongo/query.py b/djongo/sql2mongo/query.py index fefb52ca..c6eda7af 100644 --- a/djongo/sql2mongo/query.py +++ b/djongo/sql2mongo/query.py @@ -309,7 +309,9 @@ def parse(self): self.left_table = c.sql_tokens[0].table elif tok.match(tokens.Keyword, 'SET'): + self.is_set = True c = self.set_columns = SetConverter(self, statement) + self.is_set = False elif isinstance(tok, Where): c = self.where = WhereConverter(self, statement) @@ -320,7 +322,9 @@ def parse(self): self.kwargs = {} if self.where: self.kwargs.update(self.where.to_mongo()) - + else: + self.kwargs = {"filter":{}} + self.kwargs.update(self.set_columns.to_mongo()) def execute(self):