diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 3ee9e9a2f76f4af4a80529f1895022975fdd3e97..eb812304aace83f8a2d1b8d8eb402ad8e9a35e4f 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -4157,9 +4157,11 @@ class OverrideCheckNode(StatNode):
         code.funcstate.release_temp(func_node_temp)
         code.putln("}")
 
+
 class ClassDefNode(StatNode, BlockNode):
     pass
 
+
 class PyClassDefNode(ClassDefNode):
     #  A Python class definition.
     #
@@ -4185,7 +4187,7 @@ class PyClassDefNode(ClassDefNode):
     mkw = None
 
     def __init__(self, pos, name, bases, doc, body, decorators=None,
-                 keyword_args=None, starstar_arg=None, force_py3_semantics=False):
+                 keyword_args=None, force_py3_semantics=False):
         StatNode.__init__(self, pos)
         self.name = name
         self.doc = doc
@@ -4200,31 +4202,30 @@ class PyClassDefNode(ClassDefNode):
             doc_node = None
 
         allow_py2_metaclass = not force_py3_semantics
-        if keyword_args or starstar_arg:
+        if keyword_args:
             allow_py2_metaclass = False
             self.is_py3_style_class = True
-            if keyword_args and not starstar_arg:
-                for i, item in list(enumerate(keyword_args.key_value_pairs))[::-1]:
-                    if item.key.value == 'metaclass':
-                        if self.metaclass is not None:
-                            error(item.pos, "keyword argument 'metaclass' passed multiple times")
-                        # special case: we already know the metaclass,
-                        # so we don't need to do the "build kwargs,
-                        # find metaclass" dance at runtime
-                        self.metaclass = item.value
-                        del keyword_args.key_value_pairs[i]
-            if starstar_arg:
-                self.mkw = ExprNodes.ProxyNode(ExprNodes.KeywordArgsNode(
-                    pos, keyword_args=keyword_args and keyword_args.key_value_pairs or [],
-                    starstar_arg=starstar_arg))
-            elif keyword_args.key_value_pairs:
-                self.mkw = keyword_args
+            if keyword_args.is_dict_literal:
+                if keyword_args.key_value_pairs:
+                    for i, item in list(enumerate(keyword_args.key_value_pairs))[::-1]:
+                        if item.key.value == 'metaclass':
+                            if self.metaclass is not None:
+                                error(item.pos, "keyword argument 'metaclass' passed multiple times")
+                            # special case: we already know the metaclass,
+                            # so we don't need to do the "build kwargs,
+                            # find metaclass" dance at runtime
+                            self.metaclass = item.value
+                            del keyword_args.key_value_pairs[i]
+                    self.mkw = keyword_args
+                else:
+                    assert self.metaclass is not None
             else:
-                assert self.metaclass is not None
+                # KeywordArgsNode
+                self.mkw = ExprNodes.ProxyNode(keyword_args)
 
         if force_py3_semantics or self.bases or self.mkw or self.metaclass:
             if self.metaclass is None:
-                if starstar_arg:
+                if keyword_args and not keyword_args.is_dict_literal:
                     # **kwargs may contain 'metaclass' arg
                     mkdict = self.mkw
                 else:
diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py
index 14e17659172192e7daec3c907d89d443017ac8f7..42b23a894c2f24adbabbe8c722baab72b1f35ca3 100644
--- a/Cython/Compiler/Parsing.py
+++ b/Cython/Compiler/Parsing.py
@@ -2996,7 +2996,6 @@ def p_class_statement(s, decorators):
     class_name.encoding = s.source_encoding  # FIXME: why is this needed?
     arg_tuple = None
     keyword_dict = None
-    starstar_arg = None
     if s.sy == '(':
         positional_args, keyword_args = p_call_parse_args(s, allow_genexp=False)
         arg_tuple, keyword_dict = p_call_build_packed_args(pos, positional_args, keyword_args)
@@ -3008,7 +3007,6 @@ def p_class_statement(s, decorators):
         pos, name=class_name,
         bases=arg_tuple,
         keyword_args=keyword_dict,
-        starstar_arg=starstar_arg,
         doc=doc, body=body, decorators=decorators,
         force_py3_semantics=s.context.language_level >= 3)