Skip to content

Commit 0fbde60

Browse files
committed
trafaret/Enum: support Python >=3.4 enum module
1 parent b38366d commit 0fbde60

2 files changed

Lines changed: 52 additions & 0 deletions

File tree

tests/test_base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
2+
import sys
23
import unittest
34
import trafaret as t
5+
import trafaret.utils as tu
46
from collections import Mapping as AbcMapping
57
from trafaret import extract_error, ignore, DataError
68
from trafaret.extras import KeysSubset
@@ -264,6 +266,40 @@ def test_enum(self):
264266
res = extract_error(trafaret, 2)
265267
self.assertEqual(res, "value doesn't match any variant")
266268

269+
@unittest.skipIf(sys.version_info < (3, 4),
270+
"not supported in this veresion"
271+
)
272+
def test_enum_py3(self):
273+
import enum
274+
275+
class Colors(enum.Enum):
276+
red = 0
277+
green = 1
278+
blue = 2
279+
280+
trafaret = t.Enum(Colors)
281+
self.assertEqual(repr(trafaret), "<Enum('red', 'green', 'blue')>")
282+
res = trafaret.check('red')
283+
res = trafaret.check('green')
284+
res = extract_error(trafaret, 'unknown')
285+
self.assertEqual(res, "value doesn't match any variant")
286+
287+
# check multiple fails
288+
289+
class Fruits(enum.Enum):
290+
orange = 0
291+
apple = 1
292+
293+
with self.assertRaises(TypeError):
294+
trafaret = t.Enum(Colors, Fruits)
295+
296+
# check mixin fails in any order
297+
298+
with self.assertRaises(TypeError):
299+
trafaret = t.Enum(Colors, 1)
300+
301+
with self.assertRaises(TypeError):
302+
trafaret = t.Enum(1, Colors)
267303

268304

269305
class TestFloat(unittest.TestCase):

trafaret/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@
1515
__VERSION__ = (0, 10, 2)
1616

1717

18+
enum = None
19+
1820
# Python3 support
1921
if py3:
2022
import urllib.parse as urlparse
2123
str_types = (str, bytes)
2224
unicode = str
25+
26+
try:
27+
import enum
28+
except ImportError:
29+
pass
30+
2331
else:
2432
try:
2533
from future_builtins import map
@@ -1345,6 +1353,14 @@ class Enum(Trafaret):
13451353
__slots__ = ['variants']
13461354

13471355
def __init__(self, *variants):
1356+
if enum is not None:
1357+
has_meta = any(isinstance(v, enum.EnumMeta) for v in variants)
1358+
if has_meta:
1359+
if len(variants) != 1:
1360+
raise TypeError("You can't use enum.Enum with other arguments")
1361+
1362+
variants = [x.name for x in variants[0]]
1363+
13481364
self.variants = variants[:]
13491365

13501366
def check_value(self, value):

0 commit comments

Comments
 (0)