diff --git a/test/modules/op/linspace.py b/test/modules/op/linspace.py new file mode 100644 index 00000000..3b87a741 --- /dev/null +++ b/test/modules/op/linspace.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test.modules.base import TestModuleBase + +from test.utils import tag + + +class SimpleLinspace(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, start, end, step): + return torch.linspace(start, end, step) + + def get_example_inputs(self): + return (torch.tensor(3, dtype = torch.int32), torch.tensor(3, dtype = torch.int32), int(3),), {} + diff --git a/tico/serialize/circle_mapping.py b/tico/serialize/circle_mapping.py index f001d04e..661d91dd 100644 --- a/tico/serialize/circle_mapping.py +++ b/tico/serialize/circle_mapping.py @@ -37,6 +37,7 @@ def to_circle_dtype( torch.int: circle.TensorType.TensorType.INT32, torch.int64: circle.TensorType.TensorType.INT64, torch.bool: circle.TensorType.TensorType.BOOL, + torch.float64: circle.TensorType.TensorType.FLOAT32, # ADDED } if torch_dtype not in dmap: