Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions djongo/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,7 @@ def date_trunc_sql(self, lookup_type, field_name):

def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "DATE_TRUNC(%s, %s)" % (lookup_type.upper(), field_name)

# force to use where A = %s when A is BooleanField
def conditional_expression_supported_in_where_clause(self, expression):
return False
22 changes: 13 additions & 9 deletions djongo/sql2mongo/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import typing
from collections import OrderedDict
from sqlparse import tokens, parse as sqlparse
from sqlparse.sql import Parenthesis
from sqlparse.sql import Parenthesis,Comparison
from typing import Union as U, List, Optional as O
from . import query as query_module
from .sql_tokens import SQLIdentifier, SQLConstIdentifier, SQLComparison
from .functions import SQLFunc, CountFuncAll
from .operators import WhereOp
from .operators import WhereOp,CmpOp
from ..exceptions import SQLDecodeError
from .sql_tokens import SQLToken, SQLStatement

Expand Down Expand Up @@ -288,23 +288,27 @@ def to_mongo(self):


class SetConverter(Converter):

def __init__(self, *args):
self.sql_tokens: List[SQLComparison] = []
super().__init__(*args)

def parse(self):
tok = self.statement.next()
self.sql_tokens.extend(SQLToken.tokens2sql(tok, self.query))

self.update_pipeline = []

for sql_token in self.sql_tokens:
if not isinstance(sql_token._token, Comparison):
continue

parser = CmpOp(sql_token._token, self.query)
parser.evaluate()
self.update_pipeline.append(parser.to_mongo())

def to_mongo(self):
return {
'update': {
'$set': {
sql.left_column: self.query.params[sql.rhs_indexes]
if sql.rhs_indexes is not None else None
for sql in self.sql_tokens}
}
'update': self.update_pipeline
}


Expand Down
117 changes: 101 additions & 16 deletions djongo/sql2mongo/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from itertools import chain

from sqlparse import tokens
from sqlparse.sql import Token, Parenthesis, Comparison, IdentifierList, Identifier, Function
from sqlparse.sql import Token, Parenthesis, Comparison, IdentifierList, Identifier, Function,Operation

from ..exceptions import SQLDecodeError
from .sql_tokens import SQLToken, SQLStatement
Expand Down Expand Up @@ -432,6 +432,9 @@ def _token2op(self,

elif isinstance(tok, Identifier):
pass
elif isinstance(tok, Operation):
#debug_operation(tok)
op = ArithmeticOp(Comparison(tok), self.query)
else:
raise SQLDecodeError

Expand Down Expand Up @@ -524,19 +527,36 @@ class CmpOp(_Op):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
#debug_operation(self.statement.left)
self._identifier = SQLToken.token2sql(self.statement.left, self.query)

self._is_expr = False

# identifier_allow is special case for where a>b or update a=b+1
identifier_allow = False
if isinstance(self.statement.right, Identifier):
raise SQLDecodeError('Join using WHERE not supported')

self._right_identifier = SQLToken.token2sql(self.statement.right, self.query)
identifier_allow = (self._right_identifier.table == self.query.left_table) or (self._right_identifier.table == self._identifier.column)
if not identifier_allow:
raise SQLDecodeError('Join using WHERE not supported')

self._operator = OPERATOR_MAP[self.statement.token_next(0)[1].value]
index = re_index(self.statement.right.value)

self._constant = self.params[index] if index is not None else None
if isinstance(self._constant, dict):
self._field_ext, self._constant = next(iter(self._constant.items()))
else:
if isinstance(self.statement.right, Parenthesis):
parser = ParenthesisOp(self.statement.right, self.query)
parser.evaluate()
self._is_expr = True
self._constant = parser._op.to_mongo()
self._field_ext = None
else:
if identifier_allow:
self._constant = f"${self._right_identifier.column}"
self._field_ext = None
else:
index = re_index(self.statement.right.value)
self._constant = self.params[index] if index is not None else None
if isinstance(self._constant, dict):
self._field_ext, self._constant = next(iter(self._constant.items()))
else:
self._field_ext = None

def negate(self):
self.is_negated = True
Expand All @@ -545,14 +565,62 @@ def evaluate(self):
pass

def to_mongo(self):
field = self._identifier.field
if self._field_ext:
field += '.' + self._field_ext
if not hasattr(self.query, 'is_set') or not self.query.is_set:
field = self._identifier.field
if self._field_ext:
field += '.' + self._field_ext

if not self.is_negated:
return {field: {self._operator: self._constant}}
if self._is_expr:
return {"$expr": {self._operator: ['$' + field, self._constant]}}

if not self.is_negated:
return {field: {self._operator: self._constant}}
else:
return {field: {'$not': {self._operator: self._constant}}}
else:
field = self._identifier.column
return {"$set":{field: self._constant}}

class ArithmeticOp(_Op):
OPERATORS = {
'+': '$add',
'-': '$subtract',
'*': '$multiply',
'/': '$divide',
'%': '$mod'
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if isinstance(self.statement.left, Parenthesis):
parser = ParenthesisOp(self.statement.left, self.query)
parser.evaluate()
self._left = parser.to_mongo()
elif isinstance(self.statement.left, Identifier):
self._left = '$' + SQLToken.token2sql(self.statement.left, self.query).field
else:
self._left = self.statement.left.value

self._operator = self.OPERATORS.get(self.statement.token_next(0)[1].value)

if isinstance(self.statement.right, Parenthesis):
parser = ParenthesisOp(self.statement.right, self.query)
parser.evaluate()
self._right = parser.to_mongo()
elif isinstance(self.statement.right, Identifier):
self._right = '$' + SQLToken.token2sql(self.statement.right, self.query).field
else:
return {field: {'$not': {self._operator: self._constant}}}
index = re_index(self.statement.right.value)
self._right = self.params[index] if index is not None else self.statement.right.value

def evaluate(self):
pass

def to_mongo(self):
return {
self._operator: [self._left, self._right]
}


class FuncOp(CmpOp):
Expand Down Expand Up @@ -592,3 +660,20 @@ def to_mongo(self):
'OR': 1,
'generic': 0
}


def debug_operation(op):
print("=== Operation Debug Info ===")
print(f"Operation: {op}")
print(f"Type: {type(op)}")
print(f"Value: {op.value}")

# 打印所有tokens
print("All Tokens:")
for idx, token in enumerate(op.tokens):
print(f" Token {idx}:")
print(f" Value: {token.value}")
print(f" Type: {type(token)}")
print(f" ttype: {token.ttype}")
if hasattr(token, 'tokens'):
print(f" Sub-tokens: {[t.value for t in token.tokens]}")
6 changes: 5 additions & 1 deletion djongo/sql2mongo/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def parse(self):
self.left_table = c.sql_tokens[0].table

elif tok.match(tokens.Keyword, 'SET'):
self.is_set = True
c = self.set_columns = SetConverter(self, statement)
self.is_set = False

elif isinstance(tok, Where):
c = self.where = WhereConverter(self, statement)
Expand All @@ -320,7 +322,9 @@ def parse(self):
self.kwargs = {}
if self.where:
self.kwargs.update(self.where.to_mongo())

else:
self.kwargs = {"filter":{}}

self.kwargs.update(self.set_columns.to_mongo())

def execute(self):
Expand Down